Gradio

  • Stream
  1. Gradio : transformers.TextIteratorStreamer
def generate(user_message):
    prompt = f"### 질문: {user_message}\n\n### 답변:"
    model_inputs = tokenizer([prompt], return_tensors = 'pt', return_token_type_ids = False).to('cuda')

    streamer = TextIteratorStreamer(tokenizer, timeout = 100, skip_prompt = True, skip_special_tokens = True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        temperature= 0.7,
        top_p=0.95,
        top_k=50,
        max_new_tokens=512,
        do_sample=True,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    ...
  1. TGI : text-generation.Client.generate_stream
def generate(...):
    ...

    stream = client.generate_stream(
            prompt,
            **generate_kwargs,
        )

    output = ""
    for idx, response in enumerate(stream):
        if response.token.text == '':
            break

        if response.token.special:
            continue
        output += response.token.text
        if idx == 0:
            history.append(" " + output)
        else:
            history[-1] = output

        chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]

        yield chat, history, user_message, ""

    return chat, history, user_message, ""

TS

  • grdio를 통해 서비스 후, public URL은 접속이 되었지만 local URL은 접속 x
    • port 7860 : Port-fowarding 진행 후, 재접속하였으나 접속 x
    • local ip로 접속하였으나 x
  • Solution
    • server_name="0.0.0.0"을 통해 모든 IP에서 접속 가능하게 만들어주기
with gr.Blocks() as demo:
    ...

demo.queue().launch(server_name = "0.0.0.0")

ref