Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
import time | |
import numpy as np | |
from torch.nn import functional as F | |
import os | |
from threading import Thread | |
print(f"Starting to load the model to memory") | |
m = AutoModelForCausalLM.from_pretrained( | |
"stabilityai/stablelm-2-1_6b-zephyr", torch_dtype=torch.float16, trust_remote_code=True) | |
tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-1_6b-zephyr", trust_remote_code=True) | |
generator = pipeline('text-generation', model=m, tokenizer=tok) | |
print(f"Sucessfully loaded the model to the memory") | |
start_message = "" | |
def user(message, history): | |
# Append the user's message to the conversation history | |
return "", history + [[message, ""]] | |
def chat(history): | |
chat = [] | |
for item in history: | |
chat.append({"role": "user", "content": item[0]}) | |
if item[1] is not None: | |
chat.append({"role": "assistant", "content": item[0]}) | |
messages = tokenizer.apply_chat_template(chat, tokenize=False) | |
# Tokenize the messages string | |
model_inputs = tok([messages], return_tensors="pt") | |
streamer = TextIteratorStreamer( | |
tok, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=1024, | |
do_sample=True, | |
top_p=0.95, | |
top_k=1000, | |
temperature=0.75, | |
num_beams=1, | |
) | |
t = Thread(target=m.generate, kwargs=generate_kwargs) | |
t.start() | |
# print(history) | |
# Initialize an empty string to store the generated text | |
partial_text = "" | |
for new_text in streamer: | |
# print(new_text) | |
partial_text += new_text | |
history[-1][1] = partial_text | |
# Yield an empty string to cleanup the message textbox and the updated conversation history | |
yield history | |
return partial_text | |
with gr.Blocks() as demo: | |
# history = gr.State([]) | |
gr.Markdown("## Stable LM 1.6b Zephyr") | |
gr.HTML('''<center><a href="https://huggingface.co/spaces/stabilityai/stablelm-2-1_6b-zephyr?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate the Space to skip the queue and run in a private space</center>''') | |
chatbot = gr.Chatbot().style(height=500) | |
with gr.Row(): | |
with gr.Column(): | |
msg = gr.Textbox(label="Chat Message Box", placeholder="Chat Message Box", | |
show_label=False).style(container=False) | |
with gr.Column(): | |
with gr.Row(): | |
submit = gr.Button("Submit") | |
stop = gr.Button("Stop") | |
clear = gr.Button("Clear") | |
submit_event = msg.submit(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
fn=chat, inputs=[chatbot], outputs=[chatbot], queue=True) | |
submit_click_event = submit.click(fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False).then( | |
fn=chat, inputs=[chatbot], outputs=[chatbot], queue=True) | |
stop.click(fn=None, inputs=None, outputs=None, cancels=[ | |
submit_event, submit_click_event], queue=False) | |
clear.click(lambda: None, None, [chatbot], queue=False) | |
demo.queue(max_size=32, concurrency_count=2) | |
demo.launch() | |