File size: 3,649 Bytes
0ce41ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from peft import PeftModel
import torch
from threading import Thread

model_path = ('TigerResearch/tigerbot-13b-chat', None)

lora_path = 'larryvrh/tigerbot-13b-chat-sharegpt-lora'

tokenizer = AutoTokenizer.from_pretrained(model_path[0])

quant_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(model_path[0], revision=model_path[1],
                                             device_map="auto", 
                                             quantization_config = quant_config, #load_in_8bit=True,
                                            )

model = PeftModel.from_pretrained(model, lora_path)
model.eval()

def predict(input, chatbot, max_length, top_p, temperature, rep_penalty, retry):
    if retry and len(chatbot) == 0:
        yield []
        return
    elif retry:
        input = chatbot[-1][0]
        chatbot = chatbot[:-1]
    
    chatbot.append((input, ""))
    
    prompt = '<s>' + ''.join([f'\n\n### Instruction:\n{r[0]}\n\n### Response:\n{r[1]}' for r in chatbot])
    print('prompt:', repr(prompt))
    model_inputs = tokenizer([prompt], return_tensors="pt", truncation=True, max_length=max_length-500).to('cuda')
    
    streamer = TextIteratorStreamer(tokenizer, timeout=15.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=500,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
        repetition_penalty=rep_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    
    for response in streamer:
        chatbot[-1] = (chatbot[-1][0], chatbot[-1][1] + response)
        yield chatbot


def reset_user_input():
    return gr.update(value='')


def reset_state():
    return []

css='''
    .contain {max-width:50}

    #chatbot {min-height:500px}
'''

with gr.Blocks(css=css) as demo:
    gr.HTML('<h1 align="center">TigerBot</h1>')

    chatbot = gr.Chatbot(elem_id='chatbot')
    with gr.Column():
        user_input = gr.Textbox(show_label=False, placeholder="输入", lines=1).style(container=False)
        with gr.Row():
            submitBtn = gr.Button("发送", variant="primary")
            retryBtn = gr.Button("重试")
            cancelBtn = gr.Button('撤销')
            emptyBtn = gr.Button("清空")
        with gr.Row():
            max_length = gr.Slider(0, 4096, value=2048, step=1, label="Context Length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top-P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.5, step=0.01, label="Temperature", interactive=True)
            rep_penalty = gr.Slider(1.0, 1.5, value=1.15, step=0.01, label='Repetition Penalty', interactive=True)


    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(False)],
                    [chatbot], show_progress=False)
    submitBtn.click(reset_user_input, [], [user_input], show_progress=False)
    
    retryBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, rep_penalty, gr.State(True)],
                    [chatbot], show_progress=False)
    
    cancelBtn.click(lambda m:m[:-1], [chatbot], [chatbot], show_progress=False)

    emptyBtn.click(reset_state, outputs=[chatbot], show_progress=False)

demo.queue().launch(share=False, inbrowser=True)