WinterGYC Junity commited on
Commit
1819a5c
0 Parent(s):

Duplicate from Junity/Genshin-World-Model

Browse files

Co-authored-by: Linkang Zhan <[email protected]>

.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ test.py
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/.name ADDED
@@ -0,0 +1 @@
 
 
1
+ app.py
.idea/Genshin-World-Model.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="11">
8
+ <item index="0" class="java.lang.String" itemvalue="tiktoken" />
9
+ <item index="1" class="java.lang.String" itemvalue="scipy" />
10
+ <item index="2" class="java.lang.String" itemvalue="matplotlib" />
11
+ <item index="3" class="java.lang.String" itemvalue="whisper" />
12
+ <item index="4" class="java.lang.String" itemvalue="torch" />
13
+ <item index="5" class="java.lang.String" itemvalue="numpy" />
14
+ <item index="6" class="java.lang.String" itemvalue="requests" />
15
+ <item index="7" class="java.lang.String" itemvalue="torchvision" />
16
+ <item index="8" class="java.lang.String" itemvalue="torchaudio" />
17
+ <item index="9" class="java.lang.String" itemvalue="Pillow" />
18
+ <item index="10" class="java.lang.String" itemvalue="Requests" />
19
+ </list>
20
+ </value>
21
+ </option>
22
+ </inspection_tool>
23
+ </profile>
24
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="pytorch" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Genshin-World-Model.iml" filepath="$PROJECT_DIR$/.idea/Genshin-World-Model.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Genshin World Model
3
+ emoji: 📈
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.40.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ duplicated_from: Junity/Genshin-World-Model
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from peft import PeftModel, PeftConfig
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
4
+ from threading import Thread
5
+ import gradio as gr
6
+ import torch
7
+
8
+
9
+ # lora_folder = ''
10
+ # model_folder = ''
11
+ #
12
+ # config = PeftConfig.from_pretrained(("Junity/Genshin-World-Model" if lora_folder == ''
13
+ # else lora_folder),
14
+ # trust_remote_code=True)
15
+ # model = AutoModelForCausalLM.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
16
+ # else model_folder),
17
+ # torch_dtype=torch.float16,
18
+ # device_map="auto",
19
+ # trust_remote_code=True)
20
+ # model = PeftModel.from_pretrained(model,
21
+ # ("Junity/Genshin-World-Model" if lora_folder == ''
22
+ # else lora_folder),
23
+ # device_map="auto",
24
+ # torch_dtype=torch.float16,
25
+ # trust_remote_code=True)
26
+ # tokenizer = AutoTokenizer.from_pretrained(("baichuan-inc/Baichuan-13B-Base" if model_folder == ''
27
+ # else model_folder),
28
+ # trust_remote_code=True)
29
+ history = []
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+
33
+ def respond(role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k):
34
+ if textbox != '':
35
+ textbox = (textbox
36
+ + "\n"
37
+ + role_name
38
+ + (":" if role_name != '' else '')
39
+ + msg
40
+ + ('。\n' if msg[-1] not in ['。', '!', '?'] else ''))
41
+ yield ["", textbox]
42
+ else:
43
+ textbox = (textbox
44
+ + role_name
45
+ + (":" if role_name != '' else '')
46
+ + msg
47
+ + ('。' if msg[-1] not in ['。', '!', '?', ')', '}', ':', ':', '('] else '')
48
+ + ('\n' if msg[-1] in ['。', '!', '?', ')', '}'] else ''))
49
+ yield ["", textbox]
50
+ if character_name != '':
51
+ textbox += ('\n' if textbox[-1] != '\n' else '') + character_name + ':'
52
+ input_ids = tokenizer.encode(textbox)[-3200:]
53
+ input_ids = torch.LongTensor([input_ids]).to(device)
54
+ generation_config = model.generation_config
55
+ stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
56
+ gen_kwargs = {}
57
+ gen_kwargs.update(dict(
58
+ input_ids=input_ids,
59
+ temperature=temp,
60
+ top_p=top_p,
61
+ top_k=top_k,
62
+ repetition_penalty=rep,
63
+ max_new_tokens=max_len,
64
+ do_sample=True,
65
+ ))
66
+ outputs = []
67
+ print(input_ids)
68
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
69
+ gen_kwargs["streamer"] = streamer
70
+
71
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
72
+ thread.start()
73
+
74
+ for new_text in streamer:
75
+ textbox += new_text
76
+ yield ["", textbox]
77
+
78
+
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown(
81
+ """
82
+ ## Genshin-World-Model
83
+ - 模型地址 [https://huggingface.co/Junity/Genshin-World-Model](https://huggingface.co/Junity/Genshin-World-Model)
84
+ - 此模型不支持要求对方回答什么,只支持续写。
85
+ - 目前运行不了,因为没有钱租卡。
86
+ """
87
+ )
88
+ with gr.Tab("创作") as chat:
89
+ role_name = gr.Textbox(label="你将扮演的角色(可留空)")
90
+ character_name = gr.Textbox(label="对方的角色(可留空)")
91
+ msg = gr.Textbox(label="你说的话")
92
+ with gr.Row():
93
+ clear = gr.ClearButton()
94
+ sub = gr.Button("Submit", variant="primary")
95
+ with gr.Row():
96
+ temp = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.5, label="温度(调大则更随机)", interactive=True)
97
+ rep = gr.Slider(minimum=0, maximum=2.0, step=0.1, value=1.0, label="对重复生成的惩罚", interactive=True)
98
+ max_len = gr.Slider(minimum=4, maximum=512, step=4, value=256, label="对方回答的最大长度", interactive=True)
99
+ top_p = gr.Slider(minimum=0, maximum=1.0, step=0.1, value=0.7, label="Top-p(调大则更随机)", interactive=True)
100
+ top_k = gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Top-k(调大则更随机)", interactive=True)
101
+ textbox = gr.Textbox(interactive=True, label="全部文本(可修改)")
102
+ clear.add([msg, role_name, textbox])
103
+ sub.click(fn=respond,
104
+ inputs=[role_name, character_name, msg, textbox, temp, rep, max_len, top_p, top_k],
105
+ outputs=[msg, textbox])
106
+ gr.Markdown(
107
+ """
108
+ #### 特别鸣谢 XXXX
109
+ """
110
+ )
111
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==3.40.1
2
+ peft==0.4.0
3
+ transformers_stream_generator
4
+ sentencepiece
5
+ accelerate
6
+ colorama
7
+ cpm_kernels