larryvrh's picture
Update chat_webui.py
85f025c
raw
history blame
3.65 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from peft import PeftModel
import torch
from threading import Thread
model_path = ('TigerResearch/tigerbot-13b-chat', None)
lora_path = 'larryvrh/tigerbot-13b-chat-sharegpt-lora'
tokenizer = AutoTokenizer.from_pretrained(model_path[0])
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(model_path[0], revision=model_path[1],
device_map="auto",
quantization_config = quant_config, #load_in_8bit=True,
)
model = PeftModel.from_pretrained(model, lora_path)
model.eval()
def predict(input, chatbot, max_length, top_p, temperature, rep_penalty, retry):
if retry and len(chatbot) == 0:
yield []
return
elif retry:
input = chatbot[-1][0]
chatbot = chatbot[:-1]
chatbot.append((input, ""))
prompt = '<s>' + ''.join([f'\n\n### Instruction:\n{r[0]}\n\n### Response:\n{r[1]}' for r in chatbot])
print('prompt:', repr(prompt))
model_inputs = tokenizer([prompt], return_tensors="pt", truncation=True, max_length=max_length-500).to('cuda')
streamer = TextIteratorStreamer(tokenizer, timeout=15.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=500,
do_sample=True,
top_p=top_p,
temperature=temperature,
repetition_penalty=rep_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
for response in streamer:
chatbot[-1] = (chatbot[-1][0], chatbot[-1][1] + response)
yield chatbot
def reset_user_input():
return gr.update(value='')
def reset_state():
return []
css='''
.contain {max-width:50}
#chatbot {min-height:500px}
'''
with gr.Blocks(css=css) as demo:
gr.HTML('<h1 align="center">TigerBot</h1>')
chatbot = gr.Chatbot(elem_id='chatbot')
with gr.Column():
user_input = gr.Textbox(show_label=False, placeholder="输入", lines=1).style(container=False)
with gr.Row():
submitBtn = gr.Button("发送", variant="primary")
retryBtn = gr.Button("重试")
cancelBtn = gr.Button('撤销')
emptyBtn = gr.Button("清空")
with gr.Row():
max_length = gr.Slider(0, 4096, value=2048, step=1, label="Context Length", interactive=True)
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True)
temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True)
rep_penalty = gr.Slider(1.0, 1.5, value=1.1, step=0.01, label='Repetition Penalty', interactive=True)
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)],
[chatbot], show_progress=False)
submitBtn.click(reset_user_input, [], [user_input], show_progress=False)
retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)],
[chatbot], show_progress=False)
cancelBtn.click(lambda m:m[:-1], [chatbot], [chatbot], show_progress=False)
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=False)
demo.queue().launch(share=False, inbrowser=True)