alfonsovelp commited on
Commit
c2b4deb
1 Parent(s): f54e09b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -1
app.py CHANGED
@@ -1,3 +1,111 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/mistralai/Mixtral-8x7B-Instruct-v0.1").launch()
 
1
  import gradio as gr
2
+ import gradio as gr
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
+ from huggingface_hub import InferenceClient
5
+ import os
6
+ import torch
7
+
8
+ hf_token = os.environ.get("HF_TOKEN")
9
+ model_id = "models/mistralai/Mixtral-8x7B-Instruct-v0.1"
10
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", token=hf_token)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
12
+
13
+
14
+
15
+
16
+ def format_prompt(message, history):
17
+ prompt = "<s>"
18
+ for user_prompt, bot_response in history:
19
+ prompt += f"[INST] {user_prompt} [/INST]"
20
+ prompt += f" {bot_response}</s> "
21
+ prompt += f"[INST] {message} [/INST]"
22
+ return prompt
23
+
24
+ def generate(
25
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
26
+ ):
27
+ temperature = float(temperature)
28
+ if temperature < 1e-2:
29
+ temperature = 1e-2
30
+ top_p = float(top_p)
31
+
32
+ generate_kwargs = dict(
33
+ temperature=temperature,
34
+ max_new_tokens=max_new_tokens,
35
+ top_p=top_p,
36
+ repetition_penalty=repetition_penalty,
37
+ do_sample=True,
38
+ )
39
+ formatted_prompt = format_prompt(prompt, history)
40
+
41
+ messages = [
42
+ {"role": "user", "content": f"[INST] {prompt} [/INST]"}
43
+ ]
44
+
45
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
46
+ stream = model.generate(inputs, **generate_kwargs)
47
+ output = ""
48
+
49
+ decoded = tokenizer.batch_decode(stream)
50
+ print(decoded[0])
51
+ return decoded[0]
52
+
53
+
54
+ additional_inputs=[
55
+ gr.Slider(
56
+ label="Temperature",
57
+ value=0.9,
58
+ minimum=0.0,
59
+ maximum=1.0,
60
+ step=0.05,
61
+ interactive=True,
62
+ info="Higher values produce more diverse outputs",
63
+ ),
64
+ gr.Slider(
65
+ label="Max new tokens",
66
+ value=256,
67
+ minimum=0,
68
+ maximum=1048,
69
+ step=64,
70
+ interactive=True,
71
+ info="The maximum numbers of new tokens",
72
+ ),
73
+ gr.Slider(
74
+ label="Top-p (nucleus sampling)",
75
+ value=0.90,
76
+ minimum=0.0,
77
+ maximum=1,
78
+ step=0.05,
79
+ interactive=True,
80
+ info="Higher values sample more low-probability tokens",
81
+ ),
82
+ gr.Slider(
83
+ label="Repetition penalty",
84
+ value=1.2,
85
+ minimum=1.0,
86
+ maximum=2.0,
87
+ step=0.05,
88
+ interactive=True,
89
+ info="Penalize repeated tokens",
90
+ )
91
+ ]
92
+
93
+ css = """
94
+ #mkd {
95
+ height: 500px;
96
+ overflow: auto;
97
+ border: 1px solid #ccc;
98
+ }
99
+ """
100
+
101
+ with gr.Blocks(css=css) as demo:
102
+ gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
103
+ gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. 💬<h3><center>")
104
+ gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. 📚<h3><center>")
105
+ gr.ChatInterface(
106
+ generate,
107
+ additional_inputs=additional_inputs,
108
+ examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]]
109
+ )
110
 
111
+ demo.queue(concurrency_count=75, max_size=100).launch(debug=True)