indiejoseph's picture
Update app.py
849c7fb verified
raw
history blame contribute delete
No virus
3.69 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# CantoneseLLM Chat
Please join our [Discord server](https://discord.gg/gG6GPp8XxQ) and give me your feedback
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "hon9kon9ize/CantoneseLLMChat-v0.5"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
@spaces.GPU(queue=False)
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 2048,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> str:
conversation = []
conversation.append({"role": "system", "content": system_prompt if system_prompt else "你係由 hon9kon9ize 開發嘅 CantoneseLLM,你係一個好幫得手嘅助理" })
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
print(chat_history)
input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors='pt')
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
output_ids = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty
)
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
return response
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Hello there! How are you doing?"],
["咩嘢係氣候變化?"],
["香港最高嘅山係?"],
["邊個係香港特首?"],
["香港行政长官是谁?"]
],
)
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()