from threading import Thread import gradio as gr import torch from transformers import ( pipeline, AutoTokenizer, TextIteratorStreamer, ) def chat_history(history) -> str: messages = [] for dialog in history: for i, message in enumerate(dialog): role = "user" if i % 2 == 0 else "assistant" messages.append({"role": role, "content": message}) messages.pop(-1) return pipe.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) def model_loading_pipeline(): model_id = "vilm/vinallama-2.7b" tokenizer = AutoTokenizer.from_pretrained(model_id) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, Timeout=5) pipe = pipeline( "text-generation", model=model_id, model_kwargs={ "torch_dtype": torch.bfloat16, }, streamer=streamer, ) return pipe, streamer def launch_app(pipe, streamer): with gr.Blocks() as demo: chat = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("Clear") def user(user_message, history): return "", history + [[user_message, None]] def bot(history): prompt = chat_history(history) history[-1][1] = "" kwargs = { "text_inputs": prompt, "max_new_tokens": 64, "do_sample": True, "temperature": 0.7, "top_k": 50, "top_p": 0.95, } thread = Thread(target=pipe, kwargs=kwargs) thread.start() for token in streamer: history[-1][1] += token yield history msg.submit(user, [msg, chat], [msg, chat], queue=False).then(bot, chat, chat) clear.click(lambda: None, None, chat, queue=False) demo.queue() demo.launch(share=True, debug=True) if __name__ == "__main__": pipe, streamer = model_loading_pipeline() launch_app(pipe, streamer)