Chris Alexiuk commited on
Commit
d019a4b
β€’
1 Parent(s): ca0675b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +282 -0
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
4
+ import datetime
5
+ import os
6
+ from threading import Event, Thread
7
+ from uuid import uuid4
8
+ import gradio as gr
9
+ import requests
10
+
11
+ model_name = "decapoda-research/llama-13b-hf"
12
+ adapters_name = 'timdettmers/guanaco-13b'
13
+
14
+ print(f"Starting to load the model {model_name} into memory")
15
+
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ model_name,
18
+ load_in_4bit=True,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map={"": 0}
21
+ )
22
+
23
+ model = PeftModel.from_pretrained(m, adapters_name)
24
+ model = m.merge_and_unload()
25
+ tokenizer = LlamaTokenizer.from_pretrained(model_name)
26
+ tokenizer.bos_token_id = 1
27
+ stop_token_ids = [0]
28
+
29
+ max_new_tokens = 2048
30
+
31
+ start_message = """A chat between a human user and a kind AI. The assistant gives helpful, cordial, and polite answers to the user's questions."""
32
+
33
+ class StopOnTokens(StoppingCriteria):
34
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
35
+ for stop_id in stop_token_ids:
36
+ if input_ids[0][-1] == stop_id:
37
+ return True
38
+ return False
39
+
40
+
41
+ def convert_history_to_text(history):
42
+ text = start_message + "".join(
43
+ [
44
+ "".join(
45
+ [
46
+ f"### Human: {item[0]}\n",
47
+ f"### Assistant: {item[1]}\n",
48
+ ]
49
+ )
50
+ for item in history[:-1]
51
+ ]
52
+ )
53
+ text += "".join(
54
+ [
55
+ "".join(
56
+ [
57
+ f"### Human: {history[-1][0]}\n",
58
+ f"### Assistant: {history[-1][1]}\n",
59
+ ]
60
+ )
61
+ ]
62
+ )
63
+ return text
64
+
65
+
66
+ def log_conversation(conversation_id, history, messages, generate_kwargs):
67
+ logging_url = os.getenv("LOGGING_URL", None)
68
+ if logging_url is None:
69
+ return
70
+
71
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
72
+
73
+ data = {
74
+ "conversation_id": conversation_id,
75
+ "timestamp": timestamp,
76
+ "history": history,
77
+ "messages": messages,
78
+ "generate_kwargs": generate_kwargs,
79
+ }
80
+
81
+ try:
82
+ requests.post(logging_url, json=data)
83
+ except requests.exceptions.RequestException as e:
84
+ print(f"Error logging conversation: {e}")
85
+
86
+
87
+ def user(message, history):
88
+ # Append the user's message to the conversation history
89
+ return "", history + [[message, ""]]
90
+
91
+
92
+ def bot(history, temperature, top_p, top_k, repetition_penalty, conversation_id):
93
+ print(f"history: {history}")
94
+ # Initialize a StopOnTokens object
95
+ stop = StopOnTokens()
96
+
97
+ # Construct the input message string for the model by concatenating the current system message and conversation history
98
+ messages = convert_history_to_text(history)
99
+
100
+ # Tokenize the messages string
101
+ input_ids = tok(messages, return_tensors="pt").input_ids
102
+ input_ids = input_ids.to(m.device)
103
+ streamer = TextIteratorStreamer(tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
104
+ generate_kwargs = dict(
105
+ input_ids=input_ids,
106
+ max_new_tokens=max_new_tokens,
107
+ temperature=temperature,
108
+ do_sample=temperature > 0.0,
109
+ top_p=top_p,
110
+ top_k=top_k,
111
+ repetition_penalty=repetition_penalty,
112
+ streamer=streamer,
113
+ stopping_criteria=StoppingCriteriaList([stop]),
114
+ )
115
+
116
+ stream_complete = Event()
117
+
118
+ def generate_and_signal_complete():
119
+ m.generate(**generate_kwargs)
120
+ stream_complete.set()
121
+
122
+ def log_after_stream_complete():
123
+ stream_complete.wait()
124
+ log_conversation(
125
+ conversation_id,
126
+ history,
127
+ messages,
128
+ {
129
+ "top_k": top_k,
130
+ "top_p": top_p,
131
+ "temperature": temperature,
132
+ "repetition_penalty": repetition_penalty,
133
+ },
134
+ )
135
+
136
+ t1 = Thread(target=generate_and_signal_complete)
137
+ t1.start()
138
+
139
+ t2 = Thread(target=log_after_stream_complete)
140
+ t2.start()
141
+
142
+ # Initialize an empty string to store the generated text
143
+ partial_text = ""
144
+ for new_text in streamer:
145
+ partial_text += new_text
146
+ history[-1][1] = partial_text
147
+ yield history
148
+
149
+
150
+ def get_uuid():
151
+ return str(uuid4())
152
+
153
+
154
+ with gr.Blocks(
155
+ theme=gr.themes.Soft(),
156
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
157
+ ) as demo:
158
+ conversation_id = gr.State(get_uuid)
159
+ gr.Markdown(
160
+ """<h1><center>Guanaco Demo</center></h1>
161
+ """
162
+ )
163
+ chatbot = gr.Chatbot().style(height=500)
164
+ with gr.Row():
165
+ with gr.Column():
166
+ msg = gr.Textbox(
167
+ label="Chat Message Box",
168
+ placeholder="Chat Message Box",
169
+ show_label=False,
170
+ ).style(container=False)
171
+ with gr.Column():
172
+ with gr.Row():
173
+ submit = gr.Button("Submit")
174
+ stop = gr.Button("Stop")
175
+ clear = gr.Button("Clear")
176
+ with gr.Row():
177
+ with gr.Accordion("Advanced Options:", open=False):
178
+ with gr.Row():
179
+ with gr.Column():
180
+ with gr.Row():
181
+ temperature = gr.Slider(
182
+ label="Temperature",
183
+ value=0.7,
184
+ minimum=0.0,
185
+ maximum=1.0,
186
+ step=0.1,
187
+ interactive=True,
188
+ info="Higher values produce more diverse outputs",
189
+ )
190
+ with gr.Column():
191
+ with gr.Row():
192
+ top_p = gr.Slider(
193
+ label="Top-p (nucleus sampling)",
194
+ value=0.9,
195
+ minimum=0.0,
196
+ maximum=1,
197
+ step=0.01,
198
+ interactive=True,
199
+ info=(
200
+ "Sample from the smallest possible set of tokens whose cumulative probability "
201
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
202
+ ),
203
+ )
204
+ with gr.Column():
205
+ with gr.Row():
206
+ top_k = gr.Slider(
207
+ label="Top-k",
208
+ value=0,
209
+ minimum=0.0,
210
+ maximum=200,
211
+ step=1,
212
+ interactive=True,
213
+ info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.",
214
+ )
215
+ with gr.Column():
216
+ with gr.Row():
217
+ repetition_penalty = gr.Slider(
218
+ label="Repetition Penalty",
219
+ value=1.1,
220
+ minimum=1.0,
221
+ maximum=2.0,
222
+ step=0.1,
223
+ interactive=True,
224
+ info="Penalize repetition β€” 1.0 to disable.",
225
+ )
226
+ with gr.Row():
227
+ gr.Markdown(
228
+ "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce "
229
+ "factually accurate information. The model was trained on various public datasets; while great efforts "
230
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
231
+ "biased, or otherwise offensive outputs.",
232
+ elem_classes=["disclaimer"],
233
+ )
234
+
235
+ submit_event = msg.submit(
236
+ fn=user,
237
+ inputs=[msg, chatbot],
238
+ outputs=[msg, chatbot],
239
+ queue=False,
240
+ ).then(
241
+ fn=bot,
242
+ inputs=[
243
+ chatbot,
244
+ temperature,
245
+ top_p,
246
+ top_k,
247
+ repetition_penalty,
248
+ conversation_id,
249
+ ],
250
+ outputs=chatbot,
251
+ queue=True,
252
+ )
253
+ submit_click_event = submit.click(
254
+ fn=user,
255
+ inputs=[msg, chatbot],
256
+ outputs=[msg, chatbot],
257
+ queue=False,
258
+ ).then(
259
+ fn=bot,
260
+ inputs=[
261
+ chatbot,
262
+ temperature,
263
+ top_p,
264
+ top_k,
265
+ repetition_penalty,
266
+ conversation_id,
267
+ ],
268
+ outputs=chatbot,
269
+ queue=True,
270
+ )
271
+ stop.click(
272
+ fn=None,
273
+ inputs=None,
274
+ outputs=None,
275
+ cancels=[submit_event, submit_click_event],
276
+ queue=False,
277
+ )
278
+ clear.click(lambda: None, None, chatbot, queue=False)
279
+
280
+ demo.queue(max_size=128, concurrency_count=2)
281
+
282
+ demo.launch()