Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from globe import title, description, customtool, presentation1, presentation2, joinus | |
import spaces | |
model_path = "nvidia/Mistral-NeMo-Minitron-8B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_path) | |
model = AutoModelForCausalLM.from_pretrained(model_path) | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
def create_prompt(system_message, user_message, tool_definition="", context=""): | |
if tool_definition: | |
return f"""<extra_id_0>System | |
{system_message} | |
<tool> | |
{tool_definition} | |
</tool> | |
<context> | |
{context} | |
</context> | |
<extra_id_1>User | |
{user_message} | |
<extra_id_1>Assistant | |
""" | |
else: | |
return f"<extra_id_0>System\n{system_message}\n\n<extra_id_1>User\n{user_message}\n<extra_id_1>Assistant\n" | |
def generate_response(message, history, system_message, max_tokens, temperature, top_p, do_sample, use_pipeline=False, tool_definition="", context=""): | |
full_prompt = create_prompt(system_message, message, tool_definition, context) | |
if use_pipeline: | |
response = pipe(full_prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=do_sample)[0]['generated_text'] | |
else: | |
max_model_length = model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else 8192 | |
max_length = max_model_length - max_tokens | |
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_length) | |
input_ids = inputs['input_ids'].to(model.device) | |
attention_mask = inputs['attention_mask'].to(model.device) | |
with torch.no_grad(): | |
output_ids = model.generate( | |
input_ids, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=do_sample, | |
attention_mask=attention_mask | |
) | |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
assistant_response = response.split("<extra_id_1>Assistant\n")[-1].strip() | |
if tool_definition and "<toolcall>" in assistant_response: | |
tool_call = assistant_response.split("<toolcall>")[1].split("</toolcall>")[0] | |
assistant_response += f"\n\nTool Call: {tool_call}\n\nNote: This is a simulated tool call. In a real scenario, the tool would be executed and its output would be used to generate a final response." | |
return assistant_response | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, tool_definition): | |
user_message = history[-1][0] | |
do_sample = advanced_checkbox | |
bot_message = generate_response(user_message, history, system_prompt, max_length, temperature, top_p, do_sample, use_pipeline, tool_definition) | |
history[-1][1] = bot_message | |
return history | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown(title) | |
with gr.Row(): | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown(presentation1) | |
with gr.Column(scale=1): | |
with gr.Group(): | |
gr.Markdown(joinus) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
system_prompt = gr.TextArea(label="📑Context", placeholder="add context here...", lines=5) | |
user_input = gr.TextArea(label="🤷🏻♂️User Input", placeholder="Hi there my name is Tonic!", lines=2) | |
advanced_checkbox = gr.Checkbox(label="🧪 Advanced Settings", value=False) | |
with gr.Column(visible=False) as advanced_settings: | |
max_length = gr.Slider(label="📏Max Length", minimum=12, maximum=1700, value=650, step=1) | |
temperature = gr.Slider(label="🌡️Temperature", minimum=0.01, maximum=1.0, value=0.7, step=0.01) | |
top_p = gr.Slider(label="⚛️Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.9, step=0.01) | |
use_pipeline = gr.Checkbox(label="Use Pipeline", value=False) | |
use_tool = gr.Checkbox(label="Use Function Calling", value=False) | |
with gr.Column(visible=False) as tool_options: | |
tool_definition = gr.Code( | |
label="Tool Definition (JSON)", | |
value=customtool, | |
lines=15, | |
language="json" | |
) | |
generate_button = gr.Button(value="🤖Mistral-NeMo-Minitron") | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(label="🤖Mistral-NeMo-Minitron") | |
generate_button.click( | |
user, | |
[user_input, chatbot], | |
[user_input, chatbot], | |
queue=False | |
).then( | |
bot, | |
[chatbot, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, tool_definition], | |
chatbot | |
) | |
advanced_checkbox.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=[advanced_checkbox], | |
outputs=[advanced_settings] | |
) | |
use_tool.change( | |
fn=lambda x: gr.update(visible=x), | |
inputs=[use_tool], | |
outputs=[tool_options] | |
) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() |