Spaces:
Runtime error
Runtime error
from peft import PeftModel, PeftConfig | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig | |
import gradio as gr | |
import torch | |
config = PeftConfig.from_pretrained("Junity/Genshin-World-Model") | |
model = AutoModelForCausalLM.from_pretrained("baichuan-inc/Baichuan-13B-Base") | |
model = PeftModel.from_pretrained(model, "Junity/Genshin-World-Model") | |
tokenizer = AutoTokenizer.from_pretrained("Junity/Genshin-World-Model") | |
history = [] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def respond(role_name, msg, chatbot, character): | |
global history | |
if role_name is not None: | |
history.append(role_name + ":" + msg) | |
else: | |
history.append(msg) | |
total_input = [] | |
for i, message in enumerate(history[::-1]): | |
content_tokens = tokenizer.encode(message + '\n') | |
total_input = content_tokens + total_input | |
if content_tokens + total_input > 4096: | |
break | |
total_input = total_input[-4096:] | |
input_ids = torch.LongTensor([total_input]).to(device) | |
generation_config = model.generation_config | |
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) | |
def stream_generator(): | |
outputs = [] | |
for token in model.generate(input_ids, generation_config=stream_config): | |
outputs.append(token.item()) | |
yield None, tokenizer.decode(outputs, skip_special_tokens=True) | |
return stream_generator() | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
## Genshin-World-Model | |
- 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model) | |
- 此模型不支持要求对方回答什么,只支持续写。 | |
""" | |
) | |
with gr.Tab("聊天") as chat: | |
role_name = gr.Textbox(label="你将扮演的角色") | |
msg = gr.Textbox(label="输入") | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
sub = gr.Button("Submit") | |
chatbot = gr.Chatbot() | |
sub.click(fn=respond, inputs=[role_name, msg, chatbot], outputs=[msg, chatbot]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue().launch() | |