import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig from peft import PeftModel import torch from threading import Thread model_path = ('TigerResearch/tigerbot-13b-chat', None) lora_path = 'larryvrh/tigerbot-13b-chat-sharegpt-lora' tokenizer = AutoTokenizer.from_pretrained(model_path[0]) quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16 ) model = AutoModelForCausalLM.from_pretrained(model_path[0], revision=model_path[1], device_map="auto", quantization_config = quant_config, #load_in_8bit=True, ) model = PeftModel.from_pretrained(model, lora_path) model.eval() def predict(input, chatbot, max_length, top_p, temperature, rep_penalty, retry): if retry and len(chatbot) == 0: yield [] return elif retry: input = chatbot[-1][0] chatbot = chatbot[:-1] chatbot.append((input, "")) prompt = '' + ''.join([f'\n\n### Instruction:\n{r[0]}\n\n### Response:\n{r[1]}' for r in chatbot]) print('prompt:', repr(prompt)) model_inputs = tokenizer([prompt], return_tensors="pt", truncation=True, max_length=max_length-500).to('cuda') streamer = TextIteratorStreamer(tokenizer, timeout=15.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=500, do_sample=True, top_p=top_p, temperature=temperature, repetition_penalty=rep_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() for response in streamer: chatbot[-1] = (chatbot[-1][0], chatbot[-1][1] + response) yield chatbot def reset_user_input(): return gr.update(value='') def reset_state(): return [] css=''' .contain {max-width:50} #chatbot {min-height:500px} ''' with gr.Blocks(css=css) as demo: gr.HTML('

TigerBot

') chatbot = gr.Chatbot(elem_id='chatbot') with gr.Column(): user_input = gr.Textbox(show_label=False, placeholder="输入", lines=1).style(container=False) with gr.Row(): submitBtn = gr.Button("发送", variant="primary") retryBtn = gr.Button("重试") cancelBtn = gr.Button('撤销') emptyBtn = gr.Button("清空") with gr.Row(): max_length = gr.Slider(0, 4096, value=2048, step=1, label="Context Length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True) temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True) rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label='Repetition Penalty', interactive=True) submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)], [chatbot], show_progress=False) submitBtn.click(reset_user_input, [], [user_input], show_progress=False) retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)], [chatbot], show_progress=False) cancelBtn.click(lambda m:m[:-1], [chatbot], [chatbot], show_progress=False) emptyBtn.click(reset_state, outputs=[chatbot], show_progress=False) demo.queue().launch(share=False, inbrowser=True)