Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from typing import List, Tuple | |
from threading import Thread | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers.generation.streamers import TextIteratorStreamer | |
import torch | |
from project_settings import project_path | |
def greet(question: str, history: List[Tuple[str, str]]): | |
answer = "Hello " + question + "!" | |
result = history + [(question, answer)] | |
return result | |
model_map: dict = dict() | |
def init_model(pretrained_model_name_or_path: str): | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
global model_map | |
if pretrained_model_name_or_path not in model_map.keys(): | |
# clear | |
for k1, v1 in model_map.items(): | |
for k2, v2 in v1.items(): | |
del v2 | |
model_map = dict() | |
# build model | |
model = AutoModelForCausalLM.from_pretrained( | |
pretrained_model_name_or_path, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
offload_folder="./offload", | |
offload_state_dict=True, | |
# load_in_4bit=True, | |
) | |
model = model.to(device) | |
model = model.bfloat16().eval() | |
tokenizer = AutoTokenizer.from_pretrained( | |
pretrained_model_name_or_path, | |
trust_remote_code=True, | |
# llama不支持fast | |
use_fast=False if model.config.model_type == "llama" else True, | |
padding_side="left" | |
) | |
# QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|> | |
if tokenizer.__class__.__name__ == "QWenTokenizer": | |
tokenizer.pad_token_id = tokenizer.eod_id | |
tokenizer.bos_token_id = tokenizer.eod_id | |
tokenizer.eos_token_id = tokenizer.eod_id | |
model_map[pretrained_model_name_or_path] = { | |
"model": model, | |
"tokenizer": tokenizer, | |
} | |
else: | |
model = model_map[pretrained_model_name_or_path]["model"] | |
tokenizer = model_map[pretrained_model_name_or_path]["tokenizer"] | |
return model, tokenizer | |
def chat_with_llm_non_stream(question: str, | |
history: List[Tuple[str, str]], | |
pretrained_model_name_or_path: str, | |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float, | |
): | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
model, tokenizer = init_model(pretrained_model_name_or_path) | |
text_list = list() | |
for pair in history: | |
text_list.extend(pair) | |
text_list.append(question) | |
text_encoded = tokenizer.__call__(text_list, add_special_tokens=False) | |
batch_input_ids = text_encoded["input_ids"] | |
input_ids = [tokenizer.bos_token_id] | |
for input_ids_ in batch_input_ids: | |
input_ids.extend(input_ids_) | |
input_ids.append(tokenizer.eos_token_id) | |
input_ids = torch.tensor([input_ids], dtype=torch.long).to(device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
outputs = outputs.tolist()[0][len(input_ids[0]):] | |
answer = tokenizer.decode(outputs) | |
answer = answer.strip().replace(tokenizer.eos_token, "").strip() | |
result = history + [(question, answer)] | |
return result | |
def chat_with_llm_streaming(question: str, | |
history: List[Tuple[str, str]], | |
pretrained_model_name_or_path: str, | |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float, | |
): | |
device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
model, tokenizer = init_model(pretrained_model_name_or_path) | |
text_list = list() | |
for pair in history: | |
text_list.extend(pair) | |
text_list.append(question) | |
text_encoded = tokenizer.__call__(text_list, add_special_tokens=False) | |
batch_input_ids = text_encoded["input_ids"] | |
input_ids = [tokenizer.bos_token_id] | |
for input_ids_ in batch_input_ids: | |
input_ids.extend(input_ids_) | |
input_ids.append(tokenizer.eos_token_id) | |
input_ids = torch.tensor([input_ids], dtype=torch.long).to(device) | |
streamer = TextIteratorStreamer(tokenizer=tokenizer) | |
generation_kwargs = dict( | |
inputs=input_ids, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
streamer=streamer, | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
answer = "" | |
for output_ in streamer: | |
output_ = output_.replace(question, "") | |
output_ = output_.replace(tokenizer.eos_token, "") | |
answer += output_ | |
result = [(question, answer)] | |
yield history + result | |
def main(): | |
description = """ | |
chat llm | |
""" | |
with gr.Blocks() as blocks: | |
gr.Markdown(value=description) | |
chatbot = gr.Chatbot([], elem_id="chatbot", height=400) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
text_box = gr.Textbox(show_label=False, placeholder="Enter text and press enter", container=False) | |
with gr.Column(scale=1): | |
submit_button = gr.Button("💬Submit") | |
with gr.Column(scale=1): | |
clear_button = gr.Button( | |
'🗑️Clear', | |
variant='secondary', | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
max_new_tokens = gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens") | |
with gr.Column(scale=1): | |
top_p = gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p") | |
with gr.Column(scale=1): | |
temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature") | |
with gr.Column(scale=1): | |
repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty") | |
with gr.Row(): | |
model_name = gr.Dropdown(choices=["Qwen/Qwen-7B-Chat"], | |
value="Qwen/Qwen-7B-Chat", | |
label="model_name", | |
) | |
gr.Examples(examples=["你好"], inputs=text_box) | |
inputs = [ | |
text_box, chatbot, model_name, | |
max_new_tokens, top_p, temperature, repetition_penalty, | |
] | |
outputs = [ | |
chatbot | |
] | |
text_box.submit(chat_with_llm_streaming, inputs, outputs) | |
submit_button.click(chat_with_llm_streaming, inputs, outputs) | |
clear_button.click( | |
fn=lambda: ('', ''), | |
outputs=[text_box, chatbot], | |
queue=False, | |
api_name=False, | |
) | |
blocks.queue().launch() | |
return | |
if __name__ == '__main__': | |
main() | |