CataLlama-Chat / app.py
laurentiubp's picture
Update app.py
6e4f7cc verified
raw
history blame
No virus
3.66 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# CataLlama-v0.1 Chat
This Space demonstrates model [CataLlama-v0.1-Instruct-DPO](https://huggingface.co/catallama/CataLlama-v0.1-Instruct-DPO).
CataLlama is a fine-tune on Llama-3 8B to enhance it's proficiency on the Catalan Language.
"""
LICENSE = """
<p/>
---
As a derivate work of [Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) by Meta,
this demo is governed by the original [llama-3 license](https://llama.meta.com/llama3/license)
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "catallama/CataLlama-v0.1-Instruct-DPO"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(
value="Ets un chatbot amigable. Responeu preguntes i ajudeu els usuaris",
label="System message",
lines=6
),
gr.Slider(
minimum=1,
maximum=2048,
value=1024,
step=256,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.05,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.90,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
examples=[
["A quina velocitat poden volar els cocodrils?"],
["Explica pas a pas com resoldre l'equació següent: 2x + 10 = 0"],
["Pot Donald Trump sopar amb Juli Cèsar?"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()