File size: 5,697 Bytes
e10040f
 
 
ee5b6dd
dec5480
6d32964
4299336
651f3a2
e10040f
a3cff49
 
e10040f
c4fc3fe
e10040f
18d32fd
e10040f
 
 
 
 
 
 
 
18d32fd
e10040f
 
 
 
 
 
 
 
 
ca7faf3
5736661
ca7faf3
18d32fd
4299336
e10040f
ee5b6dd
e10040f
5736661
 
 
 
 
55b91e5
 
 
e10040f
 
55b91e5
e10040f
 
 
ca7faf3
55b91e5
e10040f
 
 
4299336
e10040f
4299336
e10040f
 
 
4299336
e10040f
 
ca7faf3
 
 
e10040f
18d32fd
 
 
4299336
ade11b4
4299336
 
 
ee5b6dd
 
 
e10040f
ca7faf3
 
 
 
 
 
 
 
 
ad72fd3
 
 
 
 
ee5b6dd
ad72fd3
 
 
ca7faf3
 
0071153
ca7faf3
 
 
e10040f
 
 
ca7faf3
e10040f
ca7faf3
e10040f
 
 
ca7faf3
 
 
 
 
 
 
 
 
18d32fd
ca7faf3
 
 
 
 
e10040f
 
 
 
 
 
 
651f3a2
e10040f
ee5b6dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from globe import title, description, customtool, presentation1, presentation2, joinus
import spaces

model_path = "nvidia/Mistral-NeMo-Minitron-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

def create_prompt(system_message, user_message, tool_definition="", context=""):
    if tool_definition:
        return f"""<extra_id_0>System
{system_message}

<tool>
{tool_definition}
</tool>
<context>
{context}
</context>

<extra_id_1>User
{user_message}
<extra_id_1>Assistant
"""
    else:
        return f"<extra_id_0>System\n{system_message}\n\n<extra_id_1>User\n{user_message}\n<extra_id_1>Assistant\n"


@spaces.GPU(duration=94)
def generate_response(message, history, system_message, max_tokens, temperature, top_p, do_sample, use_pipeline=False, tool_definition="", context=""):
    full_prompt = create_prompt(system_message, message, tool_definition, context)

    if use_pipeline:
        response = pipe(full_prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True)[0]['generated_text']
    else:
        max_model_length = model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else 8192

        max_length = max_model_length - max_tokens

        inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        input_ids = inputs['input_ids'].to(model.device)
        attention_mask = inputs['attention_mask'].to(model.device)
        
        with torch.no_grad():
            output_ids = model.generate(
                input_ids,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=do_sample,
                attention_mask=attention_mask
            )
        
        response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    assistant_response = response.split("<extra_id_1>Assistant\n")[-1].strip()

    if tool_definition and "<toolcall>" in assistant_response:
        tool_call = assistant_response.split("<toolcall>")[1].split("</toolcall>")[0]
        assistant_response += f"\n\nTool Call: {tool_call}\n\nNote: This is a simulated tool call. In a real scenario, the tool would be executed and its output would be used to generate a final response."

    return assistant_response

def update_advanced_settings(show_advanced):
    return {"visible": show_advanced}

with gr.Blocks() as demo:
    with gr.Row(): 
        gr.Markdown(title)
    with gr.Row():
        gr.Markdown(description)
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown(presentation1)
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown(joinus)
    with gr.Row():
        with gr.Column(scale=2):
            system_prompt = gr.TextArea(label="📑Context", placeholder="add context here...", lines=5)
            user_input = gr.TextArea(label="🤷🏻‍♂️User Input", placeholder="Hi there my name is Tonic!", lines=2)
            advanced_checkbox = gr.Checkbox(label="🧪 Advanced Settings", value=False)
            with gr.Column(visible=False) as advanced_settings:
                max_length = gr.Slider(label="📏Max Length", minimum=12, maximum=1700, value=650, step=1)
                temperature = gr.Slider(label="🌡️Temperature", minimum=0.01, maximum=1.0, value=0.7, step=0.01)
                top_p = gr.Slider(label="⚛️Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
                # do_sample = gr.Checkbox(label="Do Sample", value=True)
                use_pipeline = gr.Checkbox(label="Use Pipeline", value=False)
                use_tool = gr.Checkbox(label="Use Function Calling", value=False)
                with gr.Column(visible=False) as tool_options:
                    tool_definition = gr.Code(
                        label="Tool Definition (JSON)",
                        value=customtool,
                        lines=15,
                        language="json"
                    )
            
            generate_button = gr.Button(value="🤖Mistral-NeMo-Minitron")

        with gr.Column(scale=2):
            chatbot = gr.Chatbot(label="🤖Mistral-NeMo-Minitron")
    
    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history, system_prompt, max_length, temperature, top_p, advanced_settings, use_pipeline, tool_definition):
        user_message = history[-1][0]
        bot_message = generate_response(user_message, history, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, tool_definition)
        history[-1][1] = bot_message
        return history

    generate_button.click(
        user,
        [user_input, chatbot],
        [user_input, chatbot],
        queue=False
    ).then(
        bot,
        [chatbot, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, tool_definition],
        chatbot
    )

    advanced_checkbox.change(
        fn=update_advanced_settings,
        inputs=[advanced_checkbox],
        outputs=advanced_settings
    )

    use_tool.change(
        fn=lambda x: gr.update(visible=x),
        inputs=[use_tool],
        outputs=[tool_options]
    )

if __name__ == "__main__":
    demo.queue()
    demo.launch()