vericudebuget commited on
Commit
0318067
1 Parent(s): 6b4ef39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -42
app.py CHANGED
@@ -1,58 +1,95 @@
1
  from huggingface_hub import InferenceClient
 
2
  import gradio as gr
 
3
  import datetime
4
 
5
  # Initialize the InferenceClient
 
6
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
7
 
8
  def format_prompt(message, history):
9
- prompt = "<s>"
10
- for user_prompt, bot_response in history:
11
- prompt += f"\[INST\] {user_prompt} \[/INST\]"
12
- prompt += f" {bot_response}</s> "
13
- prompt += f"\[INST\] {message} \[/INST\]"
14
- return prompt
 
 
 
 
 
 
15
 
16
  def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=9048, top_p=0.95, repetition_penalty=1.0):
17
- temperature = max(float(temperature), 1e-2)
18
- top_p = float(top_p)
19
- generate_kwargs = dict(
20
- temperature=temperature,
21
- max_new_tokens=max_new_tokens,
22
- top_p=top_p,
23
- repetition_penalty=repetition_penalty,
24
- do_sample=True,
25
- seed=42,
26
- )
27
-
28
- # Get current time
29
- now = datetime.datetime.now()
30
- formatted_time = now.strftime("%H.%M.%S, %B, %Y")
31
- system_prompt = f"server log: ~This message was sent at {formatted_time}. The actual year is 2024.~"
32
-
33
- formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
34
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
35
- output = ""
36
- for response in stream:
37
- output += response.token.text
38
- yield (prompt, output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  additional_inputs = [
41
- gr.Textbox(label="System Prompt", max_lines=1, interactive=True),
42
- gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
43
- gr.Slider(label="Max new tokens", value=9048, minimum=256, maximum=9048, step=64, interactive=True, info="The maximum numbers of new tokens"),
44
- gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
45
- gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
 
 
 
 
 
 
46
  ]
47
 
48
- app = gr.Blocks(theme=gr.themes.Soft())
49
- with app:
50
- chatbot = gr.Chatbot()
51
- text_input = gr.Textbox(label="Your message")
 
 
 
 
 
52
 
53
- def process_message(message, history):
54
- for response in generate(message, history, additional_inputs[0].value):
55
- yield response
56
 
57
- text_input.submit(process_message, inputs=[text_input, chatbot], outputs=[chatbot, text_input])
58
- app.launch(show_api=False)
 
1
  from huggingface_hub import InferenceClient
2
+
3
  import gradio as gr
4
+
5
  import datetime
6
 
7
  # Initialize the InferenceClient
8
+
9
  client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
10
 
11
  def format_prompt(message, history):
12
+
13
+ prompt = "<s>"
14
+
15
+ for user_prompt, bot_response in history:
16
+
17
+ prompt += f"[INST] {user_prompt} [/INST]"
18
+
19
+ prompt += f" {bot_response}</s> "
20
+
21
+ prompt += f"[INST] {message} [/INST]"
22
+
23
+ return prompt
24
 
25
  def generate(prompt, history, system_prompt, temperature=0.9, max_new_tokens=9048, top_p=0.95, repetition_penalty=1.0):
26
+
27
+ temperature = max(float(temperature), 1e-2)
28
+
29
+ top_p = float(top_p)
30
+
31
+ generate_kwargs = dict(
32
+
33
+ temperature=temperature,
34
+
35
+ max_new_tokens=max_new_tokens,
36
+
37
+ top_p=top_p,
38
+
39
+ repetition_penalty=repetition_penalty,
40
+
41
+ do_sample=True,
42
+
43
+ seed=42,
44
+
45
+ )
46
+
47
+ # Get current time
48
+
49
+ now = datetime.datetime.now()
50
+
51
+ formatted_time = now.strftime("%H.%M.%S, %B, %Y")
52
+
53
+ system_prompt = f"server log: This message was sent at {formatted_time}. The actual year is 2024."
54
+
55
+ formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
56
+
57
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
58
+
59
+ output = ""
60
+
61
+ for response in stream:
62
+
63
+ output += response.token.text
64
+
65
+ yield output
66
+
67
+ return output
68
 
69
  additional_inputs = [
70
+
71
+ gr.Textbox(label="System Prompt", max_lines=1, interactive=True),
72
+
73
+ gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
74
+
75
+ gr.Slider(label="Max new tokens", value=9048, minimum=256, maximum=9048, step=64, interactive=True, info="The maximum numbers of new tokens"),
76
+
77
+ gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
78
+
79
+ gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
80
+
81
  ]
82
 
83
+ gr.ChatInterface(
84
+
85
+ fn=generate,
86
+
87
+ chatbot=gr.Chatbot(show_label=True, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
88
+
89
+ additional_inputs=additional_inputs,
90
+
91
+ title="ConvoLite",
92
 
93
+ concurrency_limit=20,
 
 
94
 
95
+ ).launch(show_api=False,)