LinkangZhan
support chat
c2769e9
raw
history blame
2.25 kB
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
import gradio as gr
import torch
config = PeftConfig.from_pretrained("Junity/Genshin-World-Model")
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base")
model = PeftModel.from_pretrained(model, "Junity/Genshin-World-Model")
tokenizer = AutoTokenizer.from_pretrained("Junity/Genshin-World-Model")
history = []
device = "cuda" if torch.cuda.is_available() else "cpu"
def respond(role_name, msg, chatbot, character):
global history
if role_name is not None:
history.append(role_name + ":" + msg)
else:
history.append(msg)
total_input = []
for i, message in enumerate(history[::-1]):
content_tokens = tokenizer.encode(message + '\n')
total_input = content_tokens + total_input
if content_tokens + total_input > 4096:
break
total_input = total_input[-4096:]
input_ids = torch.LongTensor([total_input]).to(device)
generation_config = model.generation_config
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
def stream_generator():
outputs = []
for token in model.generate(input_ids, generation_config=stream_config):
outputs.append(token.item())
yield None, tokenizer.decode(outputs, skip_special_tokens=True)
return stream_generator()
with gr.Blocks() as demo:
gr.Markdown(
"""
## Genshin-World-Model
- 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
- 此模型不支持要求对方回答什么,只支持续写。
"""
)
with gr.Tab("聊天") as chat:
role_name = gr.Textbox(label="你将扮演的角色")
msg = gr.Textbox(label="输入")
with gr.Row():
clear = gr.Button("Clear")
sub = gr.Button("Submit")
chatbot = gr.Chatbot()
sub.click(fn=respond, inputs=[role_name, msg, chatbot], outputs=[msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue().launch()