File size: 3,318 Bytes
6d99564
4208f3e
9e87282
d066eba
1539ae0
a87ee72
f56e2b5
a87ee72
 
4208f3e
8f5ecef
9e87282
5c4a967
 
 
 
 
9e87282
5c4a967
4208f3e
9e87282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
957baa8
9e87282
 
 
 
038d05f
9e87282
00a96ec
11bd8a5
00a96ec
038d05f
00a96ec
11bd8a5
9e87282
 
0af1348
e21a4b2
 
9e87282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from huggingface_hub import InferenceClient
import os
import torch
# Environment variable for HF token
hf_token = os.environ.get("HF_TOKEN")

# Your model ID
model_id = "mistralai/Mistral-7B-Instruct-v0.2"


# quantization_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.float16
# )

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", token=hf_token)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate(
    prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True
    )

    formatted_prompt = format_prompt(prompt, history)

    

    messages = [
        {"role": "user", "content": f"[INST] {prompt} [/INST]"}
    ]

    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt")
    stream = model.generate(inputs, **generate_kwargs)
    output = ""

    decoded = tokenizer.batch_decode(stream)
    print(decoded[0])
    return decoded[0]


additional_inputs=[
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1><center>Mistral 7B Instruct<h1><center>")
    gr.HTML("<h3><center>In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. πŸ’¬<h3><center>")
    gr.HTML("<h3><center>Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. πŸ“š<h3><center>")
    gr.ChatInterface(
        generate,
        additional_inputs=additional_inputs,
        examples=[["What is the secret to life?"], ["Write me a recipe for pancakes."]]
    )

demo.queue(concurrency_count=75, max_size=100).launch(debug=True)