Kukedlc's picture
Update app.py
c2ea2f2 verified
raw
history blame
4.69 kB
import os
import json
import subprocess
from threading import Thread
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"
CHAT_TEMPLATE = "ChatML"
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 16000
# Estableciendo valores directamente para las variables
COLOR = "blue" # Color predeterminado de la interfaz
EMOJI = "🤖" # Emoji predeterminado para el modelo
DESCRIPTION = f"This is the {MODEL_NAME} model designed for coding assistance and general AI tasks." # Descripción predeterminada
@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
# Format history with a given chat template
if CHAT_TEMPLATE == "Auto":
stop_tokens = [tokenizer.eos_token_id]
instruction = system_prompt + "\n\n"
for user, assistant in history:
instruction += f"User: {user}\nAssistant: {assistant}\n"
instruction += f"User: {message}\nAssistant:"
elif CHAT_TEMPLATE == "ChatML":
stop_tokens = ["<|endoftext|>", "<|im_end|>"]
instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
for user, assistant in history:
instruction += f'<|im_start|>user\n{user}\n<|im_end|>\n<|im_start|>assistant\n{assistant}\n<|im_end|>\n'
instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'
elif CHAT_TEMPLATE == "Mistral Instruct":
stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
instruction = f'<s>[INST] {system_prompt}\n'
for user, assistant in history:
instruction += f'{user} [/INST] {assistant}</s>[INST]'
instruction += f' {message} [/INST]'
else:
raise Exception("Incorrect chat template, select 'Auto', 'ChatML' or 'Mistral Instruct'")
print(instruction)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
enc = tokenizer(instruction, return_tensors="pt", padding=True, truncation=True)
input_ids, attention_mask = enc.input_ids, enc.attention_mask
if input_ids.shape[1] > CONTEXT_LENGTH:
input_ids = input_ids[:, -CONTEXT_LENGTH:]
attention_mask = attention_mask[:, -CONTEXT_LENGTH:]
generate_kwargs = dict(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
streamer=streamer,
do_sample=True,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
repetition_penalty=repetition_penalty,
top_p=top_p
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs)
# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
quantization_config=quantization_config,
attn_implementation="flash_attention_2",
)
# Create Gradio interface
gr.ChatInterface(
predict,
title=EMOJI + " " + MODEL_NAME,
description=DESCRIPTION,
examples=[
["¿Puedes resolver la ecuación 2x + 3 = 11 para x?"],
["Escribe un poema épico sobre la Antigua Roma."],
["¿Quién fue la primera persona en caminar sobre la Luna?"],
["Usa una comprensión de listas para crear una lista de cuadrados de los números del 1 al 10."],
["Recomienda algunos libros populares de ciencia ficción."],
["¿Puedes escribir una historia corta sobre un detective que viaja en el tiempo?"]
],
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
additional_inputs=[
gr.Textbox("Eres un modelo que responde de manera precisa en español.", label="System prompt"),
gr.Slider(0, 1, 0.3, label="Temperature"),
gr.Slider(128, 4096, 1024, label="Max new tokens"),
gr.Slider(1, 80, 40, label="Top K sampling"),
gr.Slider(0, 2, 1.1, label="Repetition penalty"),
gr.Slider(0, 1, 0.95, label="Top P sampling"),
],
theme=gr.themes.Soft(primary_hue=COLOR),
).queue().launch()