Spaces:
Runtime error
Runtime error
File size: 4,818 Bytes
a7d37a3 a6636f6 c2769e9 a6636f6 a7d37a3 c2769e9 b00538d a6636f6 b00538d a6636f6 c2769e9 b00538d a6636f6 b00538d 05b7bb9 a6636f6 c2769e9 a6636f6 05b7bb9 d04c3ba a6636f6 c2769e9 a6636f6 a7d37a3 a6636f6 a7d37a3 b00538d a7d37a3 b00538d a7d37a3 b00538d c2769e9 b00538d b64802f 10a57ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
from threading import Thread
import gradio as gr
import torch
# lora_folder = ''
# model_folder = ''
#
# config = PeftConfig.from_pretrained(("Junity/Genshin-World-Model" if lora_folder == ''
# else lora_folder),
# trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
# else model_folder), torch_dtype=torch.float32, trust_remote_code=True)
# model = PeftModel.from_pretrained(model,
# ("Junity/Genshin-World-Model" if lora_folder == ''
# else lora_folder)
# , torch_dtype=torch.float32, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
# else model_folder),
# trust_remote_code=True)
# history = []
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if device == "cuda":
# model.cuda()
# model = model.half()
def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k):
if textbox != '':
textbox = (textbox
+ "\n"
+ role_name
+ (":" if role_name != '' else '')
+ msg
+ ('。\n' if msg[-1] not in ['。', '!', '?'] else ''))
yield ["", textbox]
else:
textbox = (textbox
+ role_name
+ (":" if role_name != '' else '')
+ msg
+ ('。' if msg[-1] not in ['。', '!', '?', ')', '}', ':', ':', '('] else '')
+ ('\n' if msg[-1] in ['。', '!', '?', ')', '}'] else ''))
yield ["", textbox]
if character_name != '':
textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':'
input_ids = tokenizer.encode(textbox)[-3200:]
input_ids = torch.LongTensor([input_ids]).to(device)
generation_config = model.generation_config
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
gen_kwargs = {}
gen_kwargs.update(dict(
input_ids=input_ids,
temperature=temp,
top_p=top_p,
top_k=top_k,
repetition_penalty=rep,
max_new_tokens=max_len,
do_sample=True,
))
outputs = []
print(input_ids)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread.start()
for new_text in streamer:
textbox += new_text
yield ["", textbox]
with gr.Blocks() as demo:
gr.Markdown(
"""
## Genshin-World-Model
- 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
- 此模型不支持要求对方回答什么,只支持续写。
- 目前运行不了,因为没有钱租卡。
"""
)
with gr.Tab("创作") as chat:
role_name = gr.Textbox(label="你将扮演的角色(可留空)")
character_name = gr.Textbox(label="对方的角色(可留空)")
msg = gr.Textbox(label="你说的话")
with gr.Row():
clear = gr.ClearButton()
sub = gr.Button("Submit", variant="primary")
with gr.Row():
temp = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.5, label="温度(调大则更随机)", interactive=True)
rep = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.0, label="对重复生成的惩罚", interactive=True)
max_len = gr.Slider(minimum=4, maximum=512, step=4, value=256, label="对方回答的最大长度", interactive=True)
top_p = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.7, label="Top-p(调大则更随机)", interactive=True)
top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k(调大则更随机)", interactive=True)
textbox = gr.Textbox(interactive=True, label="全部文本(可修改)")
clear.add([msg, role_name, textbox])
sub.click(fn=respond,
inputs=[role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k],
outputs=[msg, textbox])
gr.Markdown(
"""
#### 特别鸣谢 XXXX
"""
)
demo.queue().launch()
|