|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig |
|
from peft import PeftModel |
|
import torch |
|
from threading import Thread |
|
|
|
model_path = ('TigerResearch/tigerbot-13b-chat', None) |
|
|
|
lora_path = 'larryvrh/tigerbot-13b-chat-sharegpt-lora' |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path[0]) |
|
|
|
quant_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.float16 |
|
) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_path[0], revision=model_path[1], |
|
device_map="auto", |
|
quantization_config = quant_config, |
|
) |
|
|
|
model = PeftModel.from_pretrained(model, lora_path) |
|
model.eval() |
|
|
|
def predict(input, chatbot, max_length, top_p, temperature, rep_penalty, retry): |
|
if retry and len(chatbot) == 0: |
|
yield [] |
|
return |
|
elif retry: |
|
input = chatbot[-1][0] |
|
chatbot = chatbot[:-1] |
|
|
|
chatbot.append((input, "")) |
|
|
|
prompt = '<s>' + ''.join([f'\n\n### Instruction:\n{r[0]}\n\n### Response:\n{r[1]}' for r in chatbot]) |
|
print('prompt:', repr(prompt)) |
|
model_inputs = tokenizer([prompt], return_tensors="pt", truncation=True, max_length=max_length-500).to('cuda') |
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=15.0, skip_prompt=True, skip_special_tokens=True) |
|
generate_kwargs = dict( |
|
model_inputs, |
|
streamer=streamer, |
|
max_new_tokens=500, |
|
do_sample=True, |
|
top_p=top_p, |
|
temperature=temperature, |
|
repetition_penalty=rep_penalty, |
|
) |
|
t = Thread(target=model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
for response in streamer: |
|
chatbot[-1] = (chatbot[-1][0], chatbot[-1][1] + response) |
|
yield chatbot |
|
|
|
|
|
def reset_user_input(): |
|
return gr.update(value='') |
|
|
|
|
|
def reset_state(): |
|
return [] |
|
|
|
css=''' |
|
.contain {max-width:50} |
|
|
|
#chatbot {min-height:500px} |
|
''' |
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.HTML('<h1 align="center">TigerBot</h1>') |
|
|
|
chatbot = gr.Chatbot(elem_id='chatbot') |
|
with gr.Column(): |
|
user_input = gr.Textbox(show_label=False, placeholder="输入", lines=1).style(container=False) |
|
with gr.Row(): |
|
submitBtn = gr.Button("发送", variant="primary") |
|
retryBtn = gr.Button("重试") |
|
cancelBtn = gr.Button('撤销') |
|
emptyBtn = gr.Button("清空") |
|
with gr.Row(): |
|
max_length = gr.Slider(0, 4096, value=2048, step=1, label="Context Length", interactive=True) |
|
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True) |
|
temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True) |
|
rep_penalty = gr.Slider(1.0, 1.5, value=1.15, step=0.01, label='Repetition Penalty', interactive=True) |
|
|
|
|
|
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)], |
|
[chatbot], show_progress=False) |
|
submitBtn.click(reset_user_input, [], [user_input], show_progress=False) |
|
|
|
retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)], |
|
[chatbot], show_progress=False) |
|
|
|
cancelBtn.click(lambda m:m[:-1], [chatbot], [chatbot], show_progress=False) |
|
|
|
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=False) |
|
|
|
demo.queue().launch(share=False, inbrowser=True) |
|
|