CataLlama-Chat / app.py
laurentiubp's picture
Update app.py
1e2dab1 verified
raw
history blame
4.04 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# CataLlama-v0.1 Chat
This Space demonstrates model [CataLlama-v0.1-Instruct-DPO](https://huggingface.co/catallama/CataLlama-v0.1-Instruct-DPO).
CataLlama is a fine-tune on Llama-3 8B to enhance it's proficiency on the Catalan Language.
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "catallama/CataLlama-v0.1-Instruct-DPO"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
"""
client = InferenceClient("catallama/CataLlama-v0.1-Instruct-DPO")
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
"""
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(value="Ets un chatbot amigable. Responeu preguntes i ajudeu els usuaris", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.6, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["A quina velocitat poden volar els cocodrils?"],
["Explica pas a pas com resoldre l'equació següent: 2x + 10 = 0"],
["Pot Donald Trump sopar amb Juli Cèsar?"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()