Tonic's picture
modify interface
ca7faf3 unverified
raw
history blame
5.7 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from globe import title, description, customtool, presentation1, presentation2, joinus
import spaces
model_path = "nvidia/Mistral-NeMo-Minitron-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
def create_prompt(system_message, user_message, tool_definition="", context=""):
if tool_definition:
return f"""<extra_id_0>System
{system_message}
<tool>
{tool_definition}
</tool>
<context>
{context}
</context>
<extra_id_1>User
{user_message}
<extra_id_1>Assistant
"""
else:
return f"<extra_id_0>System\n{system_message}\n\n<extra_id_1>User\n{user_message}\n<extra_id_1>Assistant\n"
@spaces.GPU(duration=94)
def generate_response(message, history, system_message, max_tokens, temperature, top_p, do_sample, use_pipeline=False, tool_definition="", context=""):
full_prompt = create_prompt(system_message, message, tool_definition, context)
if use_pipeline:
response = pipe(full_prompt, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True)[0]['generated_text']
else:
max_model_length = model.config.max_position_embeddings if hasattr(model.config, 'max_position_embeddings') else 8192
max_length = max_model_length - max_tokens
inputs = tokenizer(full_prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
input_ids = inputs['input_ids'].to(model.device)
attention_mask = inputs['attention_mask'].to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
attention_mask=attention_mask
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
assistant_response = response.split("<extra_id_1>Assistant\n")[-1].strip()
if tool_definition and "<toolcall>" in assistant_response:
tool_call = assistant_response.split("<toolcall>")[1].split("</toolcall>")[0]
assistant_response += f"\n\nTool Call: {tool_call}\n\nNote: This is a simulated tool call. In a real scenario, the tool would be executed and its output would be used to generate a final response."
return assistant_response
def update_advanced_settings(show_advanced):
return {"visible": show_advanced}
with gr.Blocks() as demo:
with gr.Row():
gr.Markdown(title)
with gr.Row():
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(presentation1)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown(joinus)
with gr.Row():
with gr.Column(scale=2):
system_prompt = gr.TextArea(label="📑Context", placeholder="add context here...", lines=5)
user_input = gr.TextArea(label="🤷🏻‍♂️User Input", placeholder="Hi there my name is Tonic!", lines=2)
advanced_checkbox = gr.Checkbox(label="🧪 Advanced Settings", value=False)
with gr.Column(visible=False) as advanced_settings:
max_length = gr.Slider(label="📏Max Length", minimum=12, maximum=1700, value=650, step=1)
temperature = gr.Slider(label="🌡️Temperature", minimum=0.01, maximum=1.0, value=0.7, step=0.01)
top_p = gr.Slider(label="⚛️Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
# do_sample = gr.Checkbox(label="Do Sample", value=True)
use_pipeline = gr.Checkbox(label="Use Pipeline", value=False)
use_tool = gr.Checkbox(label="Use Function Calling", value=False)
with gr.Column(visible=False) as tool_options:
tool_definition = gr.Code(
label="Tool Definition (JSON)",
value=customtool,
lines=15,
language="json"
)
generate_button = gr.Button(value="🤖Mistral-NeMo-Minitron")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="🤖Mistral-NeMo-Minitron")
def user(user_message, history):
return "", history + [[user_message, None]]
def bot(history, system_prompt, max_length, temperature, top_p, advanced_settings, use_pipeline, tool_definition):
user_message = history[-1][0]
bot_message = generate_response(user_message, history, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, tool_definition)
history[-1][1] = bot_message
return history
generate_button.click(
user,
[user_input, chatbot],
[user_input, chatbot],
queue=False
).then(
bot,
[chatbot, system_prompt, max_length, temperature, top_p, advanced_checkbox, use_pipeline, tool_definition],
chatbot
)
advanced_checkbox.change(
fn=update_advanced_settings,
inputs=[advanced_checkbox],
outputs=advanced_settings
)
use_tool.change(
fn=lambda x: gr.update(visible=x),
inputs=[use_tool],
outputs=[tool_options]
)
if __name__ == "__main__":
demo.queue()
demo.launch()