Spaces:
Running
on
A10G
Running
on
A10G
update to 1.2
Browse files- app.py +420 -129
- fish_speech/configs/base.yaml +1 -0
- fish_speech/configs/firefly_gan_vq.yaml +34 -0
- fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
- fish_speech/configs/text2semantic_finetune.yaml +22 -18
- fish_speech/datasets/concat_repeat.py +53 -0
- fish_speech/datasets/semantic.py +496 -0
- fish_speech/datasets/vqgan.py +3 -1
- fish_speech/models/text2semantic/__init__.py +0 -3
- fish_speech/models/text2semantic/lit_module.py +22 -164
- fish_speech/models/text2semantic/llama.py +227 -70
- fish_speech/models/text2semantic/lora.py +92 -0
- fish_speech/models/vqgan/modules/firefly.py +88 -1
- fish_speech/models/vqgan/modules/fsq.py +1 -1
- fish_speech/text/__init__.py +2 -1
- fish_speech/text/chn_text_norm/.gitignore +114 -0
- fish_speech/text/chn_text_norm/README.md +36 -0
- fish_speech/text/chn_text_norm/__init__.py +0 -0
- fish_speech/text/chn_text_norm/basic_class.py +172 -0
- fish_speech/text/chn_text_norm/basic_constant.py +30 -0
- fish_speech/text/chn_text_norm/basic_util.py +342 -0
- fish_speech/text/chn_text_norm/cardinal.py +32 -0
- fish_speech/text/chn_text_norm/date.py +75 -0
- fish_speech/text/chn_text_norm/digit.py +32 -0
- fish_speech/text/chn_text_norm/fraction.py +35 -0
- fish_speech/text/chn_text_norm/money.py +43 -0
- fish_speech/text/chn_text_norm/percentage.py +33 -0
- fish_speech/text/chn_text_norm/telephone.py +51 -0
- fish_speech/text/chn_text_norm/text.py +177 -0
- fish_speech/text/clean.py +1 -5
- fish_speech/text/spliter.py +130 -0
- fish_speech/utils/file.py +1 -1
- fish_speech/utils/rich_utils.py +7 -3
- fish_speech/utils/spectrogram.py +122 -0
- tools/api.py +482 -0
- tools/auto_rerank.py +159 -0
- tools/llama/build_dataset.py +169 -0
- tools/llama/eval_in_context.py +171 -0
- tools/llama/generate.py +119 -180
- tools/llama/merge_lora.py +95 -0
- tools/llama/quantize.py +46 -64
- tools/llama/rebuild_tokenizer.py +57 -0
- tools/vqgan/create_train_split.py +83 -0
- tools/vqgan/extract_vq.py +227 -0
- tools/vqgan/inference.py +29 -26
app.py
CHANGED
@@ -5,7 +5,7 @@ import hydra
|
|
5 |
|
6 |
# Download if not exists
|
7 |
os.makedirs("checkpoints", exist_ok=True)
|
8 |
-
snapshot_download(repo_id="fishaudio/fish-speech-1", local_dir="./checkpoints/fish-speech-1")
|
9 |
|
10 |
print("All checkpoints downloaded")
|
11 |
|
@@ -23,6 +23,16 @@ from transformers import AutoTokenizer
|
|
23 |
|
24 |
from tools.llama.generate import launch_thread_safe_queue
|
25 |
from tools.vqgan.inference import load_model as load_vqgan_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Make einx happy
|
28 |
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
@@ -30,8 +40,8 @@ os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
|
30 |
|
31 |
HEADER_MD = """# Fish Speech
|
32 |
|
33 |
-
## The demo in this space is version 1.
|
34 |
-
## 该 Demo 为 Fish Speech 1.
|
35 |
|
36 |
A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
|
37 |
由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
|
@@ -39,14 +49,14 @@ A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https
|
|
39 |
You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).
|
40 |
你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.
|
41 |
|
42 |
-
Related code
|
43 |
-
|
44 |
|
45 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
46 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
47 |
|
48 |
-
The model running in this WebUI is Fish Speech V1 Medium SFT
|
49 |
-
在此 WebUI 中运行的模型是 Fish Speech V1 Medium SFT
|
50 |
"""
|
51 |
|
52 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
@@ -85,36 +95,27 @@ def inference(
|
|
85 |
top_p,
|
86 |
repetition_penalty,
|
87 |
temperature,
|
88 |
-
|
89 |
):
|
90 |
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
91 |
-
return
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
reference_audio_content, _ = librosa.load(
|
98 |
-
reference_audio, sr=vqgan_model.sampling_rate, mono=True
|
99 |
-
)
|
100 |
-
audios = torch.from_numpy(reference_audio_content).to(vqgan_model.device)[
|
101 |
-
None, None, :
|
102 |
-
]
|
103 |
-
|
104 |
-
logger.info(
|
105 |
-
f"Loaded audio with {audios.shape[2] / vqgan_model.sampling_rate:.2f} seconds"
|
106 |
)
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
|
|
113 |
|
114 |
# LLAMA Inference
|
115 |
request = dict(
|
116 |
-
|
117 |
-
device=vqgan_model.device,
|
118 |
max_new_tokens=max_new_tokens,
|
119 |
text=text,
|
120 |
top_p=top_p,
|
@@ -123,43 +124,246 @@ def inference(
|
|
123 |
compile=args.compile,
|
124 |
iterative_prompt=chunk_length > 0,
|
125 |
chunk_length=chunk_length,
|
126 |
-
max_length=
|
127 |
-
speaker=speaker if speaker else None,
|
128 |
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
129 |
prompt_text=reference_text if enable_reference_audio else None,
|
130 |
)
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
135 |
)
|
136 |
-
llama_queue.put(payload)
|
137 |
|
138 |
-
|
|
|
|
|
|
|
|
|
139 |
while True:
|
140 |
-
result =
|
141 |
-
if result == "
|
142 |
-
|
143 |
-
continue
|
144 |
-
|
145 |
-
if result == "done":
|
146 |
-
if payload["success"] is False:
|
147 |
-
return None, build_html_error_message(payload["response"])
|
148 |
break
|
149 |
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
-
codes = torch.cat(codes, dim=1)
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
fake_audios = vqgan_model.decode(
|
157 |
-
indices=codes[None], feature_lengths=feature_lengths, return_audios=True
|
158 |
-
)[0, 0]
|
159 |
|
160 |
-
|
|
|
|
|
|
|
|
|
161 |
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
|
165 |
def build_app():
|
@@ -170,95 +374,179 @@ def build_app():
|
|
170 |
app.load(
|
171 |
None,
|
172 |
None,
|
173 |
-
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '
|
|
|
174 |
)
|
175 |
|
176 |
# Inference
|
177 |
with gr.Row():
|
178 |
with gr.Column(scale=3):
|
179 |
text = gr.Textbox(
|
180 |
-
label="Input Text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
)
|
182 |
|
183 |
with gr.Row():
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
chunk_length = gr.Slider(
|
186 |
-
label="Iterative Prompt Length, 0 means off
|
187 |
minimum=0,
|
188 |
-
maximum=
|
189 |
-
value=
|
190 |
step=8,
|
191 |
)
|
192 |
|
193 |
max_new_tokens = gr.Slider(
|
194 |
-
label="Maximum tokens per batch, 0 means no limit
|
195 |
-
minimum=
|
196 |
-
maximum=
|
197 |
-
value=
|
198 |
step=8,
|
199 |
)
|
200 |
|
201 |
top_p = gr.Slider(
|
202 |
-
label="Top-P",
|
|
|
|
|
|
|
|
|
203 |
)
|
204 |
|
205 |
repetition_penalty = gr.Slider(
|
206 |
label="Repetition Penalty",
|
207 |
-
minimum=
|
208 |
-
maximum=
|
209 |
-
value=1.
|
210 |
step=0.01,
|
211 |
)
|
212 |
|
213 |
temperature = gr.Slider(
|
214 |
label="Temperature",
|
215 |
-
minimum=0,
|
216 |
-
maximum=
|
217 |
value=0.7,
|
218 |
step=0.01,
|
219 |
)
|
220 |
|
221 |
-
|
222 |
-
label="Speaker / 说话人",
|
223 |
-
placeholder="Type name of the speaker / 输入说话人的名称",
|
224 |
-
lines=1,
|
225 |
-
)
|
226 |
-
|
227 |
-
with gr.Tab(label="Reference Audio / 参考音频"):
|
228 |
gr.Markdown(
|
229 |
-
|
230 |
)
|
231 |
|
232 |
enable_reference_audio = gr.Checkbox(
|
233 |
-
label="Enable Reference Audio
|
234 |
)
|
235 |
reference_audio = gr.Audio(
|
236 |
-
label="Reference Audio
|
237 |
type="filepath",
|
238 |
)
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
)
|
244 |
|
245 |
with gr.Column(scale=3):
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
with gr.Row():
|
252 |
with gr.Column(scale=3):
|
253 |
generate = gr.Button(
|
254 |
-
value="\U0001F3A7
|
|
|
|
|
|
|
|
|
255 |
)
|
256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
# # Submit
|
258 |
generate.click(
|
259 |
-
|
260 |
[
|
261 |
-
|
262 |
enable_reference_audio,
|
263 |
reference_audio,
|
264 |
reference_text,
|
@@ -267,12 +555,29 @@ def build_app():
|
|
267 |
top_p,
|
268 |
repetition_penalty,
|
269 |
temperature,
|
270 |
-
|
|
|
271 |
],
|
272 |
-
[
|
273 |
concurrency_limit=1,
|
274 |
)
|
275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
return app
|
277 |
|
278 |
|
@@ -281,74 +586,60 @@ def parse_args():
|
|
281 |
parser.add_argument(
|
282 |
"--llama-checkpoint-path",
|
283 |
type=Path,
|
284 |
-
default="checkpoints/
|
285 |
)
|
286 |
parser.add_argument(
|
287 |
-
"--
|
288 |
-
)
|
289 |
-
parser.add_argument(
|
290 |
-
"--vqgan-checkpoint-path",
|
291 |
type=Path,
|
292 |
-
default="checkpoints/
|
293 |
)
|
294 |
-
parser.add_argument("--
|
295 |
-
parser.add_argument("--tokenizer", type=str, default="fishaudio/fish-speech-1")
|
296 |
parser.add_argument("--device", type=str, default="cuda")
|
297 |
parser.add_argument("--half", action="store_true")
|
298 |
-
parser.add_argument("--max-length", type=int, default=2048)
|
299 |
parser.add_argument("--compile", action="store_true")
|
300 |
parser.add_argument("--max-gradio-length", type=int, default=0)
|
|
|
301 |
|
302 |
return parser.parse_args()
|
303 |
|
304 |
|
305 |
if __name__ == "__main__":
|
306 |
args = parse_args()
|
307 |
-
|
308 |
args.precision = torch.half if args.half else torch.bfloat16
|
309 |
-
args.compile = True
|
310 |
-
args.max_gradio_length = 1024
|
311 |
-
args.tokenizer = "./checkpoints/fish-speech-1"
|
312 |
-
args.llama_checkpoint_path = "./checkpoints/fish-speech-1/text2semantic-sft-medium-v1-4k.pth"
|
313 |
-
args.llama_config_name = "dual_ar_2_codebook_medium"
|
314 |
-
args.vqgan_checkpoint_path = "./checkpoints/fish-speech-1/vq-gan-group-fsq-2x1024.pth"
|
315 |
-
args.vqgan_config_name = "vqgan_pretrain"
|
316 |
|
317 |
logger.info("Loading Llama model...")
|
318 |
llama_queue = launch_thread_safe_queue(
|
319 |
-
config_name=args.llama_config_name,
|
320 |
checkpoint_path=args.llama_checkpoint_path,
|
321 |
device=args.device,
|
322 |
precision=args.precision,
|
323 |
-
max_length=args.max_length,
|
324 |
compile=args.compile,
|
325 |
)
|
326 |
-
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
327 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
328 |
|
329 |
-
|
330 |
-
config_name=args.
|
331 |
-
checkpoint_path=args.
|
332 |
device=args.device,
|
333 |
)
|
334 |
|
335 |
-
logger.info("
|
336 |
|
337 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
|
|
349 |
)
|
350 |
|
351 |
logger.info("Warming up done, launching the web UI...")
|
352 |
|
353 |
app = build_app()
|
354 |
-
app.launch(show_api=
|
|
|
5 |
|
6 |
# Download if not exists
|
7 |
os.makedirs("checkpoints", exist_ok=True)
|
8 |
+
snapshot_download(repo_id="fishaudio/fish-speech-1.2-sft", local_dir="./checkpoints/fish-speech-1.2-sft")
|
9 |
|
10 |
print("All checkpoints downloaded")
|
11 |
|
|
|
23 |
|
24 |
from tools.llama.generate import launch_thread_safe_queue
|
25 |
from tools.vqgan.inference import load_model as load_vqgan_model
|
26 |
+
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
|
27 |
+
from tools.api import decode_vq_tokens, encode_reference
|
28 |
+
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
|
29 |
+
from tools.llama.generate import (
|
30 |
+
GenerateRequest,
|
31 |
+
GenerateResponse,
|
32 |
+
WrappedGenerateResponse,
|
33 |
+
launch_thread_safe_queue,
|
34 |
+
)
|
35 |
+
from tools.vqgan.inference import load_model as load_decoder_model
|
36 |
|
37 |
# Make einx happy
|
38 |
os.environ["EINX_FILTER_TRACEBACK"] = "false"
|
|
|
40 |
|
41 |
HEADER_MD = """# Fish Speech
|
42 |
|
43 |
+
## The demo in this space is version 1.2, Please check [Fish Audio](https://fish.audio) for the best model.
|
44 |
+
## 该 Demo 为 Fish Speech 1.2 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
|
45 |
|
46 |
A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
|
47 |
由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
|
|
|
49 |
You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).
|
50 |
你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.
|
51 |
|
52 |
+
Related code and weights are released under CC BY-NC-SA 4.0 License.
|
53 |
+
相关代码,权重使用 CC BY-NC-SA 4.0 许可证发布.
|
54 |
|
55 |
We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
|
56 |
我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
|
57 |
|
58 |
+
The model running in this WebUI is Fish Speech V1.2 Medium SFT.
|
59 |
+
在此 WebUI 中运行的模型是 Fish Speech V1.2 Medium SFT.
|
60 |
"""
|
61 |
|
62 |
TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
|
|
|
95 |
top_p,
|
96 |
repetition_penalty,
|
97 |
temperature,
|
98 |
+
streaming=False,
|
99 |
):
|
100 |
if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
|
101 |
+
return (
|
102 |
+
None,
|
103 |
+
None,
|
104 |
+
"Text is too long, please keep it under {} characters.".format(
|
105 |
+
args.max_gradio_length
|
106 |
+
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
)
|
108 |
|
109 |
+
# Parse reference audio aka prompt
|
110 |
+
prompt_tokens = encode_reference(
|
111 |
+
decoder_model=decoder_model,
|
112 |
+
reference_audio=reference_audio,
|
113 |
+
enable_reference_audio=enable_reference_audio,
|
114 |
+
)
|
115 |
|
116 |
# LLAMA Inference
|
117 |
request = dict(
|
118 |
+
device=decoder_model.device,
|
|
|
119 |
max_new_tokens=max_new_tokens,
|
120 |
text=text,
|
121 |
top_p=top_p,
|
|
|
124 |
compile=args.compile,
|
125 |
iterative_prompt=chunk_length > 0,
|
126 |
chunk_length=chunk_length,
|
127 |
+
max_length=2048,
|
|
|
128 |
prompt_tokens=prompt_tokens if enable_reference_audio else None,
|
129 |
prompt_text=reference_text if enable_reference_audio else None,
|
130 |
)
|
131 |
|
132 |
+
response_queue = queue.Queue()
|
133 |
+
llama_queue.put(
|
134 |
+
GenerateRequest(
|
135 |
+
request=request,
|
136 |
+
response_queue=response_queue,
|
137 |
+
)
|
138 |
)
|
|
|
139 |
|
140 |
+
if streaming:
|
141 |
+
yield wav_chunk_header(), None, None
|
142 |
+
|
143 |
+
segments = []
|
144 |
+
|
145 |
while True:
|
146 |
+
result: WrappedGenerateResponse = response_queue.get()
|
147 |
+
if result.status == "error":
|
148 |
+
yield None, None, build_html_error_message(result.response)
|
|
|
|
|
|
|
|
|
|
|
149 |
break
|
150 |
|
151 |
+
result: GenerateResponse = result.response
|
152 |
+
if result.action == "next":
|
153 |
+
break
|
154 |
+
|
155 |
+
with torch.autocast(
|
156 |
+
device_type=(
|
157 |
+
"cpu"
|
158 |
+
if decoder_model.device.type == "mps"
|
159 |
+
else decoder_model.device.type
|
160 |
+
),
|
161 |
+
dtype=args.precision,
|
162 |
+
):
|
163 |
+
fake_audios = decode_vq_tokens(
|
164 |
+
decoder_model=decoder_model,
|
165 |
+
codes=result.codes,
|
166 |
+
)
|
167 |
+
|
168 |
+
fake_audios = fake_audios.float().cpu().numpy()
|
169 |
+
segments.append(fake_audios)
|
170 |
+
|
171 |
+
if streaming:
|
172 |
+
yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
|
173 |
+
|
174 |
+
if len(segments) == 0:
|
175 |
+
return (
|
176 |
+
None,
|
177 |
+
None,
|
178 |
+
build_html_error_message(
|
179 |
+
"No audio generated, please check the input text."
|
180 |
+
),
|
181 |
+
)
|
182 |
+
|
183 |
+
# No matter streaming or not, we need to return the final audio
|
184 |
+
audio = np.concatenate(segments, axis=0)
|
185 |
+
yield None, (decoder_model.spec_transform.sample_rate, audio), None
|
186 |
+
|
187 |
+
if torch.cuda.is_available():
|
188 |
+
torch.cuda.empty_cache()
|
189 |
+
gc.collect()
|
190 |
+
|
191 |
+
|
192 |
+
def inference_with_auto_rerank(
|
193 |
+
text,
|
194 |
+
enable_reference_audio,
|
195 |
+
reference_audio,
|
196 |
+
reference_text,
|
197 |
+
max_new_tokens,
|
198 |
+
chunk_length,
|
199 |
+
top_p,
|
200 |
+
repetition_penalty,
|
201 |
+
temperature,
|
202 |
+
use_auto_rerank,
|
203 |
+
streaming=False,
|
204 |
+
):
|
205 |
+
|
206 |
+
max_attempts = 2 if use_auto_rerank else 1
|
207 |
+
best_wer = float("inf")
|
208 |
+
best_audio = None
|
209 |
+
best_sample_rate = None
|
210 |
+
|
211 |
+
for attempt in range(max_attempts):
|
212 |
+
audio_generator = inference(
|
213 |
+
text,
|
214 |
+
enable_reference_audio,
|
215 |
+
reference_audio,
|
216 |
+
reference_text,
|
217 |
+
max_new_tokens,
|
218 |
+
chunk_length,
|
219 |
+
top_p,
|
220 |
+
repetition_penalty,
|
221 |
+
temperature,
|
222 |
+
streaming=False,
|
223 |
+
)
|
224 |
+
|
225 |
+
# 获取音频数据
|
226 |
+
for _ in audio_generator:
|
227 |
+
pass
|
228 |
+
_, (sample_rate, audio), message = _
|
229 |
+
|
230 |
+
if audio is None:
|
231 |
+
return None, None, message
|
232 |
+
|
233 |
+
if not use_auto_rerank:
|
234 |
+
return None, (sample_rate, audio), None
|
235 |
+
|
236 |
+
asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
|
237 |
+
wer = calculate_wer(text, asr_result["text"])
|
238 |
+
if wer <= 0.3 and not asr_result["huge_gap"]:
|
239 |
+
return None, (sample_rate, audio), None
|
240 |
+
|
241 |
+
if wer < best_wer:
|
242 |
+
best_wer = wer
|
243 |
+
best_audio = audio
|
244 |
+
best_sample_rate = sample_rate
|
245 |
+
|
246 |
+
if attempt == max_attempts - 1:
|
247 |
+
break
|
248 |
+
|
249 |
+
return None, (best_sample_rate, best_audio), None
|
250 |
+
|
251 |
+
|
252 |
+
inference_stream = partial(inference, streaming=True)
|
253 |
+
|
254 |
+
n_audios = 4
|
255 |
+
|
256 |
+
global_audio_list = []
|
257 |
+
global_error_list = []
|
258 |
+
|
259 |
+
|
260 |
+
def inference_wrapper(
|
261 |
+
text,
|
262 |
+
enable_reference_audio,
|
263 |
+
reference_audio,
|
264 |
+
reference_text,
|
265 |
+
max_new_tokens,
|
266 |
+
chunk_length,
|
267 |
+
top_p,
|
268 |
+
repetition_penalty,
|
269 |
+
temperature,
|
270 |
+
batch_infer_num,
|
271 |
+
if_load_asr_model,
|
272 |
+
):
|
273 |
+
audios = []
|
274 |
+
errors = []
|
275 |
+
|
276 |
+
for _ in range(batch_infer_num):
|
277 |
+
result = inference_with_auto_rerank(
|
278 |
+
text,
|
279 |
+
enable_reference_audio,
|
280 |
+
reference_audio,
|
281 |
+
reference_text,
|
282 |
+
max_new_tokens,
|
283 |
+
chunk_length,
|
284 |
+
top_p,
|
285 |
+
repetition_penalty,
|
286 |
+
temperature,
|
287 |
+
if_load_asr_model,
|
288 |
+
)
|
289 |
+
|
290 |
+
_, audio_data, error_message = result
|
291 |
+
|
292 |
+
audios.append(
|
293 |
+
gr.Audio(value=audio_data if audio_data else None, visible=True),
|
294 |
+
)
|
295 |
+
errors.append(
|
296 |
+
gr.HTML(value=error_message if error_message else None, visible=True),
|
297 |
+
)
|
298 |
+
|
299 |
+
for _ in range(batch_infer_num, n_audios):
|
300 |
+
audios.append(
|
301 |
+
gr.Audio(value=None, visible=False),
|
302 |
+
)
|
303 |
+
errors.append(
|
304 |
+
gr.HTML(value=None, visible=False),
|
305 |
+
)
|
306 |
+
|
307 |
+
return None, *audios, *errors
|
308 |
+
|
309 |
+
|
310 |
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
311 |
+
buffer = io.BytesIO()
|
312 |
+
|
313 |
+
with wave.open(buffer, "wb") as wav_file:
|
314 |
+
wav_file.setnchannels(channels)
|
315 |
+
wav_file.setsampwidth(bit_depth // 8)
|
316 |
+
wav_file.setframerate(sample_rate)
|
317 |
+
|
318 |
+
wav_header_bytes = buffer.getvalue()
|
319 |
+
buffer.close()
|
320 |
+
return wav_header_bytes
|
321 |
+
|
322 |
+
|
323 |
+
def normalize_text(user_input, use_normalization):
|
324 |
+
if use_normalization:
|
325 |
+
return ChnNormedText(raw_text=user_input).normalize()
|
326 |
+
else:
|
327 |
+
return user_input
|
328 |
+
|
329 |
+
|
330 |
+
asr_model = None
|
331 |
|
|
|
332 |
|
333 |
+
def change_if_load_asr_model(if_load):
|
334 |
+
global asr_model
|
|
|
|
|
|
|
335 |
|
336 |
+
if if_load:
|
337 |
+
gr.Warning("Loading faster whisper model...")
|
338 |
+
if asr_model is None:
|
339 |
+
asr_model = load_model()
|
340 |
+
return gr.Checkbox(label="Unload faster whisper model", value=if_load)
|
341 |
|
342 |
+
if if_load is False:
|
343 |
+
gr.Warning("Unloading faster whisper model...")
|
344 |
+
del asr_model
|
345 |
+
asr_model = None
|
346 |
+
if torch.cuda.is_available():
|
347 |
+
torch.cuda.empty_cache()
|
348 |
+
gc.collect()
|
349 |
+
return gr.Checkbox(label="Load faster whisper model", value=if_load)
|
350 |
+
|
351 |
+
|
352 |
+
def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
|
353 |
+
if if_load and asr_model is not None:
|
354 |
+
if (
|
355 |
+
if_auto_label
|
356 |
+
and enable_ref
|
357 |
+
and ref_audio is not None
|
358 |
+
and ref_text.strip() == ""
|
359 |
+
):
|
360 |
+
data, sample_rate = librosa.load(ref_audio)
|
361 |
+
res = batch_asr(asr_model, [data], sample_rate)[0]
|
362 |
+
ref_text = res["text"]
|
363 |
+
else:
|
364 |
+
gr.Warning("Whisper model not loaded!")
|
365 |
+
|
366 |
+
return gr.Textbox(value=ref_text)
|
367 |
|
368 |
|
369 |
def build_app():
|
|
|
374 |
app.load(
|
375 |
None,
|
376 |
None,
|
377 |
+
js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
|
378 |
+
% args.theme,
|
379 |
)
|
380 |
|
381 |
# Inference
|
382 |
with gr.Row():
|
383 |
with gr.Column(scale=3):
|
384 |
text = gr.Textbox(
|
385 |
+
label="Input Text", placeholder=TEXTBOX_PLACEHOLDER, lines=10
|
386 |
+
)
|
387 |
+
refined_text = gr.Textbox(
|
388 |
+
label="Realtime Transform Text",
|
389 |
+
placeholder=
|
390 |
+
"Normalization Result Preview (Currently Only Chinese)",
|
391 |
+
lines=5,
|
392 |
+
interactive=False,
|
393 |
)
|
394 |
|
395 |
with gr.Row():
|
396 |
+
if_refine_text = gr.Checkbox(
|
397 |
+
label="Text Normalization",
|
398 |
+
value=True,
|
399 |
+
scale=1,
|
400 |
+
)
|
401 |
+
|
402 |
+
if_load_asr_model = gr.Checkbox(
|
403 |
+
label="Load / Unload ASR model for auto-reranking",
|
404 |
+
value=False,
|
405 |
+
scale=3,
|
406 |
+
)
|
407 |
+
|
408 |
+
with gr.Row():
|
409 |
+
with gr.Tab(label="Advanced Config"):
|
410 |
chunk_length = gr.Slider(
|
411 |
+
label="Iterative Prompt Length, 0 means off",
|
412 |
minimum=0,
|
413 |
+
maximum=500,
|
414 |
+
value=100,
|
415 |
step=8,
|
416 |
)
|
417 |
|
418 |
max_new_tokens = gr.Slider(
|
419 |
+
label="Maximum tokens per batch, 0 means no limit",
|
420 |
+
minimum=0,
|
421 |
+
maximum=2048,
|
422 |
+
value=1024, # 0 means no limit
|
423 |
step=8,
|
424 |
)
|
425 |
|
426 |
top_p = gr.Slider(
|
427 |
+
label="Top-P",
|
428 |
+
minimum=0.6,
|
429 |
+
maximum=0.9,
|
430 |
+
value=0.7,
|
431 |
+
step=0.01,
|
432 |
)
|
433 |
|
434 |
repetition_penalty = gr.Slider(
|
435 |
label="Repetition Penalty",
|
436 |
+
minimum=1,
|
437 |
+
maximum=1.5,
|
438 |
+
value=1.2,
|
439 |
step=0.01,
|
440 |
)
|
441 |
|
442 |
temperature = gr.Slider(
|
443 |
label="Temperature",
|
444 |
+
minimum=0.6,
|
445 |
+
maximum=0.9,
|
446 |
value=0.7,
|
447 |
step=0.01,
|
448 |
)
|
449 |
|
450 |
+
with gr.Tab(label="Reference Audio"):
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
gr.Markdown(
|
452 |
+
"5 to 10 seconds of reference audio, useful for specifying speaker."
|
453 |
)
|
454 |
|
455 |
enable_reference_audio = gr.Checkbox(
|
456 |
+
label="Enable Reference Audio",
|
457 |
)
|
458 |
reference_audio = gr.Audio(
|
459 |
+
label="Reference Audio",
|
460 |
type="filepath",
|
461 |
)
|
462 |
+
with gr.Row():
|
463 |
+
if_auto_label = gr.Checkbox(
|
464 |
+
label="Auto Labeling",
|
465 |
+
min_width=100,
|
466 |
+
scale=0,
|
467 |
+
value=False,
|
468 |
+
)
|
469 |
+
reference_text = gr.Textbox(
|
470 |
+
label="Reference Text",
|
471 |
+
lines=1,
|
472 |
+
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
|
473 |
+
value="",
|
474 |
+
)
|
475 |
+
with gr.Tab(label="Batch Inference"):
|
476 |
+
batch_infer_num = gr.Slider(
|
477 |
+
label="Batch infer nums",
|
478 |
+
minimum=1,
|
479 |
+
maximum=n_audios,
|
480 |
+
step=1,
|
481 |
+
value=1,
|
482 |
)
|
483 |
|
484 |
with gr.Column(scale=3):
|
485 |
+
for _ in range(n_audios):
|
486 |
+
with gr.Row():
|
487 |
+
error = gr.HTML(
|
488 |
+
label="Error Message",
|
489 |
+
visible=True if _ == 0 else False,
|
490 |
+
)
|
491 |
+
global_error_list.append(error)
|
492 |
+
with gr.Row():
|
493 |
+
audio = gr.Audio(
|
494 |
+
label="Generated Audio",
|
495 |
+
type="numpy",
|
496 |
+
interactive=False,
|
497 |
+
visible=True if _ == 0 else False,
|
498 |
+
)
|
499 |
+
global_audio_list.append(audio)
|
500 |
|
501 |
+
with gr.Row():
|
502 |
+
stream_audio = gr.Audio(
|
503 |
+
label="Streaming Audio",
|
504 |
+
streaming=True,
|
505 |
+
autoplay=True,
|
506 |
+
interactive=False,
|
507 |
+
show_download_button=True,
|
508 |
+
)
|
509 |
with gr.Row():
|
510 |
with gr.Column(scale=3):
|
511 |
generate = gr.Button(
|
512 |
+
value="\U0001F3A7 " + "Generate", variant="primary"
|
513 |
+
)
|
514 |
+
generate_stream = gr.Button(
|
515 |
+
value="\U0001F3A7 " + "Streaming Generate",
|
516 |
+
variant="primary",
|
517 |
)
|
518 |
|
519 |
+
text.input(
|
520 |
+
fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
|
521 |
+
)
|
522 |
+
|
523 |
+
if_load_asr_model.change(
|
524 |
+
fn=change_if_load_asr_model,
|
525 |
+
inputs=[if_load_asr_model],
|
526 |
+
outputs=[if_load_asr_model],
|
527 |
+
)
|
528 |
+
|
529 |
+
if_auto_label.change(
|
530 |
+
fn=lambda: gr.Textbox(value=""),
|
531 |
+
inputs=[],
|
532 |
+
outputs=[reference_text],
|
533 |
+
).then(
|
534 |
+
fn=change_if_auto_label,
|
535 |
+
inputs=[
|
536 |
+
if_load_asr_model,
|
537 |
+
if_auto_label,
|
538 |
+
enable_reference_audio,
|
539 |
+
reference_audio,
|
540 |
+
reference_text,
|
541 |
+
],
|
542 |
+
outputs=[reference_text],
|
543 |
+
)
|
544 |
+
|
545 |
# # Submit
|
546 |
generate.click(
|
547 |
+
inference_wrapper,
|
548 |
[
|
549 |
+
refined_text,
|
550 |
enable_reference_audio,
|
551 |
reference_audio,
|
552 |
reference_text,
|
|
|
555 |
top_p,
|
556 |
repetition_penalty,
|
557 |
temperature,
|
558 |
+
batch_infer_num,
|
559 |
+
if_load_asr_model,
|
560 |
],
|
561 |
+
[stream_audio, *global_audio_list, *global_error_list],
|
562 |
concurrency_limit=1,
|
563 |
)
|
564 |
|
565 |
+
generate_stream.click(
|
566 |
+
inference_stream,
|
567 |
+
[
|
568 |
+
refined_text,
|
569 |
+
enable_reference_audio,
|
570 |
+
reference_audio,
|
571 |
+
reference_text,
|
572 |
+
max_new_tokens,
|
573 |
+
chunk_length,
|
574 |
+
top_p,
|
575 |
+
repetition_penalty,
|
576 |
+
temperature,
|
577 |
+
],
|
578 |
+
[stream_audio, global_audio_list[0], global_error_list[0]],
|
579 |
+
concurrency_limit=10,
|
580 |
+
)
|
581 |
return app
|
582 |
|
583 |
|
|
|
586 |
parser.add_argument(
|
587 |
"--llama-checkpoint-path",
|
588 |
type=Path,
|
589 |
+
default="checkpoints/fish-speech-1.2-sft",
|
590 |
)
|
591 |
parser.add_argument(
|
592 |
+
"--decoder-checkpoint-path",
|
|
|
|
|
|
|
593 |
type=Path,
|
594 |
+
default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
595 |
)
|
596 |
+
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
|
|
597 |
parser.add_argument("--device", type=str, default="cuda")
|
598 |
parser.add_argument("--half", action="store_true")
|
|
|
599 |
parser.add_argument("--compile", action="store_true")
|
600 |
parser.add_argument("--max-gradio-length", type=int, default=0)
|
601 |
+
parser.add_argument("--theme", type=str, default="light")
|
602 |
|
603 |
return parser.parse_args()
|
604 |
|
605 |
|
606 |
if __name__ == "__main__":
|
607 |
args = parse_args()
|
|
|
608 |
args.precision = torch.half if args.half else torch.bfloat16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
609 |
|
610 |
logger.info("Loading Llama model...")
|
611 |
llama_queue = launch_thread_safe_queue(
|
|
|
612 |
checkpoint_path=args.llama_checkpoint_path,
|
613 |
device=args.device,
|
614 |
precision=args.precision,
|
|
|
615 |
compile=args.compile,
|
616 |
)
|
|
|
617 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
618 |
|
619 |
+
decoder_model = load_decoder_model(
|
620 |
+
config_name=args.decoder_config_name,
|
621 |
+
checkpoint_path=args.decoder_checkpoint_path,
|
622 |
device=args.device,
|
623 |
)
|
624 |
|
625 |
+
logger.info("Decoder model loaded, warming up...")
|
626 |
|
627 |
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
628 |
+
list(
|
629 |
+
inference(
|
630 |
+
text="Hello, world!",
|
631 |
+
enable_reference_audio=False,
|
632 |
+
reference_audio=None,
|
633 |
+
reference_text="",
|
634 |
+
max_new_tokens=0,
|
635 |
+
chunk_length=100,
|
636 |
+
top_p=0.7,
|
637 |
+
repetition_penalty=1.2,
|
638 |
+
temperature=0.7,
|
639 |
+
)
|
640 |
)
|
641 |
|
642 |
logger.info("Warming up done, launching the web UI...")
|
643 |
|
644 |
app = build_app()
|
645 |
+
app.launch(show_api=True)
|
fish_speech/configs/base.yaml
CHANGED
@@ -17,6 +17,7 @@ trainer:
|
|
17 |
devices: auto
|
18 |
strategy:
|
19 |
_target_: lightning.pytorch.strategies.DDPStrategy
|
|
|
20 |
|
21 |
precision: bf16-mixed
|
22 |
|
|
|
17 |
devices: auto
|
18 |
strategy:
|
19 |
_target_: lightning.pytorch.strategies.DDPStrategy
|
20 |
+
process_group_backend: nccl # This should be override when training on windows
|
21 |
|
22 |
precision: bf16-mixed
|
23 |
|
fish_speech/configs/firefly_gan_vq.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
|
2 |
+
spec_transform:
|
3 |
+
_target_: fish_speech.utils.spectrogram.LogMelSpectrogram
|
4 |
+
sample_rate: 44100
|
5 |
+
n_mels: 160
|
6 |
+
n_fft: 2048
|
7 |
+
hop_length: 512
|
8 |
+
win_length: 2048
|
9 |
+
backbone:
|
10 |
+
_target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
|
11 |
+
input_channels: 160
|
12 |
+
depths: [3, 3, 9, 3]
|
13 |
+
dims: [128, 256, 384, 512]
|
14 |
+
drop_path_rate: 0.2
|
15 |
+
kernel_size: 7
|
16 |
+
head:
|
17 |
+
_target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
|
18 |
+
hop_length: 512
|
19 |
+
upsample_rates: [8, 8, 2, 2, 2] # aka. strides
|
20 |
+
upsample_kernel_sizes: [16, 16, 4, 4, 4]
|
21 |
+
resblock_kernel_sizes: [3, 7, 11]
|
22 |
+
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
23 |
+
num_mels: 512
|
24 |
+
upsample_initial_channel: 512
|
25 |
+
use_template: false
|
26 |
+
pre_conv_kernel_size: 13
|
27 |
+
post_conv_kernel_size: 13
|
28 |
+
quantizer:
|
29 |
+
_target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
|
30 |
+
input_dim: 512
|
31 |
+
n_groups: 4
|
32 |
+
n_codebooks: 1
|
33 |
+
levels: [8, 5, 5, 5]
|
34 |
+
downsample_factor: [2]
|
fish_speech/configs/lora/r_8_alpha_16.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: fish_speech.models.text2semantic.lora.LoraConfig
|
2 |
+
r: 8
|
3 |
+
lora_alpha: 16
|
4 |
+
lora_dropout: 0.01
|
fish_speech/configs/text2semantic_finetune.yaml
CHANGED
@@ -1,18 +1,16 @@
|
|
1 |
defaults:
|
2 |
- base
|
3 |
-
- [email protected]: dual_ar_2_codebook_small
|
4 |
- _self_
|
5 |
|
6 |
project: text2semantic_finetune_dual_ar
|
7 |
-
max_length:
|
8 |
-
|
9 |
-
resume_weights_only: true
|
10 |
|
11 |
# Lightning Trainer
|
12 |
trainer:
|
13 |
accumulate_grad_batches: 1
|
14 |
gradient_clip_val: 1.0
|
15 |
-
gradient_clip_algorithm:
|
16 |
max_steps: 1000
|
17 |
precision: bf16-true
|
18 |
limit_val_batches: 10
|
@@ -21,29 +19,31 @@ trainer:
|
|
21 |
# Dataset Configuration
|
22 |
tokenizer:
|
23 |
_target_: transformers.AutoTokenizer.from_pretrained
|
24 |
-
pretrained_model_name_or_path:
|
25 |
|
26 |
# Dataset Configuration
|
27 |
train_dataset:
|
28 |
-
_target_: fish_speech.datasets.
|
29 |
proto_files:
|
30 |
- data/protos
|
31 |
tokenizer: ${tokenizer}
|
|
|
32 |
max_length: ${max_length}
|
33 |
-
num_codebooks: ${model.model.config.num_codebooks}
|
34 |
use_speaker: false
|
|
|
35 |
|
36 |
val_dataset:
|
37 |
-
_target_: fish_speech.datasets.
|
38 |
proto_files:
|
39 |
- data/protos
|
40 |
tokenizer: ${tokenizer}
|
|
|
41 |
max_length: ${max_length}
|
42 |
-
num_codebooks: ${model.model.config.num_codebooks}
|
43 |
use_speaker: false
|
|
|
44 |
|
45 |
data:
|
46 |
-
_target_: fish_speech.datasets.
|
47 |
train_dataset: ${train_dataset}
|
48 |
val_dataset: ${val_dataset}
|
49 |
num_workers: 4
|
@@ -53,13 +53,18 @@ data:
|
|
53 |
|
54 |
# Model Configuration
|
55 |
model:
|
56 |
-
_target_: fish_speech.models.text2semantic.TextToSemantic
|
57 |
-
model:
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
optimizer:
|
60 |
_target_: torch.optim.AdamW
|
61 |
_partial_: true
|
62 |
-
lr: 1e-
|
63 |
weight_decay: 0
|
64 |
betas: [0.9, 0.95]
|
65 |
eps: 1e-5
|
@@ -68,12 +73,11 @@ model:
|
|
68 |
_target_: torch.optim.lr_scheduler.LambdaLR
|
69 |
_partial_: true
|
70 |
lr_lambda:
|
71 |
-
_target_: fish_speech.scheduler.
|
72 |
_partial_: true
|
73 |
-
num_warmup_steps:
|
74 |
-
num_training_steps: ${trainer.max_steps}
|
75 |
|
76 |
# Callbacks
|
77 |
callbacks:
|
78 |
model_checkpoint:
|
79 |
-
every_n_train_steps:
|
|
|
1 |
defaults:
|
2 |
- base
|
|
|
3 |
- _self_
|
4 |
|
5 |
project: text2semantic_finetune_dual_ar
|
6 |
+
max_length: 4096
|
7 |
+
pretrained_ckpt_path: checkpoints/fish-speech-1.2-sft
|
|
|
8 |
|
9 |
# Lightning Trainer
|
10 |
trainer:
|
11 |
accumulate_grad_batches: 1
|
12 |
gradient_clip_val: 1.0
|
13 |
+
gradient_clip_algorithm: "norm"
|
14 |
max_steps: 1000
|
15 |
precision: bf16-true
|
16 |
limit_val_batches: 10
|
|
|
19 |
# Dataset Configuration
|
20 |
tokenizer:
|
21 |
_target_: transformers.AutoTokenizer.from_pretrained
|
22 |
+
pretrained_model_name_or_path: ${pretrained_ckpt_path}
|
23 |
|
24 |
# Dataset Configuration
|
25 |
train_dataset:
|
26 |
+
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
|
27 |
proto_files:
|
28 |
- data/protos
|
29 |
tokenizer: ${tokenizer}
|
30 |
+
causal: true
|
31 |
max_length: ${max_length}
|
|
|
32 |
use_speaker: false
|
33 |
+
interactive_prob: 0.7
|
34 |
|
35 |
val_dataset:
|
36 |
+
_target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
|
37 |
proto_files:
|
38 |
- data/protos
|
39 |
tokenizer: ${tokenizer}
|
40 |
+
causal: true
|
41 |
max_length: ${max_length}
|
|
|
42 |
use_speaker: false
|
43 |
+
interactive_prob: 0.7
|
44 |
|
45 |
data:
|
46 |
+
_target_: fish_speech.datasets.semantic.SemanticDataModule
|
47 |
train_dataset: ${train_dataset}
|
48 |
val_dataset: ${val_dataset}
|
49 |
num_workers: 4
|
|
|
53 |
|
54 |
# Model Configuration
|
55 |
model:
|
56 |
+
_target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
|
57 |
+
model:
|
58 |
+
_target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
|
59 |
+
path: ${pretrained_ckpt_path}
|
60 |
+
load_weights: true
|
61 |
+
max_length: ${max_length}
|
62 |
+
lora_config: null
|
63 |
|
64 |
optimizer:
|
65 |
_target_: torch.optim.AdamW
|
66 |
_partial_: true
|
67 |
+
lr: 1e-4
|
68 |
weight_decay: 0
|
69 |
betas: [0.9, 0.95]
|
70 |
eps: 1e-5
|
|
|
73 |
_target_: torch.optim.lr_scheduler.LambdaLR
|
74 |
_partial_: true
|
75 |
lr_lambda:
|
76 |
+
_target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
|
77 |
_partial_: true
|
78 |
+
num_warmup_steps: 10
|
|
|
79 |
|
80 |
# Callbacks
|
81 |
callbacks:
|
82 |
model_checkpoint:
|
83 |
+
every_n_train_steps: ${trainer.val_check_interval}
|
fish_speech/datasets/concat_repeat.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
import random
|
3 |
+
from typing import Iterable
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset, IterableDataset
|
6 |
+
|
7 |
+
|
8 |
+
class ConcatRepeatDataset(Dataset):
|
9 |
+
datasets: list[Dataset]
|
10 |
+
cumulative_sizes: list[int]
|
11 |
+
repeats: list[int]
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def cumsum(sequence, repeats):
|
15 |
+
r, s = [], 0
|
16 |
+
for dataset, repeat in zip(sequence, repeats):
|
17 |
+
l = len(dataset) * repeat
|
18 |
+
r.append(l + s)
|
19 |
+
s += l
|
20 |
+
return r
|
21 |
+
|
22 |
+
def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.datasets = list(datasets)
|
26 |
+
self.repeats = repeats
|
27 |
+
|
28 |
+
assert len(self.datasets) > 0, "datasets should not be an empty iterable"
|
29 |
+
assert len(self.datasets) == len(
|
30 |
+
repeats
|
31 |
+
), "datasets and repeats should have the same length"
|
32 |
+
|
33 |
+
for d in self.datasets:
|
34 |
+
assert not isinstance(
|
35 |
+
d, IterableDataset
|
36 |
+
), "ConcatRepeatDataset does not support IterableDataset"
|
37 |
+
|
38 |
+
self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
|
39 |
+
|
40 |
+
def __len__(self):
|
41 |
+
return self.cumulative_sizes[-1]
|
42 |
+
|
43 |
+
def __getitem__(self, idx):
|
44 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
45 |
+
|
46 |
+
if dataset_idx == 0:
|
47 |
+
sample_idx = idx
|
48 |
+
else:
|
49 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
50 |
+
|
51 |
+
dataset = self.datasets[dataset_idx]
|
52 |
+
|
53 |
+
return dataset[sample_idx % len(dataset)]
|
fish_speech/datasets/semantic.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from itertools import chain
|
4 |
+
from pathlib import Path
|
5 |
+
from random import Random
|
6 |
+
from typing import Optional, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pyarrow.parquet as pq
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from datasets.download.streaming_download_manager import xopen
|
13 |
+
from huggingface_hub import HfApi
|
14 |
+
from lightning import LightningDataModule
|
15 |
+
from torch.distributed import get_rank, get_world_size, is_initialized
|
16 |
+
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
|
17 |
+
from transformers import AutoTokenizer
|
18 |
+
|
19 |
+
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
20 |
+
from fish_speech.datasets.protos.text_data_pb2 import SampledData
|
21 |
+
from fish_speech.datasets.protos.text_data_stream import read_pb_stream
|
22 |
+
from fish_speech.text.clean import clean_text
|
23 |
+
from fish_speech.utils import RankedLogger
|
24 |
+
from fish_speech.utils.braceexpand import braceexpand
|
25 |
+
|
26 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
27 |
+
|
28 |
+
|
29 |
+
def split_by_rank_worker(files):
|
30 |
+
# We need to know the total number of devices
|
31 |
+
# to split the data properly
|
32 |
+
|
33 |
+
total_devices = 1
|
34 |
+
if is_initialized():
|
35 |
+
total_devices = get_world_size()
|
36 |
+
|
37 |
+
worker_info = get_worker_info()
|
38 |
+
if worker_info is not None:
|
39 |
+
total_devices *= worker_info.num_workers
|
40 |
+
|
41 |
+
if len(files) < total_devices:
|
42 |
+
# Repeat the files N times to match the number of devices
|
43 |
+
files = files * (total_devices // len(files) + 1)
|
44 |
+
|
45 |
+
# DDP
|
46 |
+
if is_initialized():
|
47 |
+
files = files[get_rank() :: get_world_size()]
|
48 |
+
|
49 |
+
# Split by worker
|
50 |
+
if worker_info is not None:
|
51 |
+
files = files[worker_info.id :: worker_info.num_workers]
|
52 |
+
|
53 |
+
return files
|
54 |
+
|
55 |
+
|
56 |
+
class AutoTextSemanticInstructionDataset(IterableDataset):
|
57 |
+
"""
|
58 |
+
Auto Augment Dataset by Speaker
|
59 |
+
|
60 |
+
1. Random concatenate multiple sentences from the same speaker to form a longer sentence
|
61 |
+
2. Automatically normalize the text
|
62 |
+
|
63 |
+
For interactive mode, we use the following format (multiple sequences):
|
64 |
+
<s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
|
65 |
+
|
66 |
+
For non-interactive mode, we use the following format (one long sequence):
|
67 |
+
<s> [INST] text [/INST] ... </s>
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
proto_files: list[str],
|
73 |
+
seed: int = 42,
|
74 |
+
interactive_prob: float = 0.5,
|
75 |
+
max_length: int = 1024,
|
76 |
+
tokenizer: AutoTokenizer = None,
|
77 |
+
use_speaker: bool | float = True,
|
78 |
+
causal: bool = True,
|
79 |
+
num_codebooks: Optional[int] = None,
|
80 |
+
skip_text_prob: float = 0.0,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
proto_files: proto buf files if using local data
|
85 |
+
seed: random seed
|
86 |
+
interactive_prob: probability to use interactive mode
|
87 |
+
max_length: max length of the text
|
88 |
+
tokenizer: tokenizer
|
89 |
+
use_speaker: include speaker information in the prompt
|
90 |
+
causal: use causal sampling when using local data, disable will lead to random sampling
|
91 |
+
num_codebooks: number of codebooks, if None, it will be automatically detected
|
92 |
+
skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
|
93 |
+
"""
|
94 |
+
|
95 |
+
super().__init__()
|
96 |
+
|
97 |
+
assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
|
98 |
+
|
99 |
+
self.seed = seed
|
100 |
+
self.max_length = max_length
|
101 |
+
self.tokenizer = tokenizer
|
102 |
+
self.interactive_prob = interactive_prob
|
103 |
+
self.use_speaker = use_speaker
|
104 |
+
self.proto_files = proto_files
|
105 |
+
self.causal = causal
|
106 |
+
self.num_codebooks = num_codebooks
|
107 |
+
self.skip_text_prob = skip_text_prob
|
108 |
+
|
109 |
+
self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
|
110 |
+
self.groups = None
|
111 |
+
|
112 |
+
def init_mock_data_server(self):
|
113 |
+
if self.groups is not None:
|
114 |
+
return
|
115 |
+
|
116 |
+
# Expand the proto files
|
117 |
+
expanded_proto_files = []
|
118 |
+
for filename in self.proto_files:
|
119 |
+
for i in braceexpand(filename):
|
120 |
+
i = Path(i)
|
121 |
+
if i.is_file():
|
122 |
+
expanded_proto_files.append(i)
|
123 |
+
elif i.is_dir():
|
124 |
+
expanded_proto_files.extend(i.rglob("*.proto"))
|
125 |
+
expanded_proto_files.extend(i.rglob("*.protos"))
|
126 |
+
else:
|
127 |
+
raise ValueError(f"{i} is not a file or directory")
|
128 |
+
|
129 |
+
expanded_proto_files = sorted(expanded_proto_files)
|
130 |
+
Random(self.seed).shuffle(expanded_proto_files)
|
131 |
+
|
132 |
+
self.groups = []
|
133 |
+
shard_proto_files = split_by_rank_worker(expanded_proto_files)
|
134 |
+
log.info(
|
135 |
+
f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
|
136 |
+
)
|
137 |
+
|
138 |
+
count = 0
|
139 |
+
for filename in shard_proto_files:
|
140 |
+
with open(filename, "rb") as f:
|
141 |
+
for text_data in read_pb_stream(f):
|
142 |
+
self.groups.append(text_data)
|
143 |
+
count += 1
|
144 |
+
|
145 |
+
log.info(f"Read total {count} groups of data")
|
146 |
+
|
147 |
+
# Shuffle the lines
|
148 |
+
Random(self.seed).shuffle(self.groups)
|
149 |
+
self.group_weights = [len(i.sentences) for i in self.groups]
|
150 |
+
|
151 |
+
def __iter__(self):
|
152 |
+
while True:
|
153 |
+
yield self.augment()
|
154 |
+
|
155 |
+
def tokenize_sentence(self, sentence: str):
|
156 |
+
sentence = clean_text(sentence)
|
157 |
+
tokens = self.tokenizer.encode(
|
158 |
+
f"{sentence}",
|
159 |
+
max_length=10**6,
|
160 |
+
add_special_tokens=False,
|
161 |
+
truncation=False,
|
162 |
+
)
|
163 |
+
return sentence, len(tokens)
|
164 |
+
|
165 |
+
def sample_data(self):
|
166 |
+
if self.groups is None:
|
167 |
+
self.init_mock_data_server()
|
168 |
+
|
169 |
+
# Shuffle unique lines, estimate that each sample is at least 20 tokens
|
170 |
+
num_samples = self.max_length // 20
|
171 |
+
|
172 |
+
# choice group based on their number of samples
|
173 |
+
group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
|
174 |
+
|
175 |
+
if self.causal:
|
176 |
+
# Sample in order
|
177 |
+
if num_samples >= len(group.sentences):
|
178 |
+
samples = group.sentences
|
179 |
+
else:
|
180 |
+
begin = random.randint(0, len(group.sentences) - num_samples)
|
181 |
+
samples = group.sentences[begin : begin + num_samples]
|
182 |
+
else:
|
183 |
+
samples = random.choices(
|
184 |
+
group.sentences, k=min(num_samples, len(group.sentences))
|
185 |
+
)
|
186 |
+
|
187 |
+
return SampledData(
|
188 |
+
source=group.source,
|
189 |
+
name=group.name,
|
190 |
+
samples=samples,
|
191 |
+
)
|
192 |
+
|
193 |
+
def augment(self):
|
194 |
+
final_text, final_semantic = [], []
|
195 |
+
response = self.sample_data()
|
196 |
+
if len(response.samples) == 0:
|
197 |
+
# Invalid group
|
198 |
+
return None
|
199 |
+
|
200 |
+
samples = list(response.samples)
|
201 |
+
idx = 0
|
202 |
+
use_interactive = random.random() < self.interactive_prob
|
203 |
+
|
204 |
+
if use_interactive is False:
|
205 |
+
# Random sample based on speaker using a truncated normal distribution
|
206 |
+
a = torch.tensor([0], dtype=torch.float32)
|
207 |
+
torch.nn.init.trunc_normal_(
|
208 |
+
a,
|
209 |
+
mean=self.max_length // 2,
|
210 |
+
std=self.max_length // 4,
|
211 |
+
a=10,
|
212 |
+
b=self.max_length,
|
213 |
+
)
|
214 |
+
remaining_tokens = a.long().item() - 4
|
215 |
+
else:
|
216 |
+
remaining_tokens = self.max_length
|
217 |
+
|
218 |
+
# Use speaker
|
219 |
+
if isinstance(self.use_speaker, float):
|
220 |
+
use_speaker = random.random() < self.use_speaker
|
221 |
+
else:
|
222 |
+
use_speaker = self.use_speaker
|
223 |
+
|
224 |
+
all_tokens, all_labels = [], []
|
225 |
+
while remaining_tokens > 0 and len(samples) > 0:
|
226 |
+
sentence = samples.pop(0)
|
227 |
+
|
228 |
+
text = random.choice(sentence.texts)
|
229 |
+
text, length = self.tokenize_sentence(text)
|
230 |
+
remaining_tokens -= length + len(sentence.semantics[0].values)
|
231 |
+
|
232 |
+
if use_interactive is False:
|
233 |
+
final_text.append(text)
|
234 |
+
final_semantic.append(sentence.semantics)
|
235 |
+
else:
|
236 |
+
# For interactive mode, we only apply speaker for the first sentence
|
237 |
+
# [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
|
238 |
+
tokens, labels = self.pack_sentences(
|
239 |
+
sentences=[text],
|
240 |
+
semantics=[sentence.semantics],
|
241 |
+
speaker=response.name if use_speaker else None,
|
242 |
+
skip_text=random.random() < self.skip_text_prob,
|
243 |
+
)
|
244 |
+
|
245 |
+
all_tokens.append(tokens)
|
246 |
+
all_labels.append(labels)
|
247 |
+
|
248 |
+
idx += 1
|
249 |
+
|
250 |
+
if use_interactive is False:
|
251 |
+
tokens, labels = self.pack_sentences(
|
252 |
+
final_text,
|
253 |
+
semantics=final_semantic,
|
254 |
+
speaker=response.name if use_speaker else None,
|
255 |
+
)
|
256 |
+
all_tokens.append(tokens)
|
257 |
+
all_labels.append(labels)
|
258 |
+
|
259 |
+
tokens = torch.cat(all_tokens, dim=1)
|
260 |
+
labels = torch.cat(all_labels, dim=1)
|
261 |
+
|
262 |
+
# Verify that the length is correct
|
263 |
+
assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
|
264 |
+
|
265 |
+
data = {"tokens": tokens, "labels": labels}
|
266 |
+
|
267 |
+
return data
|
268 |
+
|
269 |
+
def pack_sentences(
|
270 |
+
self,
|
271 |
+
sentences: list[str],
|
272 |
+
semantics: list,
|
273 |
+
speaker: Optional[str] = None,
|
274 |
+
skip_text: bool = False,
|
275 |
+
):
|
276 |
+
if speaker is None:
|
277 |
+
speaker = "assistant"
|
278 |
+
|
279 |
+
cated_sentences = " ".join(sentences)
|
280 |
+
if skip_text:
|
281 |
+
cated_sentences = "<|skip_text|>"
|
282 |
+
|
283 |
+
final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
|
284 |
+
final_text = final_text + f"<|im_start|>{speaker}\n"
|
285 |
+
|
286 |
+
encoded = self.tokenizer.encode(
|
287 |
+
final_text,
|
288 |
+
add_special_tokens=False,
|
289 |
+
truncation=False,
|
290 |
+
max_length=10**6,
|
291 |
+
)
|
292 |
+
semantic_length = sum([len(i[0].values) for i in semantics])
|
293 |
+
prompt_length = len(encoded)
|
294 |
+
num_codebooks = (
|
295 |
+
len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
|
296 |
+
)
|
297 |
+
|
298 |
+
# Pack the tokens and semantics (add <s> and </s> to semantic tokens)
|
299 |
+
tokens = (
|
300 |
+
encoded
|
301 |
+
+ [self.semantic_token_id] * semantic_length
|
302 |
+
+ self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
|
303 |
+
)
|
304 |
+
|
305 |
+
# Codebook bos/padding: 0, eos: 1
|
306 |
+
codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
|
307 |
+
for segment in semantics:
|
308 |
+
for book_idx, book in zip(range(num_codebooks), segment):
|
309 |
+
for j in book.values:
|
310 |
+
codes[book_idx].append(int(j) + 1)
|
311 |
+
|
312 |
+
for book in codes:
|
313 |
+
book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
|
314 |
+
|
315 |
+
tokens = [tokens] + codes
|
316 |
+
|
317 |
+
tokens = torch.tensor(tokens, dtype=torch.long)
|
318 |
+
labels = tokens.clone()
|
319 |
+
|
320 |
+
if skip_text:
|
321 |
+
# If text is not provided, the sentence is used for condition only, all labels are -100
|
322 |
+
torch.fill_(labels, -100)
|
323 |
+
return tokens, labels
|
324 |
+
|
325 |
+
# Mask out the <s> tokens for semantic, predict semantic tokens only
|
326 |
+
# Since we don't mask out the input tokens, the language modeling still works
|
327 |
+
labels[1:, :prompt_length] = -100
|
328 |
+
|
329 |
+
tokens = tokens[:, :-1]
|
330 |
+
labels = labels[:, 1:]
|
331 |
+
|
332 |
+
# Verify the padding is correct, and the last token is eos
|
333 |
+
assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
|
334 |
+
assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
|
335 |
+
|
336 |
+
return tokens, labels
|
337 |
+
|
338 |
+
|
339 |
+
@dataclass
|
340 |
+
class TextDataCollator:
|
341 |
+
tokenizer: AutoTokenizer
|
342 |
+
max_length: int = 1024
|
343 |
+
|
344 |
+
def __call__(self, examples):
|
345 |
+
if "negative_tokens" in examples:
|
346 |
+
positive_examples = []
|
347 |
+
negative_examples = []
|
348 |
+
|
349 |
+
for i in examples:
|
350 |
+
positive_examples.append(
|
351 |
+
{
|
352 |
+
"tokens": i["tokens"],
|
353 |
+
"labels": i["labels"],
|
354 |
+
}
|
355 |
+
)
|
356 |
+
negative_examples.append(
|
357 |
+
{
|
358 |
+
"tokens": i["negative_tokens"],
|
359 |
+
"labels": i["negative_labels"],
|
360 |
+
}
|
361 |
+
)
|
362 |
+
|
363 |
+
examples = positive_examples + negative_examples
|
364 |
+
|
365 |
+
return self.batchify(examples)
|
366 |
+
|
367 |
+
def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
|
368 |
+
tokens, attention_masks, labels = [], [], []
|
369 |
+
|
370 |
+
# Calculate the max length
|
371 |
+
max_tokens_length = 0
|
372 |
+
for example in examples:
|
373 |
+
max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
|
374 |
+
max_tokens_length = min(max_tokens_length, self.max_length)
|
375 |
+
|
376 |
+
for example in examples:
|
377 |
+
_tokens = example[tokens_key][:, :max_tokens_length]
|
378 |
+
_labels = example[labels_key][:, :max_tokens_length]
|
379 |
+
_attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
|
380 |
+
tokens_length = _tokens.size(1)
|
381 |
+
_attention_mask[:tokens_length] = False
|
382 |
+
|
383 |
+
assert tokens_length == _labels.size(
|
384 |
+
1
|
385 |
+
), f"{tokens_length} != {_labels.size(1)}"
|
386 |
+
|
387 |
+
if tokens_length < max_tokens_length:
|
388 |
+
_tokens = F.pad(
|
389 |
+
_tokens,
|
390 |
+
(0, max_tokens_length - tokens_length),
|
391 |
+
value=self.tokenizer.eos_token_id,
|
392 |
+
)
|
393 |
+
_tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
|
394 |
+
_labels = F.pad(
|
395 |
+
_labels, (0, max_tokens_length - _labels.size(1)), value=-100
|
396 |
+
)
|
397 |
+
|
398 |
+
tokens.append(_tokens)
|
399 |
+
attention_masks.append(_attention_mask)
|
400 |
+
labels.append(_labels)
|
401 |
+
|
402 |
+
tokens = torch.stack(tokens, dim=0)
|
403 |
+
attention_masks = torch.stack(attention_masks, dim=0)
|
404 |
+
labels = torch.stack(labels, dim=0)
|
405 |
+
|
406 |
+
return {
|
407 |
+
"inputs": tokens,
|
408 |
+
"attention_masks": attention_masks,
|
409 |
+
"labels": labels,
|
410 |
+
}
|
411 |
+
|
412 |
+
|
413 |
+
class InterleaveDataset(IterableDataset):
|
414 |
+
def __init__(
|
415 |
+
self,
|
416 |
+
datasets: list[IterableDataset],
|
417 |
+
probabilities: list[float],
|
418 |
+
seed: int = 42,
|
419 |
+
):
|
420 |
+
super().__init__()
|
421 |
+
|
422 |
+
self.datasets = datasets
|
423 |
+
self.probabilities = probabilities
|
424 |
+
self.seed = seed
|
425 |
+
|
426 |
+
def __iter__(self):
|
427 |
+
rng = np.random.default_rng(self.seed)
|
428 |
+
dataset_iterators = [iter(dataset) for dataset in self.datasets]
|
429 |
+
|
430 |
+
while True:
|
431 |
+
# Random choice one
|
432 |
+
dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
|
433 |
+
dataset_iterator = dataset_iterators[dataset_idx]
|
434 |
+
|
435 |
+
try:
|
436 |
+
yield next(dataset_iterator)
|
437 |
+
except StopIteration:
|
438 |
+
# Exhausted, create a new iterator
|
439 |
+
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
|
440 |
+
yield next(dataset_iterators[dataset_idx])
|
441 |
+
|
442 |
+
|
443 |
+
class SemanticDataModule(LightningDataModule):
|
444 |
+
def __init__(
|
445 |
+
self,
|
446 |
+
train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
447 |
+
val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
|
448 |
+
batch_size: int = 32,
|
449 |
+
tokenizer: AutoTokenizer = None,
|
450 |
+
max_length: int = 1024,
|
451 |
+
num_workers: int = 4,
|
452 |
+
):
|
453 |
+
super().__init__()
|
454 |
+
|
455 |
+
self.train_dataset = train_dataset
|
456 |
+
self.val_dataset = val_dataset
|
457 |
+
self.batch_size = batch_size
|
458 |
+
self.tokenizer = tokenizer
|
459 |
+
self.max_length = max_length
|
460 |
+
self.num_workers = num_workers
|
461 |
+
|
462 |
+
def train_dataloader(self):
|
463 |
+
return DataLoader(
|
464 |
+
self.train_dataset,
|
465 |
+
batch_size=self.batch_size,
|
466 |
+
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
467 |
+
num_workers=self.num_workers,
|
468 |
+
persistent_workers=True,
|
469 |
+
)
|
470 |
+
|
471 |
+
def val_dataloader(self):
|
472 |
+
return DataLoader(
|
473 |
+
self.val_dataset,
|
474 |
+
batch_size=self.batch_size,
|
475 |
+
collate_fn=TextDataCollator(self.tokenizer, self.max_length),
|
476 |
+
num_workers=self.num_workers,
|
477 |
+
persistent_workers=True,
|
478 |
+
)
|
479 |
+
|
480 |
+
|
481 |
+
if __name__ == "__main__":
|
482 |
+
from tqdm import tqdm
|
483 |
+
|
484 |
+
ds = AutoTextSemanticInstructionDataset(
|
485 |
+
["data/protos"],
|
486 |
+
tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
|
487 |
+
use_speaker=False,
|
488 |
+
interactive_prob=1.0,
|
489 |
+
skip_text_prob=0.5,
|
490 |
+
)
|
491 |
+
|
492 |
+
for i in ds:
|
493 |
+
print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
|
494 |
+
# i["labels"][0][i["labels"][0] == -100] = 0
|
495 |
+
# print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
|
496 |
+
break
|
fish_speech/datasets/vqgan.py
CHANGED
@@ -28,7 +28,7 @@ class VQGANDataset(Dataset):
|
|
28 |
|
29 |
self.files = [
|
30 |
root / line.strip()
|
31 |
-
for line in filelist.read_text().splitlines()
|
32 |
if line.strip()
|
33 |
]
|
34 |
self.sample_rate = sample_rate
|
@@ -120,6 +120,7 @@ class VQGANDataModule(LightningDataModule):
|
|
120 |
collate_fn=VQGANCollator(),
|
121 |
num_workers=self.num_workers,
|
122 |
shuffle=True,
|
|
|
123 |
)
|
124 |
|
125 |
def val_dataloader(self):
|
@@ -128,6 +129,7 @@ class VQGANDataModule(LightningDataModule):
|
|
128 |
batch_size=self.val_batch_size,
|
129 |
collate_fn=VQGANCollator(),
|
130 |
num_workers=self.num_workers,
|
|
|
131 |
)
|
132 |
|
133 |
|
|
|
28 |
|
29 |
self.files = [
|
30 |
root / line.strip()
|
31 |
+
for line in filelist.read_text(encoding="utf-8").splitlines()
|
32 |
if line.strip()
|
33 |
]
|
34 |
self.sample_rate = sample_rate
|
|
|
120 |
collate_fn=VQGANCollator(),
|
121 |
num_workers=self.num_workers,
|
122 |
shuffle=True,
|
123 |
+
persistent_workers=True,
|
124 |
)
|
125 |
|
126 |
def val_dataloader(self):
|
|
|
129 |
batch_size=self.val_batch_size,
|
130 |
collate_fn=VQGANCollator(),
|
131 |
num_workers=self.num_workers,
|
132 |
+
persistent_workers=True,
|
133 |
)
|
134 |
|
135 |
|
fish_speech/models/text2semantic/__init__.py
CHANGED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .lit_module import TextToSemantic
|
2 |
-
|
3 |
-
__all__ = ["TextToSemantic"]
|
|
|
|
|
|
|
|
fish_speech/models/text2semantic/lit_module.py
CHANGED
@@ -1,110 +1,40 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
from typing import Any, Optional
|
3 |
|
4 |
import lightning as L
|
5 |
-
import loralib as lora
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
9 |
|
10 |
import fish_speech.utils as utils
|
|
|
11 |
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
12 |
|
13 |
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
14 |
|
15 |
|
16 |
-
@dataclass
|
17 |
-
class LoraConfig:
|
18 |
-
r: int
|
19 |
-
lora_alpha: float
|
20 |
-
lora_dropout: float = 0.0
|
21 |
-
|
22 |
-
|
23 |
class TextToSemantic(L.LightningModule):
|
24 |
def __init__(
|
25 |
self,
|
26 |
model: NaiveTransformer,
|
27 |
optimizer: Any,
|
28 |
lr_scheduler: Any,
|
29 |
-
lora_config: Optional[LoraConfig] = None,
|
30 |
-
save_lora_only: bool = False,
|
31 |
-
use_dpo: bool = False,
|
32 |
-
dpo_beta: float = 0.2,
|
33 |
):
|
34 |
super().__init__()
|
35 |
|
36 |
self.model = model
|
37 |
self.optimizer_builder = optimizer
|
38 |
self.lr_scheduler_builder = lr_scheduler
|
39 |
-
self.lora_config = lora_config
|
40 |
-
self.save_lora_only = save_lora_only
|
41 |
-
self.use_dpo = use_dpo # We don't support reference model yet
|
42 |
-
self.dpo_beta = dpo_beta
|
43 |
-
|
44 |
-
if self.lora_config is not None:
|
45 |
-
self.setup_lora()
|
46 |
-
|
47 |
-
def setup_lora(self):
|
48 |
-
# Replace the embedding layer with a LoRA layer
|
49 |
-
self.model.embeddings = lora.Embedding(
|
50 |
-
num_embeddings=self.model.embeddings.num_embeddings,
|
51 |
-
embedding_dim=self.model.embeddings.embedding_dim,
|
52 |
-
padding_idx=self.model.embeddings.padding_idx,
|
53 |
-
r=self.lora_config.r,
|
54 |
-
lora_alpha=self.lora_config.lora_alpha,
|
55 |
-
)
|
56 |
-
|
57 |
-
# Replace output layer with a LoRA layer
|
58 |
-
linears = [(self.model, "output")]
|
59 |
-
|
60 |
-
# Replace all linear layers with LoRA layers
|
61 |
-
for layer in self.model.layers:
|
62 |
-
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
63 |
-
linears.extend(
|
64 |
-
[
|
65 |
-
(layer.feed_forward, "w1"),
|
66 |
-
(layer.feed_forward, "w2"),
|
67 |
-
(layer.feed_forward, "w3"),
|
68 |
-
]
|
69 |
-
)
|
70 |
-
|
71 |
-
if hasattr(self.model, "fast_layers"):
|
72 |
-
# Dual-AR model
|
73 |
-
linears.extend([(self.model, "fast_output")])
|
74 |
-
|
75 |
-
for layer in self.model.fast_layers:
|
76 |
-
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
77 |
-
linears.extend(
|
78 |
-
[
|
79 |
-
(layer.feed_forward, "w1"),
|
80 |
-
(layer.feed_forward, "w2"),
|
81 |
-
(layer.feed_forward, "w3"),
|
82 |
-
]
|
83 |
-
)
|
84 |
-
|
85 |
-
for module, layer in linears:
|
86 |
-
updated_linear = lora.Linear(
|
87 |
-
in_features=getattr(module, layer).in_features,
|
88 |
-
out_features=getattr(module, layer).out_features,
|
89 |
-
bias=getattr(module, layer).bias,
|
90 |
-
r=self.lora_config.r,
|
91 |
-
lora_alpha=self.lora_config.lora_alpha,
|
92 |
-
lora_dropout=self.lora_config.lora_dropout,
|
93 |
-
)
|
94 |
-
setattr(module, layer, updated_linear)
|
95 |
-
|
96 |
-
# Mark only the LoRA layers as trainable
|
97 |
-
lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
|
98 |
|
99 |
def forward(self, x):
|
100 |
return self.model(x)
|
101 |
|
102 |
def on_save_checkpoint(self, checkpoint):
|
103 |
-
if self.lora_config is None or self.save_lora_only is False:
|
104 |
-
return
|
105 |
-
|
106 |
# Save only LoRA parameters
|
107 |
state_dict = checkpoint["state_dict"]
|
|
|
|
|
|
|
|
|
108 |
for name in list(state_dict.keys()):
|
109 |
if "lora" not in name:
|
110 |
state_dict.pop(name)
|
@@ -178,6 +108,11 @@ class TextToSemantic(L.LightningModule):
|
|
178 |
def _step(self, batch, batch_idx, stage: str):
|
179 |
is_train = stage == "train"
|
180 |
|
|
|
|
|
|
|
|
|
|
|
181 |
# Do positive and negative samples in the same batch to speed up training
|
182 |
labels = batch["labels"]
|
183 |
outputs = self.model(
|
@@ -187,92 +122,22 @@ class TextToSemantic(L.LightningModule):
|
|
187 |
token_logits = outputs.token_logits
|
188 |
codebook_logits = outputs.codebook_logits
|
189 |
|
190 |
-
if self.use_dpo:
|
191 |
-
# Firtst half is positive, second half is negative
|
192 |
-
token_logits, negative_token_logits = token_logits.chunk(2)
|
193 |
-
codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
|
194 |
-
labels, negative_labels = labels.chunk(2)
|
195 |
-
|
196 |
# Generate labels
|
197 |
base_loss = F.cross_entropy(
|
198 |
-
token_logits.
|
199 |
labels[:, 0].reshape(-1),
|
200 |
ignore_index=-100,
|
201 |
)
|
202 |
|
203 |
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
204 |
semantic_loss = F.cross_entropy(
|
205 |
-
codebook_logits.
|
206 |
codebook_labels.reshape(-1),
|
207 |
ignore_index=-100,
|
208 |
)
|
209 |
|
210 |
loss = base_loss + semantic_loss
|
211 |
|
212 |
-
# If we use dpo
|
213 |
-
if self.use_dpo:
|
214 |
-
negative_codebook_labels = negative_labels[
|
215 |
-
:, 1 : 1 + self.model.config.num_codebooks
|
216 |
-
].mT
|
217 |
-
|
218 |
-
positive_codebook_logps = self.get_batch_logps(
|
219 |
-
codebook_logits, codebook_labels
|
220 |
-
)
|
221 |
-
negative_codebook_logps = self.get_batch_logps(
|
222 |
-
negative_codebook_logits, negative_codebook_labels
|
223 |
-
)
|
224 |
-
|
225 |
-
# TODO: implement the reference model, avoid screwing up the gradients
|
226 |
-
dpo_loss = -F.logsigmoid(
|
227 |
-
(positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
|
228 |
-
).mean()
|
229 |
-
|
230 |
-
chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
|
231 |
-
rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
|
232 |
-
reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
|
233 |
-
chosen_rewards, rejected_rewards = (
|
234 |
-
chosen_rewards.mean(),
|
235 |
-
rejected_rewards.mean(),
|
236 |
-
)
|
237 |
-
|
238 |
-
loss = loss + dpo_loss
|
239 |
-
|
240 |
-
self.log(
|
241 |
-
f"{stage}/dpo_loss",
|
242 |
-
dpo_loss,
|
243 |
-
on_step=is_train,
|
244 |
-
on_epoch=not is_train,
|
245 |
-
prog_bar=False,
|
246 |
-
logger=True,
|
247 |
-
)
|
248 |
-
|
249 |
-
self.log(
|
250 |
-
f"{stage}/chosen_rewards",
|
251 |
-
chosen_rewards,
|
252 |
-
on_step=is_train,
|
253 |
-
on_epoch=not is_train,
|
254 |
-
prog_bar=False,
|
255 |
-
logger=True,
|
256 |
-
)
|
257 |
-
|
258 |
-
self.log(
|
259 |
-
f"{stage}/rejected_rewards",
|
260 |
-
rejected_rewards,
|
261 |
-
on_step=is_train,
|
262 |
-
on_epoch=not is_train,
|
263 |
-
prog_bar=False,
|
264 |
-
logger=True,
|
265 |
-
)
|
266 |
-
|
267 |
-
self.log(
|
268 |
-
f"{stage}/reward_accuracy",
|
269 |
-
reward_accuracy,
|
270 |
-
on_step=is_train,
|
271 |
-
on_epoch=not is_train,
|
272 |
-
prog_bar=False,
|
273 |
-
logger=True,
|
274 |
-
)
|
275 |
-
|
276 |
self.log(
|
277 |
f"{stage}/loss",
|
278 |
loss,
|
@@ -280,6 +145,7 @@ class TextToSemantic(L.LightningModule):
|
|
280 |
on_epoch=not is_train,
|
281 |
prog_bar=True,
|
282 |
logger=True,
|
|
|
283 |
)
|
284 |
|
285 |
self.log(
|
@@ -289,6 +155,7 @@ class TextToSemantic(L.LightningModule):
|
|
289 |
on_epoch=not is_train,
|
290 |
prog_bar=False,
|
291 |
logger=True,
|
|
|
292 |
)
|
293 |
|
294 |
self.log(
|
@@ -298,6 +165,7 @@ class TextToSemantic(L.LightningModule):
|
|
298 |
on_epoch=not is_train,
|
299 |
prog_bar=False,
|
300 |
logger=True,
|
|
|
301 |
)
|
302 |
|
303 |
# Top-5 accuracy
|
@@ -309,31 +177,21 @@ class TextToSemantic(L.LightningModule):
|
|
309 |
on_epoch=not is_train,
|
310 |
prog_bar=True,
|
311 |
logger=True,
|
|
|
312 |
)
|
313 |
|
314 |
-
if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
|
315 |
-
accuracy = self.get_accuracy(
|
316 |
-
codebook_logits[:, :, : self.model.config.num_in_codebooks],
|
317 |
-
codebook_labels[:, :, : self.model.config.num_in_codebooks],
|
318 |
-
)
|
319 |
-
|
320 |
-
self.log(
|
321 |
-
f"{stage}/top_5_accuracy_in",
|
322 |
-
accuracy,
|
323 |
-
on_step=is_train,
|
324 |
-
on_epoch=not is_train,
|
325 |
-
prog_bar=True,
|
326 |
-
logger=True,
|
327 |
-
)
|
328 |
-
|
329 |
return loss
|
330 |
|
331 |
def get_accuracy(self, logits, labels):
|
|
|
|
|
|
|
|
|
332 |
_, indices = logits.topk(5, dim=-1)
|
333 |
correct = indices.eq(labels.unsqueeze(-1))
|
334 |
-
correct[
|
335 |
correct = correct.sum()
|
336 |
-
accuracy = correct /
|
337 |
|
338 |
return accuracy
|
339 |
|
|
|
|
|
1 |
from typing import Any, Optional
|
2 |
|
3 |
import lightning as L
|
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from lightning.pytorch.utilities.types import OptimizerLRScheduler
|
7 |
|
8 |
import fish_speech.utils as utils
|
9 |
+
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
10 |
from fish_speech.models.text2semantic.llama import NaiveTransformer
|
11 |
|
12 |
log = utils.RankedLogger(__name__, rank_zero_only=True)
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
class TextToSemantic(L.LightningModule):
|
16 |
def __init__(
|
17 |
self,
|
18 |
model: NaiveTransformer,
|
19 |
optimizer: Any,
|
20 |
lr_scheduler: Any,
|
|
|
|
|
|
|
|
|
21 |
):
|
22 |
super().__init__()
|
23 |
|
24 |
self.model = model
|
25 |
self.optimizer_builder = optimizer
|
26 |
self.lr_scheduler_builder = lr_scheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
def forward(self, x):
|
29 |
return self.model(x)
|
30 |
|
31 |
def on_save_checkpoint(self, checkpoint):
|
|
|
|
|
|
|
32 |
# Save only LoRA parameters
|
33 |
state_dict = checkpoint["state_dict"]
|
34 |
+
use_lora = any("lora" in name for name in state_dict.keys())
|
35 |
+
if not use_lora:
|
36 |
+
return
|
37 |
+
|
38 |
for name in list(state_dict.keys()):
|
39 |
if "lora" not in name:
|
40 |
state_dict.pop(name)
|
|
|
108 |
def _step(self, batch, batch_idx, stage: str):
|
109 |
is_train = stage == "train"
|
110 |
|
111 |
+
if is_train:
|
112 |
+
# Key part to make lora work
|
113 |
+
# Otherwise the parameters are merged, which lead to incorrect gradients
|
114 |
+
self.model.train()
|
115 |
+
|
116 |
# Do positive and negative samples in the same batch to speed up training
|
117 |
labels = batch["labels"]
|
118 |
outputs = self.model(
|
|
|
122 |
token_logits = outputs.token_logits
|
123 |
codebook_logits = outputs.codebook_logits
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
# Generate labels
|
126 |
base_loss = F.cross_entropy(
|
127 |
+
token_logits.view(-1, token_logits.size(-1)),
|
128 |
labels[:, 0].reshape(-1),
|
129 |
ignore_index=-100,
|
130 |
)
|
131 |
|
132 |
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
|
133 |
semantic_loss = F.cross_entropy(
|
134 |
+
codebook_logits.view(-1, codebook_logits.size(-1)),
|
135 |
codebook_labels.reshape(-1),
|
136 |
ignore_index=-100,
|
137 |
)
|
138 |
|
139 |
loss = base_loss + semantic_loss
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
self.log(
|
142 |
f"{stage}/loss",
|
143 |
loss,
|
|
|
145 |
on_epoch=not is_train,
|
146 |
prog_bar=True,
|
147 |
logger=True,
|
148 |
+
sync_dist=not is_train,
|
149 |
)
|
150 |
|
151 |
self.log(
|
|
|
155 |
on_epoch=not is_train,
|
156 |
prog_bar=False,
|
157 |
logger=True,
|
158 |
+
sync_dist=not is_train,
|
159 |
)
|
160 |
|
161 |
self.log(
|
|
|
165 |
on_epoch=not is_train,
|
166 |
prog_bar=False,
|
167 |
logger=True,
|
168 |
+
sync_dist=not is_train,
|
169 |
)
|
170 |
|
171 |
# Top-5 accuracy
|
|
|
177 |
on_epoch=not is_train,
|
178 |
prog_bar=True,
|
179 |
logger=True,
|
180 |
+
sync_dist=not is_train,
|
181 |
)
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
return loss
|
184 |
|
185 |
def get_accuracy(self, logits, labels):
|
186 |
+
mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
|
187 |
+
if mask.sum() == 0:
|
188 |
+
return torch.tensor(0.0, device=logits.device)
|
189 |
+
|
190 |
_, indices = logits.topk(5, dim=-1)
|
191 |
correct = indices.eq(labels.unsqueeze(-1))
|
192 |
+
correct[~mask] = 0
|
193 |
correct = correct.sum()
|
194 |
+
accuracy = correct / mask.sum()
|
195 |
|
196 |
return accuracy
|
197 |
|
fish_speech/models/text2semantic/llama.py
CHANGED
@@ -1,13 +1,25 @@
|
|
|
|
1 |
import math
|
2 |
from dataclasses import dataclass
|
|
|
3 |
from typing import Optional
|
4 |
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
from einops import rearrange
|
|
|
8 |
from torch import Tensor
|
9 |
from torch.nn import functional as F
|
|
|
10 |
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def find_multiple(n: int, k: int) -> int:
|
@@ -18,6 +30,8 @@ def find_multiple(n: int, k: int) -> int:
|
|
18 |
|
19 |
@dataclass
|
20 |
class BaseModelArgs:
|
|
|
|
|
21 |
vocab_size: int = 32000
|
22 |
n_layer: int = 32
|
23 |
n_head: int = 32
|
@@ -29,16 +43,19 @@ class BaseModelArgs:
|
|
29 |
norm_eps: float = 1e-5
|
30 |
max_seq_len: int = 2048
|
31 |
dropout: float = 0.0
|
|
|
|
|
32 |
|
33 |
# Codebook configs
|
34 |
codebook_size: int = 160
|
35 |
num_codebooks: int = 4
|
36 |
-
num_in_codebooks: Optional[int] = None
|
37 |
-
codebook_padding_idx: int = 0
|
38 |
|
39 |
# Gradient checkpointing
|
40 |
use_gradient_checkpointing: bool = True
|
41 |
|
|
|
|
|
|
|
42 |
def __post_init__(self):
|
43 |
if self.n_local_heads == -1:
|
44 |
self.n_local_heads = self.n_head
|
@@ -46,18 +63,41 @@ class BaseModelArgs:
|
|
46 |
hidden_dim = 4 * self.dim
|
47 |
n_hidden = int(2 * hidden_dim / 3)
|
48 |
self.intermediate_size = find_multiple(n_hidden, 256)
|
49 |
-
if self.num_in_codebooks is None:
|
50 |
-
self.num_in_codebooks = self.num_codebooks
|
51 |
self.head_dim = self.dim // self.n_head
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
@dataclass
|
55 |
class NaiveModelArgs(BaseModelArgs):
|
56 |
-
|
57 |
|
58 |
|
59 |
@dataclass
|
60 |
class DualARModelArgs(BaseModelArgs):
|
|
|
61 |
n_fast_layer: int = 4
|
62 |
|
63 |
|
@@ -95,24 +135,35 @@ class BaseTransformerForwardResult:
|
|
95 |
|
96 |
|
97 |
class BaseTransformer(nn.Module):
|
98 |
-
def __init__(
|
|
|
|
|
99 |
super().__init__()
|
100 |
self.config = config
|
|
|
|
|
|
|
101 |
|
102 |
# Slow transformer
|
103 |
self.embeddings = nn.Embedding(
|
104 |
-
config.vocab_size
|
|
|
|
|
|
|
|
|
105 |
config.dim,
|
106 |
)
|
107 |
self.layers = nn.ModuleList(
|
108 |
TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
|
109 |
)
|
110 |
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
|
117 |
self.register_buffer(
|
118 |
"freqs_cis",
|
@@ -139,6 +190,9 @@ class BaseTransformer(nn.Module):
|
|
139 |
self.max_batch_size = -1
|
140 |
self.max_seq_len = -1
|
141 |
|
|
|
|
|
|
|
142 |
def setup_caches(
|
143 |
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
144 |
):
|
@@ -161,11 +215,9 @@ class BaseTransformer(nn.Module):
|
|
161 |
|
162 |
def embed(self, x: Tensor) -> Tensor:
|
163 |
vocab_embeds = [self.embeddings(x[:, 0])]
|
164 |
-
for i in range(self.config.
|
165 |
-
emb = self.
|
166 |
-
|
167 |
-
)
|
168 |
-
emb[x[:, i + 1] == self.config.codebook_padding_idx] = 0
|
169 |
vocab_embeds.append(emb)
|
170 |
|
171 |
x = torch.stack(vocab_embeds, dim=3)
|
@@ -174,21 +226,23 @@ class BaseTransformer(nn.Module):
|
|
174 |
return x
|
175 |
|
176 |
def forward(
|
177 |
-
self,
|
|
|
|
|
178 |
) -> BaseTransformerForwardResult:
|
179 |
-
# x: (batch, num_codebooks + 1, seq_len)
|
180 |
seq_len = inp.size(2)
|
181 |
|
182 |
# Here we want to merge the embeddings of the codebooks
|
183 |
x = self.embed(inp)
|
184 |
|
185 |
-
mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
|
186 |
freqs_cis = self.freqs_cis[:seq_len]
|
187 |
|
188 |
# Not that the causal mask here follows the definition of scaled_dot_product_attention
|
189 |
# That is, FALSE means masked out
|
190 |
# To maintain consistency, key_padding_mask use TRUE to mask out
|
|
|
191 |
if key_padding_mask is not None:
|
|
|
192 |
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
193 |
|
194 |
for layer in self.layers:
|
@@ -199,7 +253,11 @@ class BaseTransformer(nn.Module):
|
|
199 |
|
200 |
# We got slow_out here
|
201 |
slow_out = self.norm(x)
|
202 |
-
|
|
|
|
|
|
|
|
|
203 |
|
204 |
return BaseTransformerForwardResult(
|
205 |
logits=token_logits,
|
@@ -207,7 +265,10 @@ class BaseTransformer(nn.Module):
|
|
207 |
)
|
208 |
|
209 |
def forward_generate(
|
210 |
-
self,
|
|
|
|
|
|
|
211 |
) -> BaseTransformerForwardResult:
|
212 |
# This is used for generation, optimized for torch compile
|
213 |
assert (
|
@@ -225,22 +286,117 @@ class BaseTransformer(nn.Module):
|
|
225 |
x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
226 |
|
227 |
# If prefill, we only calculate the logits of last token
|
228 |
-
if x.size(1) > 1:
|
229 |
x = x[:, -1:]
|
230 |
|
231 |
# We got slow_out here
|
232 |
slow_out = self.norm(x)
|
233 |
-
|
|
|
|
|
|
|
|
|
234 |
|
235 |
return BaseTransformerForwardResult(
|
236 |
logits=token_logits,
|
237 |
hidden_states=x,
|
238 |
)
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
class NaiveTransformer(BaseTransformer):
|
242 |
-
def __init__(self, config: NaiveModelArgs) -> None:
|
243 |
-
super().__init__(config)
|
244 |
|
245 |
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
246 |
self.codebook_output = nn.Linear(
|
@@ -249,6 +405,8 @@ class NaiveTransformer(BaseTransformer):
|
|
249 |
bias=False,
|
250 |
)
|
251 |
|
|
|
|
|
252 |
def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
|
253 |
token_logits = result.logits
|
254 |
x = result.hidden_states
|
@@ -265,9 +423,14 @@ class NaiveTransformer(BaseTransformer):
|
|
265 |
)
|
266 |
|
267 |
def forward(
|
268 |
-
self,
|
|
|
|
|
269 |
) -> TransformerForwardResult:
|
270 |
-
result = super().forward(
|
|
|
|
|
|
|
271 |
return self.decode(result)
|
272 |
|
273 |
def forward_generate(
|
@@ -278,13 +441,11 @@ class NaiveTransformer(BaseTransformer):
|
|
278 |
|
279 |
|
280 |
class DualARTransformer(BaseTransformer):
|
281 |
-
def __init__(self, config:
|
282 |
-
super().__init__(config)
|
283 |
|
284 |
# Fast transformer
|
285 |
-
self.fast_embeddings = nn.Embedding(
|
286 |
-
config.codebook_size, config.dim, padding_idx=config.codebook_padding_idx
|
287 |
-
)
|
288 |
|
289 |
# The equivalent bs is so large that sdpa doesn't work
|
290 |
self.fast_layers = nn.ModuleList(
|
@@ -297,6 +458,8 @@ class DualARTransformer(BaseTransformer):
|
|
297 |
bias=False,
|
298 |
)
|
299 |
|
|
|
|
|
300 |
def setup_caches(
|
301 |
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
302 |
):
|
@@ -316,7 +479,9 @@ class DualARTransformer(BaseTransformer):
|
|
316 |
)
|
317 |
|
318 |
def forward(
|
319 |
-
self,
|
|
|
|
|
320 |
) -> TransformerForwardResult:
|
321 |
parent_result = super().forward(inp, key_padding_mask)
|
322 |
token_logits = parent_result.logits
|
@@ -331,7 +496,7 @@ class DualARTransformer(BaseTransformer):
|
|
331 |
|
332 |
# Drop the last token and rotate left
|
333 |
codebooks = inp[:, 1:-1, 1:]
|
334 |
-
codebooks = F.pad(codebooks, (0, 1), value=
|
335 |
codebook_embeddings = self.fast_embeddings(codebooks)
|
336 |
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
|
337 |
b, s = x.size(0), x.size(2)
|
@@ -339,7 +504,12 @@ class DualARTransformer(BaseTransformer):
|
|
339 |
|
340 |
# Remove padded part
|
341 |
codebooks = rearrange(codebooks, "b n s -> (b s) n")
|
342 |
-
codebook_mask = (codebooks ==
|
|
|
|
|
|
|
|
|
|
|
343 |
x_bs, x_len = x.size(0), x.size(1)
|
344 |
x = x[~codebook_mask]
|
345 |
|
@@ -422,7 +592,9 @@ class Attention(nn.Module):
|
|
422 |
|
423 |
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
424 |
# key, query, value projections for all heads, but in a batch
|
425 |
-
self.wqkv = nn.Linear(
|
|
|
|
|
426 |
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
427 |
self.kv_cache = None
|
428 |
|
@@ -469,13 +641,24 @@ class Attention(nn.Module):
|
|
469 |
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
470 |
|
471 |
if self.use_sdpa:
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
else:
|
480 |
y = self.eq_scaled_dot_product_attention(
|
481 |
q,
|
@@ -567,29 +750,3 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
|
567 |
|
568 |
x_out2 = x_out2.flatten(3)
|
569 |
return x_out2.type_as(x)
|
570 |
-
|
571 |
-
|
572 |
-
if __name__ == "__main__":
|
573 |
-
args = DualARModelArgs(
|
574 |
-
max_seq_len=4096,
|
575 |
-
vocab_size=32312,
|
576 |
-
n_layer=12,
|
577 |
-
n_fast_layer=4,
|
578 |
-
n_head=12,
|
579 |
-
dim=768,
|
580 |
-
rope_base=10000,
|
581 |
-
norm_eps=1e-5,
|
582 |
-
codebook_size=128,
|
583 |
-
num_codebooks=4,
|
584 |
-
)
|
585 |
-
|
586 |
-
model = DualARTransformer(args)
|
587 |
-
model = model.cuda().bfloat16()
|
588 |
-
print("Total params:", sum(i.numel() for i in model.parameters()) / 1024 / 1024)
|
589 |
-
|
590 |
-
inputs = torch.randint(0, 100, (2, 5, 128)).cuda()
|
591 |
-
key_padding_mask = torch.zeros(2, 128).bool().cuda()
|
592 |
-
key_padding_mask[0, 2:] = True
|
593 |
-
x1 = model(inputs, key_padding_mask=key_padding_mask)
|
594 |
-
print(x1.token_logits.shape)
|
595 |
-
print(x1.codebook_logits.shape)
|
|
|
1 |
+
import json
|
2 |
import math
|
3 |
from dataclasses import dataclass
|
4 |
+
from pathlib import Path
|
5 |
from typing import Optional
|
6 |
|
7 |
import torch
|
8 |
import torch.nn as nn
|
9 |
from einops import rearrange
|
10 |
+
from loguru import logger
|
11 |
from torch import Tensor
|
12 |
from torch.nn import functional as F
|
13 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
14 |
from torch.utils.checkpoint import checkpoint
|
15 |
+
from transformers import AutoTokenizer
|
16 |
+
|
17 |
+
from fish_speech.conversation import SEMANTIC_TOKEN
|
18 |
+
from fish_speech.utils import RankedLogger
|
19 |
+
|
20 |
+
from .lora import LoraConfig, setup_lora
|
21 |
+
|
22 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
23 |
|
24 |
|
25 |
def find_multiple(n: int, k: int) -> int:
|
|
|
30 |
|
31 |
@dataclass
|
32 |
class BaseModelArgs:
|
33 |
+
model_type: str = "base"
|
34 |
+
|
35 |
vocab_size: int = 32000
|
36 |
n_layer: int = 32
|
37 |
n_head: int = 32
|
|
|
43 |
norm_eps: float = 1e-5
|
44 |
max_seq_len: int = 2048
|
45 |
dropout: float = 0.0
|
46 |
+
tie_word_embeddings: bool = True
|
47 |
+
attention_qkv_bias: bool = False
|
48 |
|
49 |
# Codebook configs
|
50 |
codebook_size: int = 160
|
51 |
num_codebooks: int = 4
|
|
|
|
|
52 |
|
53 |
# Gradient checkpointing
|
54 |
use_gradient_checkpointing: bool = True
|
55 |
|
56 |
+
# Initialize the model
|
57 |
+
initializer_range: float = 0.02
|
58 |
+
|
59 |
def __post_init__(self):
|
60 |
if self.n_local_heads == -1:
|
61 |
self.n_local_heads = self.n_head
|
|
|
63 |
hidden_dim = 4 * self.dim
|
64 |
n_hidden = int(2 * hidden_dim / 3)
|
65 |
self.intermediate_size = find_multiple(n_hidden, 256)
|
|
|
|
|
66 |
self.head_dim = self.dim // self.n_head
|
67 |
|
68 |
+
@staticmethod
|
69 |
+
def from_pretrained(path: str):
|
70 |
+
path = Path(path)
|
71 |
+
|
72 |
+
if path.is_dir():
|
73 |
+
path = path / "config.json"
|
74 |
+
|
75 |
+
with open(path, "r", encoding="utf-8") as f:
|
76 |
+
data = json.load(f)
|
77 |
+
|
78 |
+
match data["model_type"]:
|
79 |
+
case "naive":
|
80 |
+
cls = NaiveModelArgs
|
81 |
+
case "dual_ar":
|
82 |
+
cls = DualARModelArgs
|
83 |
+
case _:
|
84 |
+
raise ValueError(f"Unknown model type: {data['model_type']}")
|
85 |
+
|
86 |
+
return cls(**data)
|
87 |
+
|
88 |
+
def save(self, path: str):
|
89 |
+
with open(path, "w") as f:
|
90 |
+
json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
|
91 |
+
|
92 |
|
93 |
@dataclass
|
94 |
class NaiveModelArgs(BaseModelArgs):
|
95 |
+
model_type: str = "naive"
|
96 |
|
97 |
|
98 |
@dataclass
|
99 |
class DualARModelArgs(BaseModelArgs):
|
100 |
+
model_type: str = "dual_ar"
|
101 |
n_fast_layer: int = 4
|
102 |
|
103 |
|
|
|
135 |
|
136 |
|
137 |
class BaseTransformer(nn.Module):
|
138 |
+
def __init__(
|
139 |
+
self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
|
140 |
+
) -> None:
|
141 |
super().__init__()
|
142 |
self.config = config
|
143 |
+
self.tokenizer = tokenizer
|
144 |
+
|
145 |
+
self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
|
146 |
|
147 |
# Slow transformer
|
148 |
self.embeddings = nn.Embedding(
|
149 |
+
config.vocab_size,
|
150 |
+
config.dim,
|
151 |
+
)
|
152 |
+
self.codebook_embeddings = nn.Embedding(
|
153 |
+
config.codebook_size * config.num_codebooks,
|
154 |
config.dim,
|
155 |
)
|
156 |
self.layers = nn.ModuleList(
|
157 |
TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
|
158 |
)
|
159 |
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
160 |
+
|
161 |
+
if self.config.tie_word_embeddings is False:
|
162 |
+
self.output = nn.Linear(
|
163 |
+
config.dim,
|
164 |
+
config.vocab_size,
|
165 |
+
bias=False,
|
166 |
+
)
|
167 |
|
168 |
self.register_buffer(
|
169 |
"freqs_cis",
|
|
|
190 |
self.max_batch_size = -1
|
191 |
self.max_seq_len = -1
|
192 |
|
193 |
+
if init_weights:
|
194 |
+
self.apply(self._init_weights)
|
195 |
+
|
196 |
def setup_caches(
|
197 |
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
198 |
):
|
|
|
215 |
|
216 |
def embed(self, x: Tensor) -> Tensor:
|
217 |
vocab_embeds = [self.embeddings(x[:, 0])]
|
218 |
+
for i in range(self.config.num_codebooks):
|
219 |
+
emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
|
220 |
+
emb[x[:, 0] != self.semantic_token_id] = 0
|
|
|
|
|
221 |
vocab_embeds.append(emb)
|
222 |
|
223 |
x = torch.stack(vocab_embeds, dim=3)
|
|
|
226 |
return x
|
227 |
|
228 |
def forward(
|
229 |
+
self,
|
230 |
+
inp: Tensor,
|
231 |
+
key_padding_mask: Optional[Tensor] = None,
|
232 |
) -> BaseTransformerForwardResult:
|
|
|
233 |
seq_len = inp.size(2)
|
234 |
|
235 |
# Here we want to merge the embeddings of the codebooks
|
236 |
x = self.embed(inp)
|
237 |
|
|
|
238 |
freqs_cis = self.freqs_cis[:seq_len]
|
239 |
|
240 |
# Not that the causal mask here follows the definition of scaled_dot_product_attention
|
241 |
# That is, FALSE means masked out
|
242 |
# To maintain consistency, key_padding_mask use TRUE to mask out
|
243 |
+
mask = None
|
244 |
if key_padding_mask is not None:
|
245 |
+
mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
|
246 |
mask = mask & key_padding_mask[:, None, None, :].logical_not()
|
247 |
|
248 |
for layer in self.layers:
|
|
|
253 |
|
254 |
# We got slow_out here
|
255 |
slow_out = self.norm(x)
|
256 |
+
|
257 |
+
if self.config.tie_word_embeddings:
|
258 |
+
token_logits = F.linear(slow_out, self.embeddings.weight)
|
259 |
+
else:
|
260 |
+
token_logits = self.output(slow_out)
|
261 |
|
262 |
return BaseTransformerForwardResult(
|
263 |
logits=token_logits,
|
|
|
265 |
)
|
266 |
|
267 |
def forward_generate(
|
268 |
+
self,
|
269 |
+
x: Tensor,
|
270 |
+
input_pos: Optional[Tensor] = None,
|
271 |
+
return_all: bool = False,
|
272 |
) -> BaseTransformerForwardResult:
|
273 |
# This is used for generation, optimized for torch compile
|
274 |
assert (
|
|
|
286 |
x = layer(x, freqs_cis, mask, input_pos=input_pos)
|
287 |
|
288 |
# If prefill, we only calculate the logits of last token
|
289 |
+
if x.size(1) > 1 and not return_all:
|
290 |
x = x[:, -1:]
|
291 |
|
292 |
# We got slow_out here
|
293 |
slow_out = self.norm(x)
|
294 |
+
|
295 |
+
if self.config.tie_word_embeddings:
|
296 |
+
token_logits = F.linear(slow_out, self.embeddings.weight)
|
297 |
+
else:
|
298 |
+
token_logits = self.output(slow_out)
|
299 |
|
300 |
return BaseTransformerForwardResult(
|
301 |
logits=token_logits,
|
302 |
hidden_states=x,
|
303 |
)
|
304 |
|
305 |
+
def _init_weights(self, module):
|
306 |
+
std = self.config.initializer_range
|
307 |
+
if isinstance(module, nn.Linear):
|
308 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
309 |
+
if module.bias is not None:
|
310 |
+
module.bias.data.zero_()
|
311 |
+
elif isinstance(module, nn.Embedding):
|
312 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
313 |
+
if module.padding_idx is not None:
|
314 |
+
module.weight.data[module.padding_idx].zero_()
|
315 |
+
|
316 |
+
@staticmethod
|
317 |
+
def from_pretrained(
|
318 |
+
path: str,
|
319 |
+
load_weights: bool = False,
|
320 |
+
max_length: int | None = None,
|
321 |
+
lora_config: LoraConfig | None = None,
|
322 |
+
rope_base: int | None = None,
|
323 |
+
) -> "BaseTransformer":
|
324 |
+
config = BaseModelArgs.from_pretrained(str(path))
|
325 |
+
if max_length is not None:
|
326 |
+
config.max_seq_len = max_length
|
327 |
+
log.info(f"Override max_seq_len to {max_length}")
|
328 |
+
|
329 |
+
if rope_base is not None:
|
330 |
+
config.rope_base = rope_base
|
331 |
+
log.info(f"Override rope_base to {rope_base}")
|
332 |
+
|
333 |
+
match config.model_type:
|
334 |
+
case "naive":
|
335 |
+
model_cls = NaiveTransformer
|
336 |
+
case "dual_ar":
|
337 |
+
model_cls = DualARTransformer
|
338 |
+
case _:
|
339 |
+
raise ValueError(f"Unknown model type: {config.model_type}")
|
340 |
+
|
341 |
+
tokenizer = AutoTokenizer.from_pretrained(str(path))
|
342 |
+
log.info(f"Loading model from {path}, config: {config}")
|
343 |
+
model = model_cls(config, tokenizer=tokenizer)
|
344 |
+
|
345 |
+
if lora_config is not None:
|
346 |
+
setup_lora(model, lora_config)
|
347 |
+
log.info(f"LoRA setup: {lora_config}")
|
348 |
+
|
349 |
+
if load_weights is False:
|
350 |
+
log.info("Randomly initialized model")
|
351 |
+
else:
|
352 |
+
|
353 |
+
if "int8" in str(Path(path)):
|
354 |
+
logger.info("Using int8 weight-only quantization!")
|
355 |
+
from tools.llama.quantize import WeightOnlyInt8QuantHandler
|
356 |
+
|
357 |
+
simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
358 |
+
model = simple_quantizer.convert_for_runtime()
|
359 |
+
|
360 |
+
if "int4" in str(Path(path)):
|
361 |
+
logger.info("Using int4 quantization!")
|
362 |
+
path_comps = path.name.split("-")
|
363 |
+
assert path_comps[-2].startswith("g")
|
364 |
+
groupsize = int(path_comps[-2][1:])
|
365 |
+
from tools.llama.quantize import WeightOnlyInt4QuantHandler
|
366 |
+
|
367 |
+
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
368 |
+
model = simple_quantizer.convert_for_runtime()
|
369 |
+
|
370 |
+
weights = torch.load(
|
371 |
+
Path(path) / "model.pth", map_location="cpu", mmap=True
|
372 |
+
)
|
373 |
+
err = model.load_state_dict(weights, strict=False, assign=True)
|
374 |
+
log.info(f"Loaded weights with error: {err}")
|
375 |
+
|
376 |
+
return model
|
377 |
+
|
378 |
+
def save_pretrained(self, path: str, drop_lora: bool = False):
|
379 |
+
path = Path(path)
|
380 |
+
path.mkdir(parents=True, exist_ok=True)
|
381 |
+
|
382 |
+
self.config.save(path / "config.json")
|
383 |
+
state_dict = self.state_dict()
|
384 |
+
|
385 |
+
if drop_lora:
|
386 |
+
for key in list(state_dict.keys()):
|
387 |
+
if "lora" not in key:
|
388 |
+
continue
|
389 |
+
|
390 |
+
state_dict.pop(key)
|
391 |
+
log.info(f"Drop LoRA parameter: {key}")
|
392 |
+
|
393 |
+
torch.save(state_dict, path / "model.pth")
|
394 |
+
self.tokenizer.save_pretrained(path)
|
395 |
+
|
396 |
|
397 |
class NaiveTransformer(BaseTransformer):
|
398 |
+
def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
|
399 |
+
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
400 |
|
401 |
self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
|
402 |
self.codebook_output = nn.Linear(
|
|
|
405 |
bias=False,
|
406 |
)
|
407 |
|
408 |
+
self.apply(self._init_weights)
|
409 |
+
|
410 |
def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
|
411 |
token_logits = result.logits
|
412 |
x = result.hidden_states
|
|
|
423 |
)
|
424 |
|
425 |
def forward(
|
426 |
+
self,
|
427 |
+
inp: Tensor,
|
428 |
+
key_padding_mask: Optional[Tensor] = None,
|
429 |
) -> TransformerForwardResult:
|
430 |
+
result = super().forward(
|
431 |
+
inp=inp,
|
432 |
+
key_padding_mask=key_padding_mask,
|
433 |
+
)
|
434 |
return self.decode(result)
|
435 |
|
436 |
def forward_generate(
|
|
|
441 |
|
442 |
|
443 |
class DualARTransformer(BaseTransformer):
|
444 |
+
def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
|
445 |
+
super().__init__(config, init_weights=False, tokenizer=tokenizer)
|
446 |
|
447 |
# Fast transformer
|
448 |
+
self.fast_embeddings = nn.Embedding(config.codebook_size, config.dim)
|
|
|
|
|
449 |
|
450 |
# The equivalent bs is so large that sdpa doesn't work
|
451 |
self.fast_layers = nn.ModuleList(
|
|
|
458 |
bias=False,
|
459 |
)
|
460 |
|
461 |
+
self.apply(self._init_weights)
|
462 |
+
|
463 |
def setup_caches(
|
464 |
self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
|
465 |
):
|
|
|
479 |
)
|
480 |
|
481 |
def forward(
|
482 |
+
self,
|
483 |
+
inp: Tensor,
|
484 |
+
key_padding_mask: Optional[Tensor] = None,
|
485 |
) -> TransformerForwardResult:
|
486 |
parent_result = super().forward(inp, key_padding_mask)
|
487 |
token_logits = parent_result.logits
|
|
|
496 |
|
497 |
# Drop the last token and rotate left
|
498 |
codebooks = inp[:, 1:-1, 1:]
|
499 |
+
codebooks = F.pad(codebooks, (0, 1), value=0)
|
500 |
codebook_embeddings = self.fast_embeddings(codebooks)
|
501 |
x = torch.cat([x[:, None], codebook_embeddings], dim=1)
|
502 |
b, s = x.size(0), x.size(2)
|
|
|
504 |
|
505 |
# Remove padded part
|
506 |
codebooks = rearrange(codebooks, "b n s -> (b s) n")
|
507 |
+
codebook_mask = (codebooks == 0).all(dim=-1)
|
508 |
+
|
509 |
+
if torch.all(codebook_mask):
|
510 |
+
# If all codebooks are padded, we keep first 8 to make sure the model runs
|
511 |
+
codebook_mask[:8] = False
|
512 |
+
|
513 |
x_bs, x_len = x.size(0), x.size(1)
|
514 |
x = x[~codebook_mask]
|
515 |
|
|
|
592 |
|
593 |
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
594 |
# key, query, value projections for all heads, but in a batch
|
595 |
+
self.wqkv = nn.Linear(
|
596 |
+
config.dim, total_head_dim, bias=config.attention_qkv_bias
|
597 |
+
)
|
598 |
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
599 |
self.kv_cache = None
|
600 |
|
|
|
641 |
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
642 |
|
643 |
if self.use_sdpa:
|
644 |
+
if mask is None:
|
645 |
+
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
646 |
+
y = F.scaled_dot_product_attention(
|
647 |
+
q,
|
648 |
+
k,
|
649 |
+
v,
|
650 |
+
dropout_p=self.dropout if self.training else 0.0,
|
651 |
+
is_causal=True,
|
652 |
+
# No third party attn_mask here to use flash_attention
|
653 |
+
)
|
654 |
+
else:
|
655 |
+
y = F.scaled_dot_product_attention(
|
656 |
+
q,
|
657 |
+
k,
|
658 |
+
v,
|
659 |
+
attn_mask=mask,
|
660 |
+
dropout_p=self.dropout if self.training else 0.0,
|
661 |
+
)
|
662 |
else:
|
663 |
y = self.eq_scaled_dot_product_attention(
|
664 |
q,
|
|
|
750 |
|
751 |
x_out2 = x_out2.flatten(3)
|
752 |
return x_out2.type_as(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fish_speech/models/text2semantic/lora.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import loralib as lora
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class LoraConfig:
|
8 |
+
r: int
|
9 |
+
lora_alpha: float
|
10 |
+
lora_dropout: float = 0.0
|
11 |
+
|
12 |
+
|
13 |
+
def setup_lora(model, lora_config):
|
14 |
+
# Replace the embedding layer with a LoRA layer
|
15 |
+
model.embeddings = lora.Embedding(
|
16 |
+
num_embeddings=model.embeddings.num_embeddings,
|
17 |
+
embedding_dim=model.embeddings.embedding_dim,
|
18 |
+
padding_idx=model.embeddings.padding_idx,
|
19 |
+
r=lora_config.r,
|
20 |
+
lora_alpha=lora_config.lora_alpha,
|
21 |
+
)
|
22 |
+
|
23 |
+
model.codebook_embeddings = lora.Embedding(
|
24 |
+
num_embeddings=model.codebook_embeddings.num_embeddings,
|
25 |
+
embedding_dim=model.codebook_embeddings.embedding_dim,
|
26 |
+
padding_idx=model.codebook_embeddings.padding_idx,
|
27 |
+
r=lora_config.r,
|
28 |
+
lora_alpha=lora_config.lora_alpha,
|
29 |
+
)
|
30 |
+
|
31 |
+
# Replace output layer with a LoRA layer
|
32 |
+
linears = [(model, "output")]
|
33 |
+
|
34 |
+
# Replace all linear layers with LoRA layers
|
35 |
+
for layer in model.layers:
|
36 |
+
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
37 |
+
linears.extend(
|
38 |
+
[
|
39 |
+
(layer.feed_forward, "w1"),
|
40 |
+
(layer.feed_forward, "w2"),
|
41 |
+
(layer.feed_forward, "w3"),
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
if hasattr(model, "fast_layers"):
|
46 |
+
model.fast_embeddings = lora.Embedding(
|
47 |
+
num_embeddings=model.fast_embeddings.num_embeddings,
|
48 |
+
embedding_dim=model.fast_embeddings.embedding_dim,
|
49 |
+
padding_idx=model.fast_embeddings.padding_idx,
|
50 |
+
r=lora_config.r,
|
51 |
+
lora_alpha=lora_config.lora_alpha,
|
52 |
+
)
|
53 |
+
|
54 |
+
# Dual-AR model
|
55 |
+
linears.append((model, "fast_output"))
|
56 |
+
|
57 |
+
for layer in model.fast_layers:
|
58 |
+
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
|
59 |
+
linears.extend(
|
60 |
+
[
|
61 |
+
(layer.feed_forward, "w1"),
|
62 |
+
(layer.feed_forward, "w2"),
|
63 |
+
(layer.feed_forward, "w3"),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
for module, layer in linears:
|
68 |
+
updated_linear = lora.Linear(
|
69 |
+
in_features=getattr(module, layer).in_features,
|
70 |
+
out_features=getattr(module, layer).out_features,
|
71 |
+
bias=getattr(module, layer).bias,
|
72 |
+
r=lora_config.r,
|
73 |
+
lora_alpha=lora_config.lora_alpha,
|
74 |
+
lora_dropout=lora_config.lora_dropout,
|
75 |
+
)
|
76 |
+
setattr(module, layer, updated_linear)
|
77 |
+
|
78 |
+
# Mark only the LoRA layers as trainable
|
79 |
+
lora.mark_only_lora_as_trainable(model, bias="none")
|
80 |
+
|
81 |
+
|
82 |
+
def get_merged_state_dict(model):
|
83 |
+
# This line will merge the state dict of the model and the LoRA parameters
|
84 |
+
model.eval()
|
85 |
+
|
86 |
+
# Then we need to remove the LoRA parameters from the state dict
|
87 |
+
state_dict = model.state_dict()
|
88 |
+
for name in list(state_dict.keys()):
|
89 |
+
if "lora" in name:
|
90 |
+
state_dict.pop(name)
|
91 |
+
|
92 |
+
return state_dict
|
fish_speech/models/vqgan/modules/firefly.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
# A inference only version of the FireflyGAN model
|
2 |
|
|
|
3 |
from functools import partial
|
4 |
from math import prod
|
5 |
from typing import Callable
|
@@ -13,6 +14,8 @@ from torch.nn.utils.parametrizations import weight_norm
|
|
13 |
from torch.nn.utils.parametrize import remove_parametrizations
|
14 |
from torch.utils.checkpoint import checkpoint
|
15 |
|
|
|
|
|
16 |
|
17 |
def init_weights(m, mean=0.0, std=0.01):
|
18 |
classname = m.__class__.__name__
|
@@ -474,6 +477,89 @@ class ConvNeXtEncoder(nn.Module):
|
|
474 |
return self.norm(x)
|
475 |
|
476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
class FireflyBase(nn.Module):
|
478 |
def __init__(self, ckpt_path: str = None, pretrained: bool = True):
|
479 |
super().__init__()
|
@@ -500,11 +586,12 @@ class FireflyBase(nn.Module):
|
|
500 |
)
|
501 |
|
502 |
if ckpt_path is not None:
|
503 |
-
|
504 |
elif pretrained:
|
505 |
state_dict = torch.hub.load_state_dict_from_url(
|
506 |
"https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
|
507 |
map_location="cpu",
|
|
|
508 |
)
|
509 |
|
510 |
if "state_dict" in state_dict:
|
|
|
1 |
# A inference only version of the FireflyGAN model
|
2 |
|
3 |
+
import math
|
4 |
from functools import partial
|
5 |
from math import prod
|
6 |
from typing import Callable
|
|
|
14 |
from torch.nn.utils.parametrize import remove_parametrizations
|
15 |
from torch.utils.checkpoint import checkpoint
|
16 |
|
17 |
+
from fish_speech.models.vqgan.utils import sequence_mask
|
18 |
+
|
19 |
|
20 |
def init_weights(m, mean=0.0, std=0.01):
|
21 |
classname = m.__class__.__name__
|
|
|
477 |
return self.norm(x)
|
478 |
|
479 |
|
480 |
+
class FireflyArchitecture(nn.Module):
|
481 |
+
def __init__(
|
482 |
+
self,
|
483 |
+
backbone: nn.Module,
|
484 |
+
head: nn.Module,
|
485 |
+
quantizer: nn.Module,
|
486 |
+
spec_transform: nn.Module,
|
487 |
+
):
|
488 |
+
super().__init__()
|
489 |
+
|
490 |
+
self.backbone = backbone
|
491 |
+
self.head = head
|
492 |
+
self.quantizer = quantizer
|
493 |
+
self.spec_transform = spec_transform
|
494 |
+
|
495 |
+
def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
|
496 |
+
if self.spec_transform is not None:
|
497 |
+
x = self.spec_transform(x)
|
498 |
+
|
499 |
+
x = self.backbone(x)
|
500 |
+
if mask is not None:
|
501 |
+
x = x * mask
|
502 |
+
|
503 |
+
if self.quantizer is not None:
|
504 |
+
vq_result = self.quantizer(x)
|
505 |
+
x = vq_result.z
|
506 |
+
|
507 |
+
if mask is not None:
|
508 |
+
x = x * mask
|
509 |
+
|
510 |
+
x = self.head(x, template=template)
|
511 |
+
|
512 |
+
if x.ndim == 2:
|
513 |
+
x = x[:, None, :]
|
514 |
+
|
515 |
+
if self.vq is not None:
|
516 |
+
return x, vq_result
|
517 |
+
|
518 |
+
return x
|
519 |
+
|
520 |
+
def encode(self, audios, audio_lengths):
|
521 |
+
audios = audios.float()
|
522 |
+
|
523 |
+
mels = self.spec_transform(audios)
|
524 |
+
mel_lengths = audio_lengths // self.spec_transform.hop_length
|
525 |
+
mel_masks = sequence_mask(mel_lengths, mels.shape[2])
|
526 |
+
mel_masks_float_conv = mel_masks[:, None, :].float()
|
527 |
+
mels = mels * mel_masks_float_conv
|
528 |
+
|
529 |
+
# Encode
|
530 |
+
encoded_features = self.backbone(mels) * mel_masks_float_conv
|
531 |
+
feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
|
532 |
+
|
533 |
+
return self.quantizer.encode(encoded_features), feature_lengths
|
534 |
+
|
535 |
+
def decode(self, indices, feature_lengths) -> torch.Tensor:
|
536 |
+
factor = math.prod(self.quantizer.downsample_factor)
|
537 |
+
mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
|
538 |
+
mel_masks_float_conv = mel_masks[:, None, :].float()
|
539 |
+
|
540 |
+
audio_masks = sequence_mask(
|
541 |
+
feature_lengths * factor * self.spec_transform.hop_length,
|
542 |
+
indices.shape[2] * factor * self.spec_transform.hop_length,
|
543 |
+
)
|
544 |
+
audio_masks_float_conv = audio_masks[:, None, :].float()
|
545 |
+
|
546 |
+
z = self.quantizer.decode(indices) * mel_masks_float_conv
|
547 |
+
x = self.head(z) * audio_masks_float_conv
|
548 |
+
|
549 |
+
return x
|
550 |
+
|
551 |
+
def remove_parametrizations(self):
|
552 |
+
if hasattr(self.backbone, "remove_parametrizations"):
|
553 |
+
self.backbone.remove_parametrizations()
|
554 |
+
|
555 |
+
if hasattr(self.head, "remove_parametrizations"):
|
556 |
+
self.head.remove_parametrizations()
|
557 |
+
|
558 |
+
@property
|
559 |
+
def device(self):
|
560 |
+
return next(self.parameters()).device
|
561 |
+
|
562 |
+
|
563 |
class FireflyBase(nn.Module):
|
564 |
def __init__(self, ckpt_path: str = None, pretrained: bool = True):
|
565 |
super().__init__()
|
|
|
586 |
)
|
587 |
|
588 |
if ckpt_path is not None:
|
589 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")
|
590 |
elif pretrained:
|
591 |
state_dict = torch.hub.load_state_dict_from_url(
|
592 |
"https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
|
593 |
map_location="cpu",
|
594 |
+
model_dir="checkpoints",
|
595 |
)
|
596 |
|
597 |
if "state_dict" in state_dict:
|
fish_speech/models/vqgan/modules/fsq.py
CHANGED
@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
|
|
20 |
def __init__(
|
21 |
self,
|
22 |
input_dim: int = 512,
|
23 |
-
n_codebooks: int =
|
24 |
n_groups: int = 1,
|
25 |
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
26 |
downsample_factor: tuple[int] = (2, 2),
|
|
|
20 |
def __init__(
|
21 |
self,
|
22 |
input_dim: int = 512,
|
23 |
+
n_codebooks: int = 1,
|
24 |
n_groups: int = 1,
|
25 |
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
|
26 |
downsample_factor: tuple[int] = (2, 2),
|
fish_speech/text/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
from .clean import clean_text
|
|
|
2 |
|
3 |
-
__all__ = ["clean_text"]
|
|
|
1 |
from .clean import clean_text
|
2 |
+
from .spliter import split_text
|
3 |
|
4 |
+
__all__ = ["clean_text", "split_text"]
|
fish_speech/text/chn_text_norm/.gitignore
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/_build/
|
68 |
+
|
69 |
+
# PyBuilder
|
70 |
+
target/
|
71 |
+
|
72 |
+
# Jupyter Notebook
|
73 |
+
.ipynb_checkpoints
|
74 |
+
|
75 |
+
# pyenv
|
76 |
+
.python-version
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# SageMath parsed files
|
82 |
+
*.sage.py
|
83 |
+
|
84 |
+
# Environments
|
85 |
+
.env
|
86 |
+
.venv
|
87 |
+
env/
|
88 |
+
venv/
|
89 |
+
ENV/
|
90 |
+
env.bak/
|
91 |
+
venv.bak/
|
92 |
+
|
93 |
+
# Spyder project settings
|
94 |
+
.spyderproject
|
95 |
+
.spyproject
|
96 |
+
|
97 |
+
# Rope project settings
|
98 |
+
.ropeproject
|
99 |
+
|
100 |
+
# mkdocs documentation
|
101 |
+
/site
|
102 |
+
|
103 |
+
# mypy
|
104 |
+
.mypy_cache/
|
105 |
+
|
106 |
+
# JetBrains PyCharm
|
107 |
+
.idea
|
108 |
+
|
109 |
+
# Customize
|
110 |
+
references
|
111 |
+
url.txt
|
112 |
+
|
113 |
+
# Git
|
114 |
+
.git
|
fish_speech/text/chn_text_norm/README.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
|
2 |
+
|
3 |
+
# Chn Text Norm
|
4 |
+
|
5 |
+
this is a repository for chinese text normalization (no longer maintained).
|
6 |
+
|
7 |
+
## Quick Start ##
|
8 |
+
|
9 |
+
### Git Clone Repo ###
|
10 |
+
|
11 |
+
git clone this repo to the root directory of your project which need to use it.
|
12 |
+
|
13 |
+
cd /path/to/proj
|
14 |
+
git clone https://github.com/Joee1995/chn-text-norm.git
|
15 |
+
|
16 |
+
after that, your doc tree should be:
|
17 |
+
```
|
18 |
+
proj # root of your project
|
19 |
+
|--- chn_text_norm # this chn-text-norm tool
|
20 |
+
|--- text.py
|
21 |
+
|--- ...
|
22 |
+
|--- text_normalize.py # your text normalization code
|
23 |
+
|--- ...
|
24 |
+
```
|
25 |
+
|
26 |
+
### How to Use ? ###
|
27 |
+
|
28 |
+
# text_normalize.py
|
29 |
+
from chn_text_norm.text import *
|
30 |
+
|
31 |
+
raw_text = 'your raw text'
|
32 |
+
text = Text(raw_text=raw_text).normalize()
|
33 |
+
|
34 |
+
### How to add quantums ###
|
35 |
+
|
36 |
+
打开test.py,然后你就知道怎么做了。
|
fish_speech/text/chn_text_norm/__init__.py
ADDED
File without changes
|
fish_speech/text/chn_text_norm/basic_class.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""基本类
|
3 |
+
中文字符类
|
4 |
+
中文数字/数位类
|
5 |
+
中文数字类
|
6 |
+
中文数位类
|
7 |
+
中文数字系统类
|
8 |
+
中文数学符号类
|
9 |
+
*中文其他符号类
|
10 |
+
"""
|
11 |
+
|
12 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
13 |
+
__data__ = "2019-05-02"
|
14 |
+
|
15 |
+
from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
|
16 |
+
|
17 |
+
|
18 |
+
class ChineseChar(object):
|
19 |
+
"""
|
20 |
+
中文字符
|
21 |
+
每个字符对应简体和繁体,
|
22 |
+
e.g. 简体 = '负', 繁体 = '負'
|
23 |
+
转换时可转换为简体或繁体
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, simplified, traditional):
|
27 |
+
self.simplified = simplified
|
28 |
+
self.traditional = traditional
|
29 |
+
self.__repr__ = self.__str__
|
30 |
+
|
31 |
+
def __str__(self):
|
32 |
+
return self.simplified or self.traditional or None
|
33 |
+
|
34 |
+
def __repr__(self):
|
35 |
+
return self.__str__()
|
36 |
+
|
37 |
+
|
38 |
+
class ChineseNumberUnit(ChineseChar):
|
39 |
+
"""
|
40 |
+
中文数字/数位字符
|
41 |
+
每个字符除繁简体外还有一个额外的大写字符
|
42 |
+
e.g. '陆' 和 '陸'
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, power, simplified, traditional, big_s, big_t):
|
46 |
+
super(ChineseNumberUnit, self).__init__(simplified, traditional)
|
47 |
+
self.power = power
|
48 |
+
self.big_s = big_s
|
49 |
+
self.big_t = big_t
|
50 |
+
|
51 |
+
def __str__(self):
|
52 |
+
return "10^{}".format(self.power)
|
53 |
+
|
54 |
+
@classmethod
|
55 |
+
def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
|
56 |
+
|
57 |
+
if small_unit:
|
58 |
+
return ChineseNumberUnit(
|
59 |
+
power=index + 1,
|
60 |
+
simplified=value[0],
|
61 |
+
traditional=value[1],
|
62 |
+
big_s=value[1],
|
63 |
+
big_t=value[1],
|
64 |
+
)
|
65 |
+
elif numbering_type == NUMBERING_TYPES[0]:
|
66 |
+
return ChineseNumberUnit(
|
67 |
+
power=index + 8,
|
68 |
+
simplified=value[0],
|
69 |
+
traditional=value[1],
|
70 |
+
big_s=value[0],
|
71 |
+
big_t=value[1],
|
72 |
+
)
|
73 |
+
elif numbering_type == NUMBERING_TYPES[1]:
|
74 |
+
return ChineseNumberUnit(
|
75 |
+
power=(index + 2) * 4,
|
76 |
+
simplified=value[0],
|
77 |
+
traditional=value[1],
|
78 |
+
big_s=value[0],
|
79 |
+
big_t=value[1],
|
80 |
+
)
|
81 |
+
elif numbering_type == NUMBERING_TYPES[2]:
|
82 |
+
return ChineseNumberUnit(
|
83 |
+
power=pow(2, index + 3),
|
84 |
+
simplified=value[0],
|
85 |
+
traditional=value[1],
|
86 |
+
big_s=value[0],
|
87 |
+
big_t=value[1],
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
raise ValueError(
|
91 |
+
"Counting type should be in {0} ({1} provided).".format(
|
92 |
+
NUMBERING_TYPES, numbering_type
|
93 |
+
)
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
class ChineseNumberDigit(ChineseChar):
|
98 |
+
"""
|
99 |
+
中文数字字符
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(
|
103 |
+
self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
|
104 |
+
):
|
105 |
+
super(ChineseNumberDigit, self).__init__(simplified, traditional)
|
106 |
+
self.value = value
|
107 |
+
self.big_s = big_s
|
108 |
+
self.big_t = big_t
|
109 |
+
self.alt_s = alt_s
|
110 |
+
self.alt_t = alt_t
|
111 |
+
|
112 |
+
def __str__(self):
|
113 |
+
return str(self.value)
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def create(cls, i, v):
|
117 |
+
return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
|
118 |
+
|
119 |
+
|
120 |
+
class ChineseMath(ChineseChar):
|
121 |
+
"""
|
122 |
+
中文数位字符
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, simplified, traditional, symbol, expression=None):
|
126 |
+
super(ChineseMath, self).__init__(simplified, traditional)
|
127 |
+
self.symbol = symbol
|
128 |
+
self.expression = expression
|
129 |
+
self.big_s = simplified
|
130 |
+
self.big_t = traditional
|
131 |
+
|
132 |
+
|
133 |
+
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
|
134 |
+
|
135 |
+
|
136 |
+
class NumberSystem(object):
|
137 |
+
"""
|
138 |
+
中文数字系统
|
139 |
+
"""
|
140 |
+
|
141 |
+
pass
|
142 |
+
|
143 |
+
|
144 |
+
class MathSymbol(object):
|
145 |
+
"""
|
146 |
+
用于中文数字系统的数学符号 (繁/简体), e.g.
|
147 |
+
positive = ['正', '正']
|
148 |
+
negative = ['负', '負']
|
149 |
+
point = ['点', '點']
|
150 |
+
"""
|
151 |
+
|
152 |
+
def __init__(self, positive, negative, point):
|
153 |
+
self.positive = positive
|
154 |
+
self.negative = negative
|
155 |
+
self.point = point
|
156 |
+
|
157 |
+
def __iter__(self):
|
158 |
+
for v in self.__dict__.values():
|
159 |
+
yield v
|
160 |
+
|
161 |
+
|
162 |
+
# class OtherSymbol(object):
|
163 |
+
# """
|
164 |
+
# 其他符号
|
165 |
+
# """
|
166 |
+
#
|
167 |
+
# def __init__(self, sil):
|
168 |
+
# self.sil = sil
|
169 |
+
#
|
170 |
+
# def __iter__(self):
|
171 |
+
# for v in self.__dict__.values():
|
172 |
+
# yield v
|
fish_speech/text/chn_text_norm/basic_constant.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""基本常量
|
3 |
+
中文数字/数位/符号字符常量
|
4 |
+
"""
|
5 |
+
|
6 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
7 |
+
__data__ = "2019-05-02"
|
8 |
+
|
9 |
+
CHINESE_DIGIS = "零一二三四五六七八九"
|
10 |
+
BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
|
11 |
+
BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
|
12 |
+
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
|
13 |
+
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
|
14 |
+
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
|
15 |
+
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
|
16 |
+
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
|
17 |
+
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
|
18 |
+
|
19 |
+
ZERO_ALT = "〇"
|
20 |
+
ONE_ALT = "幺"
|
21 |
+
TWO_ALTS = ["两", "兩"]
|
22 |
+
|
23 |
+
POSITIVE = ["正", "正"]
|
24 |
+
NEGATIVE = ["负", "負"]
|
25 |
+
POINT = ["点", "點"]
|
26 |
+
# PLUS = [u'加', u'加']
|
27 |
+
# SIL = [u'杠', u'槓']
|
28 |
+
|
29 |
+
# 中文数字系统类型
|
30 |
+
NUMBERING_TYPES = ["low", "mid", "high"]
|
fish_speech/text/chn_text_norm/basic_util.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""基本方法
|
3 |
+
创建中文数字系统 方法
|
4 |
+
中文字符串 <=> 数字串 方法
|
5 |
+
数字串 <=> 中文字符串 方法
|
6 |
+
"""
|
7 |
+
|
8 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
9 |
+
__data__ = "2019-05-02"
|
10 |
+
|
11 |
+
from fish_speech.text.chn_text_norm.basic_class import *
|
12 |
+
from fish_speech.text.chn_text_norm.basic_constant import *
|
13 |
+
|
14 |
+
|
15 |
+
def create_system(numbering_type=NUMBERING_TYPES[1]):
|
16 |
+
"""
|
17 |
+
根据数字系统类型返回创建相应的数字系统,默认为 mid
|
18 |
+
NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
|
19 |
+
low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
|
20 |
+
mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
|
21 |
+
high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
|
22 |
+
返回对应的数字系统
|
23 |
+
"""
|
24 |
+
|
25 |
+
# chinese number units of '亿' and larger
|
26 |
+
all_larger_units = zip(
|
27 |
+
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
|
28 |
+
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
|
29 |
+
)
|
30 |
+
larger_units = [
|
31 |
+
CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
|
32 |
+
]
|
33 |
+
# chinese number units of '十, 百, 千, 万'
|
34 |
+
all_smaller_units = zip(
|
35 |
+
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
|
36 |
+
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
|
37 |
+
)
|
38 |
+
smaller_units = [
|
39 |
+
CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
|
40 |
+
]
|
41 |
+
# digis
|
42 |
+
chinese_digis = zip(
|
43 |
+
CHINESE_DIGIS,
|
44 |
+
CHINESE_DIGIS,
|
45 |
+
BIG_CHINESE_DIGIS_SIMPLIFIED,
|
46 |
+
BIG_CHINESE_DIGIS_TRADITIONAL,
|
47 |
+
)
|
48 |
+
digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
|
49 |
+
digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
|
50 |
+
digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
|
51 |
+
digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
|
52 |
+
|
53 |
+
# symbols
|
54 |
+
positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
|
55 |
+
negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
|
56 |
+
point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
|
57 |
+
# sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
|
58 |
+
system = NumberSystem()
|
59 |
+
system.units = smaller_units + larger_units
|
60 |
+
system.digits = digits
|
61 |
+
system.math = MathSymbol(positive_cn, negative_cn, point_cn)
|
62 |
+
# system.symbols = OtherSymbol(sil_cn)
|
63 |
+
return system
|
64 |
+
|
65 |
+
|
66 |
+
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
|
67 |
+
|
68 |
+
def get_symbol(char, system):
|
69 |
+
for u in system.units:
|
70 |
+
if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
|
71 |
+
return u
|
72 |
+
for d in system.digits:
|
73 |
+
if char in [
|
74 |
+
d.traditional,
|
75 |
+
d.simplified,
|
76 |
+
d.big_s,
|
77 |
+
d.big_t,
|
78 |
+
d.alt_s,
|
79 |
+
d.alt_t,
|
80 |
+
]:
|
81 |
+
return d
|
82 |
+
for m in system.math:
|
83 |
+
if char in [m.traditional, m.simplified]:
|
84 |
+
return m
|
85 |
+
|
86 |
+
def string2symbols(chinese_string, system):
|
87 |
+
int_string, dec_string = chinese_string, ""
|
88 |
+
for p in [system.math.point.simplified, system.math.point.traditional]:
|
89 |
+
if p in chinese_string:
|
90 |
+
int_string, dec_string = chinese_string.split(p)
|
91 |
+
break
|
92 |
+
return [get_symbol(c, system) for c in int_string], [
|
93 |
+
get_symbol(c, system) for c in dec_string
|
94 |
+
]
|
95 |
+
|
96 |
+
def correct_symbols(integer_symbols, system):
|
97 |
+
"""
|
98 |
+
一百八 to 一百八十
|
99 |
+
一亿一千三百万 to 一亿 一千万 三百万
|
100 |
+
"""
|
101 |
+
|
102 |
+
if integer_symbols and isinstance(integer_symbols[0], CNU):
|
103 |
+
if integer_symbols[0].power == 1:
|
104 |
+
integer_symbols = [system.digits[1]] + integer_symbols
|
105 |
+
|
106 |
+
if len(integer_symbols) > 1:
|
107 |
+
if isinstance(integer_symbols[-1], CND) and isinstance(
|
108 |
+
integer_symbols[-2], CNU
|
109 |
+
):
|
110 |
+
integer_symbols.append(
|
111 |
+
CNU(integer_symbols[-2].power - 1, None, None, None, None)
|
112 |
+
)
|
113 |
+
|
114 |
+
result = []
|
115 |
+
unit_count = 0
|
116 |
+
for s in integer_symbols:
|
117 |
+
if isinstance(s, CND):
|
118 |
+
result.append(s)
|
119 |
+
unit_count = 0
|
120 |
+
elif isinstance(s, CNU):
|
121 |
+
current_unit = CNU(s.power, None, None, None, None)
|
122 |
+
unit_count += 1
|
123 |
+
|
124 |
+
if unit_count == 1:
|
125 |
+
result.append(current_unit)
|
126 |
+
elif unit_count > 1:
|
127 |
+
for i in range(len(result)):
|
128 |
+
if (
|
129 |
+
isinstance(result[-i - 1], CNU)
|
130 |
+
and result[-i - 1].power < current_unit.power
|
131 |
+
):
|
132 |
+
result[-i - 1] = CNU(
|
133 |
+
result[-i - 1].power + current_unit.power,
|
134 |
+
None,
|
135 |
+
None,
|
136 |
+
None,
|
137 |
+
None,
|
138 |
+
)
|
139 |
+
return result
|
140 |
+
|
141 |
+
def compute_value(integer_symbols):
|
142 |
+
"""
|
143 |
+
Compute the value.
|
144 |
+
When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
|
145 |
+
e.g. '两千万' = 2000 * 10000 not 2000 + 10000
|
146 |
+
"""
|
147 |
+
value = [0]
|
148 |
+
last_power = 0
|
149 |
+
for s in integer_symbols:
|
150 |
+
if isinstance(s, CND):
|
151 |
+
value[-1] = s.value
|
152 |
+
elif isinstance(s, CNU):
|
153 |
+
value[-1] *= pow(10, s.power)
|
154 |
+
if s.power > last_power:
|
155 |
+
value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
|
156 |
+
last_power = s.power
|
157 |
+
value.append(0)
|
158 |
+
return sum(value)
|
159 |
+
|
160 |
+
system = create_system(numbering_type)
|
161 |
+
int_part, dec_part = string2symbols(chinese_string, system)
|
162 |
+
int_part = correct_symbols(int_part, system)
|
163 |
+
int_str = str(compute_value(int_part))
|
164 |
+
dec_str = "".join([str(d.value) for d in dec_part])
|
165 |
+
if dec_part:
|
166 |
+
return "{0}.{1}".format(int_str, dec_str)
|
167 |
+
else:
|
168 |
+
return int_str
|
169 |
+
|
170 |
+
|
171 |
+
def num2chn(
|
172 |
+
number_string,
|
173 |
+
numbering_type=NUMBERING_TYPES[1],
|
174 |
+
big=False,
|
175 |
+
traditional=False,
|
176 |
+
alt_zero=False,
|
177 |
+
alt_one=False,
|
178 |
+
alt_two=True,
|
179 |
+
use_zeros=True,
|
180 |
+
use_units=True,
|
181 |
+
):
|
182 |
+
|
183 |
+
def get_value(value_string, use_zeros=True):
|
184 |
+
|
185 |
+
striped_string = value_string.lstrip("0")
|
186 |
+
|
187 |
+
# record nothing if all zeros
|
188 |
+
if not striped_string:
|
189 |
+
return []
|
190 |
+
|
191 |
+
# record one digits
|
192 |
+
elif len(striped_string) == 1:
|
193 |
+
if use_zeros and len(value_string) != len(striped_string):
|
194 |
+
return [system.digits[0], system.digits[int(striped_string)]]
|
195 |
+
else:
|
196 |
+
return [system.digits[int(striped_string)]]
|
197 |
+
|
198 |
+
# recursively record multiple digits
|
199 |
+
else:
|
200 |
+
result_unit = next(
|
201 |
+
u for u in reversed(system.units) if u.power < len(striped_string)
|
202 |
+
)
|
203 |
+
result_string = value_string[: -result_unit.power]
|
204 |
+
return (
|
205 |
+
get_value(result_string)
|
206 |
+
+ [result_unit]
|
207 |
+
+ get_value(striped_string[-result_unit.power :])
|
208 |
+
)
|
209 |
+
|
210 |
+
system = create_system(numbering_type)
|
211 |
+
|
212 |
+
int_dec = number_string.split(".")
|
213 |
+
if len(int_dec) == 1:
|
214 |
+
int_string = int_dec[0]
|
215 |
+
dec_string = ""
|
216 |
+
elif len(int_dec) == 2:
|
217 |
+
int_string = int_dec[0]
|
218 |
+
dec_string = int_dec[1]
|
219 |
+
else:
|
220 |
+
raise ValueError(
|
221 |
+
"invalid input num string with more than one dot: {}".format(number_string)
|
222 |
+
)
|
223 |
+
|
224 |
+
if use_units and len(int_string) > 1:
|
225 |
+
result_symbols = get_value(int_string)
|
226 |
+
else:
|
227 |
+
result_symbols = [system.digits[int(c)] for c in int_string]
|
228 |
+
dec_symbols = [system.digits[int(c)] for c in dec_string]
|
229 |
+
if dec_string:
|
230 |
+
result_symbols += [system.math.point] + dec_symbols
|
231 |
+
|
232 |
+
if alt_two:
|
233 |
+
liang = CND(
|
234 |
+
2,
|
235 |
+
system.digits[2].alt_s,
|
236 |
+
system.digits[2].alt_t,
|
237 |
+
system.digits[2].big_s,
|
238 |
+
system.digits[2].big_t,
|
239 |
+
)
|
240 |
+
for i, v in enumerate(result_symbols):
|
241 |
+
if isinstance(v, CND) and v.value == 2:
|
242 |
+
next_symbol = (
|
243 |
+
result_symbols[i + 1] if i < len(result_symbols) - 1 else None
|
244 |
+
)
|
245 |
+
previous_symbol = result_symbols[i - 1] if i > 0 else None
|
246 |
+
if isinstance(next_symbol, CNU) and isinstance(
|
247 |
+
previous_symbol, (CNU, type(None))
|
248 |
+
):
|
249 |
+
if next_symbol.power != 1 and (
|
250 |
+
(previous_symbol is None) or (previous_symbol.power != 1)
|
251 |
+
):
|
252 |
+
result_symbols[i] = liang
|
253 |
+
|
254 |
+
# if big is True, '两' will not be used and `alt_two` has no impact on output
|
255 |
+
if big:
|
256 |
+
attr_name = "big_"
|
257 |
+
if traditional:
|
258 |
+
attr_name += "t"
|
259 |
+
else:
|
260 |
+
attr_name += "s"
|
261 |
+
else:
|
262 |
+
if traditional:
|
263 |
+
attr_name = "traditional"
|
264 |
+
else:
|
265 |
+
attr_name = "simplified"
|
266 |
+
|
267 |
+
result = "".join([getattr(s, attr_name) for s in result_symbols])
|
268 |
+
|
269 |
+
# if not use_zeros:
|
270 |
+
# result = result.strip(getattr(system.digits[0], attr_name))
|
271 |
+
|
272 |
+
if alt_zero:
|
273 |
+
result = result.replace(
|
274 |
+
getattr(system.digits[0], attr_name), system.digits[0].alt_s
|
275 |
+
)
|
276 |
+
|
277 |
+
if alt_one:
|
278 |
+
result = result.replace(
|
279 |
+
getattr(system.digits[1], attr_name), system.digits[1].alt_s
|
280 |
+
)
|
281 |
+
|
282 |
+
for i, p in enumerate(POINT):
|
283 |
+
if result.startswith(p):
|
284 |
+
return CHINESE_DIGIS[0] + result
|
285 |
+
|
286 |
+
# ^10, 11, .., 19
|
287 |
+
if (
|
288 |
+
len(result) >= 2
|
289 |
+
and result[1]
|
290 |
+
in [
|
291 |
+
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
|
292 |
+
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
|
293 |
+
]
|
294 |
+
and result[0]
|
295 |
+
in [
|
296 |
+
CHINESE_DIGIS[1],
|
297 |
+
BIG_CHINESE_DIGIS_SIMPLIFIED[1],
|
298 |
+
BIG_CHINESE_DIGIS_TRADITIONAL[1],
|
299 |
+
]
|
300 |
+
):
|
301 |
+
result = result[1:]
|
302 |
+
|
303 |
+
return result
|
304 |
+
|
305 |
+
|
306 |
+
if __name__ == "__main__":
|
307 |
+
|
308 |
+
# 测试程序
|
309 |
+
all_chinese_number_string = (
|
310 |
+
CHINESE_DIGIS
|
311 |
+
+ BIG_CHINESE_DIGIS_SIMPLIFIED
|
312 |
+
+ BIG_CHINESE_DIGIS_TRADITIONAL
|
313 |
+
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
|
314 |
+
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
|
315 |
+
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
|
316 |
+
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
|
317 |
+
+ ZERO_ALT
|
318 |
+
+ ONE_ALT
|
319 |
+
+ "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
|
320 |
+
)
|
321 |
+
|
322 |
+
print("num:", chn2num("一万零四百零三点八零五"))
|
323 |
+
print("num:", chn2num("一亿六点三"))
|
324 |
+
print("num:", chn2num("一亿零六点三"))
|
325 |
+
print("num:", chn2num("两千零一亿六点三"))
|
326 |
+
# print('num:', chn2num('一零零八六'))
|
327 |
+
print("txt:", num2chn("10260.03", alt_zero=True))
|
328 |
+
print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
|
329 |
+
print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
|
330 |
+
print(
|
331 |
+
"txt:",
|
332 |
+
num2chn(
|
333 |
+
"059523810880",
|
334 |
+
alt_one=True,
|
335 |
+
alt_two=False,
|
336 |
+
use_lzeros=True,
|
337 |
+
use_rzeros=True,
|
338 |
+
use_units=False,
|
339 |
+
),
|
340 |
+
)
|
341 |
+
|
342 |
+
print(all_chinese_number_string)
|
fish_speech/text/chn_text_norm/cardinal.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""CARDINAL类 (包含小数DECIMAL类)
|
3 |
+
纯数 <=> 中文字符串 方法
|
4 |
+
中文字符串 <=> 纯数 方法
|
5 |
+
"""
|
6 |
+
|
7 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
8 |
+
__data__ = "2019-05-03"
|
9 |
+
|
10 |
+
from fish_speech.text.chn_text_norm.basic_util import *
|
11 |
+
|
12 |
+
|
13 |
+
class Cardinal:
|
14 |
+
"""
|
15 |
+
CARDINAL类
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, cardinal=None, chntext=None):
|
19 |
+
self.cardinal = cardinal
|
20 |
+
self.chntext = chntext
|
21 |
+
|
22 |
+
def chntext2cardinal(self):
|
23 |
+
return chn2num(self.chntext)
|
24 |
+
|
25 |
+
def cardinal2chntext(self):
|
26 |
+
return num2chn(self.cardinal)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
|
31 |
+
# 测试程序
|
32 |
+
print(Cardinal(cardinal="21357.230").cardinal2chntext())
|
fish_speech/text/chn_text_norm/date.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""DATE类
|
3 |
+
日期 <=> 中文字符串 方法
|
4 |
+
中文字符串 <=> 日期 方法
|
5 |
+
"""
|
6 |
+
|
7 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
8 |
+
__data__ = "2019-05-07"
|
9 |
+
|
10 |
+
from fish_speech.text.chn_text_norm.cardinal import Cardinal
|
11 |
+
from fish_speech.text.chn_text_norm.digit import Digit
|
12 |
+
|
13 |
+
|
14 |
+
class Date:
|
15 |
+
"""
|
16 |
+
DATE类
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, date=None, chntext=None):
|
20 |
+
self.date = date
|
21 |
+
self.chntext = chntext
|
22 |
+
|
23 |
+
# def chntext2date(self):
|
24 |
+
# chntext = self.chntext
|
25 |
+
# try:
|
26 |
+
# year, other = chntext.strip().split('年', maxsplit=1)
|
27 |
+
# year = Digit(chntext=year).digit2chntext() + '年'
|
28 |
+
# except ValueError:
|
29 |
+
# other = chntext
|
30 |
+
# year = ''
|
31 |
+
# if other:
|
32 |
+
# try:
|
33 |
+
# month, day = other.strip().split('月', maxsplit=1)
|
34 |
+
# month = Cardinal(chntext=month).chntext2cardinal() + '月'
|
35 |
+
# except ValueError:
|
36 |
+
# day = chntext
|
37 |
+
# month = ''
|
38 |
+
# if day:
|
39 |
+
# day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
|
40 |
+
# else:
|
41 |
+
# month = ''
|
42 |
+
# day = ''
|
43 |
+
# date = year + month + day
|
44 |
+
# self.date = date
|
45 |
+
# return self.date
|
46 |
+
|
47 |
+
def date2chntext(self):
|
48 |
+
date = self.date
|
49 |
+
try:
|
50 |
+
year, other = date.strip().split("年", maxsplit=1)
|
51 |
+
year = Digit(digit=year).digit2chntext() + "年"
|
52 |
+
except ValueError:
|
53 |
+
other = date
|
54 |
+
year = ""
|
55 |
+
if other:
|
56 |
+
try:
|
57 |
+
month, day = other.strip().split("月", maxsplit=1)
|
58 |
+
month = Cardinal(cardinal=month).cardinal2chntext() + "月"
|
59 |
+
except ValueError:
|
60 |
+
day = date
|
61 |
+
month = ""
|
62 |
+
if day:
|
63 |
+
day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
|
64 |
+
else:
|
65 |
+
month = ""
|
66 |
+
day = ""
|
67 |
+
chntext = year + month + day
|
68 |
+
self.chntext = chntext
|
69 |
+
return self.chntext
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
|
74 |
+
# 测试
|
75 |
+
print(Date(date="09年3月16日").date2chntext())
|
fish_speech/text/chn_text_norm/digit.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""DIGIT类
|
3 |
+
数字串 <=> 中文字符串 方法
|
4 |
+
中文字符串 <=> 数字串 方法
|
5 |
+
"""
|
6 |
+
|
7 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
8 |
+
__data__ = "2019-05-03"
|
9 |
+
|
10 |
+
from fish_speech.text.chn_text_norm.basic_util import *
|
11 |
+
|
12 |
+
|
13 |
+
class Digit:
|
14 |
+
"""
|
15 |
+
DIGIT类
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, digit=None, chntext=None):
|
19 |
+
self.digit = digit
|
20 |
+
self.chntext = chntext
|
21 |
+
|
22 |
+
# def chntext2digit(self):
|
23 |
+
# return chn2num(self.chntext)
|
24 |
+
|
25 |
+
def digit2chntext(self):
|
26 |
+
return num2chn(self.digit, alt_two=False, use_units=False)
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
|
31 |
+
# 测试程序
|
32 |
+
print(Digit(digit="2016").digit2chntext())
|
fish_speech/text/chn_text_norm/fraction.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""FRACTION类
|
3 |
+
分数 <=> 中文字符串 方法
|
4 |
+
中文字符串 <=> 分数 方法
|
5 |
+
"""
|
6 |
+
|
7 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
8 |
+
__data__ = "2019-05-03"
|
9 |
+
|
10 |
+
from fish_speech.text.chn_text_norm.basic_util import *
|
11 |
+
|
12 |
+
|
13 |
+
class Fraction:
|
14 |
+
"""
|
15 |
+
FRACTION类
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, fraction=None, chntext=None):
|
19 |
+
self.fraction = fraction
|
20 |
+
self.chntext = chntext
|
21 |
+
|
22 |
+
def chntext2fraction(self):
|
23 |
+
denominator, numerator = self.chntext.split("分之")
|
24 |
+
return chn2num(numerator) + "/" + chn2num(denominator)
|
25 |
+
|
26 |
+
def fraction2chntext(self):
|
27 |
+
numerator, denominator = self.fraction.split("/")
|
28 |
+
return num2chn(denominator) + "分之" + num2chn(numerator)
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
|
33 |
+
# 测试程序
|
34 |
+
print(Fraction(fraction="2135/7230").fraction2chntext())
|
35 |
+
print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
|
fish_speech/text/chn_text_norm/money.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""MONEY类
|
3 |
+
金钱 <=> 中文字符串 方法
|
4 |
+
中文字符串 <=> 金钱 方法
|
5 |
+
"""
|
6 |
+
import re
|
7 |
+
|
8 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
9 |
+
__data__ = "2019-05-08"
|
10 |
+
|
11 |
+
from fish_speech.text.chn_text_norm.cardinal import Cardinal
|
12 |
+
|
13 |
+
|
14 |
+
class Money:
|
15 |
+
"""
|
16 |
+
MONEY类
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, money=None, chntext=None):
|
20 |
+
self.money = money
|
21 |
+
self.chntext = chntext
|
22 |
+
|
23 |
+
# def chntext2money(self):
|
24 |
+
# return self.money
|
25 |
+
|
26 |
+
def money2chntext(self):
|
27 |
+
money = self.money
|
28 |
+
pattern = re.compile(r"(\d+(\.\d+)?)")
|
29 |
+
matchers = pattern.findall(money)
|
30 |
+
if matchers:
|
31 |
+
for matcher in matchers:
|
32 |
+
money = money.replace(
|
33 |
+
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
|
34 |
+
)
|
35 |
+
self.chntext = money
|
36 |
+
return self.chntext
|
37 |
+
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
|
41 |
+
# 测试
|
42 |
+
print(Money(money="21.5万元").money2chntext())
|
43 |
+
print(Money(money="230块5毛").money2chntext())
|
fish_speech/text/chn_text_norm/percentage.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""PERCENTAGE类
|
3 |
+
百分数 <=> 中文字符串 方法
|
4 |
+
中文字符串 <=> 百分数 方法
|
5 |
+
"""
|
6 |
+
|
7 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
8 |
+
__data__ = "2019-05-06"
|
9 |
+
|
10 |
+
from fish_speech.text.chn_text_norm.basic_util import *
|
11 |
+
|
12 |
+
|
13 |
+
class Percentage:
|
14 |
+
"""
|
15 |
+
PERCENTAGE类
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, percentage=None, chntext=None):
|
19 |
+
self.percentage = percentage
|
20 |
+
self.chntext = chntext
|
21 |
+
|
22 |
+
def chntext2percentage(self):
|
23 |
+
return chn2num(self.chntext.strip().strip("百分之")) + "%"
|
24 |
+
|
25 |
+
def percentage2chntext(self):
|
26 |
+
return "百分之" + num2chn(self.percentage.strip().strip("%"))
|
27 |
+
|
28 |
+
|
29 |
+
if __name__ == "__main__":
|
30 |
+
|
31 |
+
# 测试程序
|
32 |
+
print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
|
33 |
+
print(Percentage(percentage="65.3%").percentage2chntext())
|
fish_speech/text/chn_text_norm/telephone.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""TELEPHONE类
|
3 |
+
电话号码 <=> 中文字符串 方法
|
4 |
+
中文字符串 <=> 电话号码 方法
|
5 |
+
"""
|
6 |
+
|
7 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
8 |
+
__data__ = "2019-05-03"
|
9 |
+
|
10 |
+
from fish_speech.text.chn_text_norm.basic_util import *
|
11 |
+
|
12 |
+
|
13 |
+
class TelePhone:
|
14 |
+
"""
|
15 |
+
TELEPHONE类
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, telephone=None, raw_chntext=None, chntext=None):
|
19 |
+
self.telephone = telephone
|
20 |
+
self.raw_chntext = raw_chntext
|
21 |
+
self.chntext = chntext
|
22 |
+
|
23 |
+
# def chntext2telephone(self):
|
24 |
+
# sil_parts = self.raw_chntext.split('<SIL>')
|
25 |
+
# self.telephone = '-'.join([
|
26 |
+
# str(chn2num(p)) for p in sil_parts
|
27 |
+
# ])
|
28 |
+
# return self.telephone
|
29 |
+
|
30 |
+
def telephone2chntext(self, fixed=False):
|
31 |
+
|
32 |
+
if fixed:
|
33 |
+
sil_parts = self.telephone.split("-")
|
34 |
+
self.raw_chntext = "<SIL>".join(
|
35 |
+
[num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
|
36 |
+
)
|
37 |
+
self.chntext = self.raw_chntext.replace("<SIL>", "")
|
38 |
+
else:
|
39 |
+
sp_parts = self.telephone.strip("+").split()
|
40 |
+
self.raw_chntext = "<SP>".join(
|
41 |
+
[num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
|
42 |
+
)
|
43 |
+
self.chntext = self.raw_chntext.replace("<SP>", "")
|
44 |
+
return self.chntext
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
|
49 |
+
# 测试程序
|
50 |
+
print(TelePhone(telephone="0595-23980880").telephone2chntext())
|
51 |
+
# print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
|
fish_speech/text/chn_text_norm/text.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
TEXT类
|
4 |
+
"""
|
5 |
+
|
6 |
+
__author__ = "Zhiyang Zhou <[email protected]>"
|
7 |
+
__data__ = "2019-05-03"
|
8 |
+
|
9 |
+
import re
|
10 |
+
|
11 |
+
from fish_speech.text.chn_text_norm.cardinal import Cardinal
|
12 |
+
from fish_speech.text.chn_text_norm.date import Date
|
13 |
+
from fish_speech.text.chn_text_norm.digit import Digit
|
14 |
+
from fish_speech.text.chn_text_norm.fraction import Fraction
|
15 |
+
from fish_speech.text.chn_text_norm.money import Money
|
16 |
+
from fish_speech.text.chn_text_norm.percentage import Percentage
|
17 |
+
from fish_speech.text.chn_text_norm.telephone import TelePhone
|
18 |
+
|
19 |
+
CURRENCY_NAMES = (
|
20 |
+
"(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
|
21 |
+
"里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
|
22 |
+
)
|
23 |
+
CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
|
24 |
+
COM_QUANTIFIERS = (
|
25 |
+
"(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
|
26 |
+
"砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
|
27 |
+
"针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
|
28 |
+
"毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
|
29 |
+
"盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
|
30 |
+
"纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class Text:
|
35 |
+
"""
|
36 |
+
Text类
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, raw_text, norm_text=None):
|
40 |
+
self.raw_text = "^" + raw_text + "$"
|
41 |
+
self.norm_text = norm_text
|
42 |
+
|
43 |
+
def _particular(self):
|
44 |
+
text = self.norm_text
|
45 |
+
pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
|
46 |
+
matchers = pattern.findall(text)
|
47 |
+
if matchers:
|
48 |
+
# print('particular')
|
49 |
+
for matcher in matchers:
|
50 |
+
text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
|
51 |
+
self.norm_text = text
|
52 |
+
return self.norm_text
|
53 |
+
|
54 |
+
def normalize(self):
|
55 |
+
text = self.raw_text
|
56 |
+
|
57 |
+
# 规范化日期
|
58 |
+
pattern = re.compile(
|
59 |
+
r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
|
60 |
+
)
|
61 |
+
matchers = pattern.findall(text)
|
62 |
+
if matchers:
|
63 |
+
# print('date')
|
64 |
+
for matcher in matchers:
|
65 |
+
text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
|
66 |
+
|
67 |
+
# 规范化金钱
|
68 |
+
pattern = re.compile(
|
69 |
+
r"\D+((\d+(\.\d+)?)[多余几]?"
|
70 |
+
+ CURRENCY_UNITS
|
71 |
+
+ "(\d"
|
72 |
+
+ CURRENCY_UNITS
|
73 |
+
+ "?)?)"
|
74 |
+
)
|
75 |
+
matchers = pattern.findall(text)
|
76 |
+
if matchers:
|
77 |
+
# print('money')
|
78 |
+
for matcher in matchers:
|
79 |
+
text = text.replace(
|
80 |
+
matcher[0], Money(money=matcher[0]).money2chntext(), 1
|
81 |
+
)
|
82 |
+
|
83 |
+
# 规范化固话/手机号码
|
84 |
+
# 手机
|
85 |
+
# http://www.jihaoba.com/news/show/13680
|
86 |
+
# 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
|
87 |
+
# 联通:130、131、132、156、155、186、185、176
|
88 |
+
# 电信:133、153、189、180、181、177
|
89 |
+
pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
|
90 |
+
matchers = pattern.findall(text)
|
91 |
+
if matchers:
|
92 |
+
# print('telephone')
|
93 |
+
for matcher in matchers:
|
94 |
+
text = text.replace(
|
95 |
+
matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
|
96 |
+
)
|
97 |
+
# 固话
|
98 |
+
pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
|
99 |
+
matchers = pattern.findall(text)
|
100 |
+
if matchers:
|
101 |
+
# print('fixed telephone')
|
102 |
+
for matcher in matchers:
|
103 |
+
text = text.replace(
|
104 |
+
matcher[0],
|
105 |
+
TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
|
106 |
+
1,
|
107 |
+
)
|
108 |
+
|
109 |
+
# 规范化分数
|
110 |
+
pattern = re.compile(r"(\d+/\d+)")
|
111 |
+
matchers = pattern.findall(text)
|
112 |
+
if matchers:
|
113 |
+
# print('fraction')
|
114 |
+
for matcher in matchers:
|
115 |
+
text = text.replace(
|
116 |
+
matcher, Fraction(fraction=matcher).fraction2chntext(), 1
|
117 |
+
)
|
118 |
+
|
119 |
+
# 规范化百分数
|
120 |
+
text = text.replace("%", "%")
|
121 |
+
pattern = re.compile(r"(\d+(\.\d+)?%)")
|
122 |
+
matchers = pattern.findall(text)
|
123 |
+
if matchers:
|
124 |
+
# print('percentage')
|
125 |
+
for matcher in matchers:
|
126 |
+
text = text.replace(
|
127 |
+
matcher[0],
|
128 |
+
Percentage(percentage=matcher[0]).percentage2chntext(),
|
129 |
+
1,
|
130 |
+
)
|
131 |
+
|
132 |
+
# 规范化纯数+量词
|
133 |
+
pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
|
134 |
+
matchers = pattern.findall(text)
|
135 |
+
if matchers:
|
136 |
+
# print('cardinal+quantifier')
|
137 |
+
for matcher in matchers:
|
138 |
+
text = text.replace(
|
139 |
+
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
|
140 |
+
)
|
141 |
+
|
142 |
+
# 规范化数字编号
|
143 |
+
pattern = re.compile(r"(\d{4,32})")
|
144 |
+
matchers = pattern.findall(text)
|
145 |
+
if matchers:
|
146 |
+
# print('digit')
|
147 |
+
for matcher in matchers:
|
148 |
+
text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
|
149 |
+
|
150 |
+
# 规范化纯数
|
151 |
+
pattern = re.compile(r"(\d+(\.\d+)?)")
|
152 |
+
matchers = pattern.findall(text)
|
153 |
+
if matchers:
|
154 |
+
# print('cardinal')
|
155 |
+
for matcher in matchers:
|
156 |
+
text = text.replace(
|
157 |
+
matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
|
158 |
+
)
|
159 |
+
|
160 |
+
self.norm_text = text
|
161 |
+
self._particular()
|
162 |
+
|
163 |
+
return self.norm_text.lstrip("^").rstrip("$")
|
164 |
+
|
165 |
+
|
166 |
+
if __name__ == "__main__":
|
167 |
+
|
168 |
+
# 测试程序
|
169 |
+
print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
|
170 |
+
print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
|
171 |
+
print(Text(raw_text="分数:32477/76391。").normalize())
|
172 |
+
print(Text(raw_text="百分数:80.03%。").normalize())
|
173 |
+
print(Text(raw_text="编号:31520181154418。").normalize())
|
174 |
+
print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
|
175 |
+
print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
|
176 |
+
print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
|
177 |
+
print(Text(raw_text="特殊:O2O或B2C。").normalize())
|
fish_speech/text/clean.py
CHANGED
@@ -18,7 +18,6 @@ SYMBOLS_MAPPING = {
|
|
18 |
"·": ",",
|
19 |
"、": ",",
|
20 |
"...": "…",
|
21 |
-
"$": ".",
|
22 |
"“": "'",
|
23 |
"”": "'",
|
24 |
"‘": "'",
|
@@ -62,12 +61,9 @@ REMOVE_UNKNOWN_SYMBOL_REGEX = re.compile(
|
|
62 |
def clean_text(text):
|
63 |
# Clean the text
|
64 |
text = text.strip()
|
65 |
-
|
66 |
-
text = re.sub(r"<p:(.*?)>", r"<PPP\1PPP>", text)
|
67 |
# Replace all chinese symbols with their english counterparts
|
68 |
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
69 |
text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
|
70 |
-
# Replace <PPP(.*?)PPP> with <p:(.*?)>
|
71 |
-
text = re.sub(r"<PPP(.*?)PPP>", r"<p:\1>", text)
|
72 |
|
73 |
return text
|
|
|
18 |
"·": ",",
|
19 |
"、": ",",
|
20 |
"...": "…",
|
|
|
21 |
"“": "'",
|
22 |
"”": "'",
|
23 |
"‘": "'",
|
|
|
61 |
def clean_text(text):
|
62 |
# Clean the text
|
63 |
text = text.strip()
|
64 |
+
|
|
|
65 |
# Replace all chinese symbols with their english counterparts
|
66 |
text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
|
67 |
text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
|
|
|
|
|
68 |
|
69 |
return text
|
fish_speech/text/spliter.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import string
|
3 |
+
|
4 |
+
from fish_speech.text.clean import clean_text
|
5 |
+
|
6 |
+
|
7 |
+
def utf_8_len(text):
|
8 |
+
return len(text.encode("utf-8"))
|
9 |
+
|
10 |
+
|
11 |
+
def break_text(texts, length, splits: set):
|
12 |
+
for text in texts:
|
13 |
+
if utf_8_len(text) <= length:
|
14 |
+
yield text
|
15 |
+
continue
|
16 |
+
|
17 |
+
curr = ""
|
18 |
+
for char in text:
|
19 |
+
curr += char
|
20 |
+
|
21 |
+
if char in splits:
|
22 |
+
yield curr
|
23 |
+
curr = ""
|
24 |
+
|
25 |
+
if curr:
|
26 |
+
yield curr
|
27 |
+
|
28 |
+
|
29 |
+
def break_text_by_length(texts, length):
|
30 |
+
for text in texts:
|
31 |
+
if utf_8_len(text) <= length:
|
32 |
+
yield text
|
33 |
+
continue
|
34 |
+
|
35 |
+
curr = ""
|
36 |
+
for char in text:
|
37 |
+
curr += char
|
38 |
+
|
39 |
+
if utf_8_len(curr) >= length:
|
40 |
+
yield curr
|
41 |
+
curr = ""
|
42 |
+
|
43 |
+
if curr:
|
44 |
+
yield curr
|
45 |
+
|
46 |
+
|
47 |
+
def add_cleaned(curr, segments):
|
48 |
+
curr = curr.strip()
|
49 |
+
if curr and not all(c.isspace() or c in string.punctuation for c in curr):
|
50 |
+
segments.append(curr)
|
51 |
+
|
52 |
+
|
53 |
+
def protect_float(text):
|
54 |
+
# Turns 3.14 into <3_f_14> to prevent splitting
|
55 |
+
return re.sub(r"(\d+)\.(\d+)", r"<\1_f_\2>", text)
|
56 |
+
|
57 |
+
|
58 |
+
def unprotect_float(text):
|
59 |
+
# Turns <3_f_14> into 3.14
|
60 |
+
return re.sub(r"<(\d+)_f_(\d+)>", r"\1.\2", text)
|
61 |
+
|
62 |
+
|
63 |
+
def split_text(text, length):
|
64 |
+
text = clean_text(text)
|
65 |
+
|
66 |
+
# Break the text into pieces with following rules:
|
67 |
+
# 1. Split the text at ".", "!", "?" if text is NOT a float
|
68 |
+
# 2. If the text is longer than length, split at ","
|
69 |
+
# 3. If the text is still longer than length, split at " "
|
70 |
+
# 4. If the text is still longer than length, split at any character to length
|
71 |
+
|
72 |
+
texts = [text]
|
73 |
+
texts = map(protect_float, texts)
|
74 |
+
texts = break_text(texts, length, {".", "!", "?"})
|
75 |
+
texts = map(unprotect_float, texts)
|
76 |
+
texts = break_text(texts, length, {","})
|
77 |
+
texts = break_text(texts, length, {" "})
|
78 |
+
texts = list(break_text_by_length(texts, length))
|
79 |
+
|
80 |
+
# Then, merge the texts into segments with length <= length
|
81 |
+
segments = []
|
82 |
+
curr = ""
|
83 |
+
|
84 |
+
for text in texts:
|
85 |
+
if utf_8_len(curr) + utf_8_len(text) <= length:
|
86 |
+
curr += text
|
87 |
+
else:
|
88 |
+
add_cleaned(curr, segments)
|
89 |
+
curr = text
|
90 |
+
|
91 |
+
if curr:
|
92 |
+
add_cleaned(curr, segments)
|
93 |
+
|
94 |
+
return segments
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == "__main__":
|
98 |
+
# Test the split_text function
|
99 |
+
|
100 |
+
text = "This is a test sentence. This is another test sentence. And a third one."
|
101 |
+
|
102 |
+
assert split_text(text, 50) == [
|
103 |
+
"This is a test sentence.",
|
104 |
+
"This is another test sentence. And a third one.",
|
105 |
+
]
|
106 |
+
assert split_text("a,aaaaaa3.14", 10) == ["a,", "aaaaaa3.14"]
|
107 |
+
assert split_text(" ", 10) == []
|
108 |
+
assert split_text("a", 10) == ["a"]
|
109 |
+
|
110 |
+
text = "This is a test sentence with only commas, and no dots, and no exclamation marks, and no question marks, and no newlines."
|
111 |
+
assert split_text(text, 50) == [
|
112 |
+
"This is a test sentence with only commas,",
|
113 |
+
"and no dots, and no exclamation marks,",
|
114 |
+
"and no question marks, and no newlines.",
|
115 |
+
]
|
116 |
+
|
117 |
+
text = "This is a test sentence This is a test sentence This is a test sentence. This is a test sentence, This is a test sentence, This is a test sentence."
|
118 |
+
# First half split at " ", second half split at ","
|
119 |
+
assert split_text(text, 50) == [
|
120 |
+
"This is a test sentence This is a test sentence",
|
121 |
+
"This is a test sentence. This is a test sentence,",
|
122 |
+
"This is a test sentence, This is a test sentence.",
|
123 |
+
]
|
124 |
+
|
125 |
+
text = "这是一段很长的中文文本,而且没有句号,也没有感叹号,也没有问号,也没有换行符。"
|
126 |
+
assert split_text(text, 50) == [
|
127 |
+
"这是一段很长的中文文本,",
|
128 |
+
"而且没有句号,也没有感叹号,",
|
129 |
+
"也没有问号,也没有换行符.",
|
130 |
+
]
|
fish_speech/utils/file.py
CHANGED
@@ -44,7 +44,7 @@ def list_files(
|
|
44 |
if not path.exists():
|
45 |
raise FileNotFoundError(f"Directory {path} does not exist.")
|
46 |
|
47 |
-
files = [file for ext in extensions for file in path.
|
48 |
|
49 |
if sort:
|
50 |
files = natsorted(files)
|
|
|
44 |
if not path.exists():
|
45 |
raise FileNotFoundError(f"Directory {path} does not exist.")
|
46 |
|
47 |
+
files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
|
48 |
|
49 |
if sort:
|
50 |
files = natsorted(files)
|
fish_speech/utils/rich_utils.py
CHANGED
@@ -43,9 +43,13 @@ def print_config_tree(
|
|
43 |
|
44 |
# add fields from `print_order` to queue
|
45 |
for field in print_order:
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
)
|
50 |
|
51 |
# add all the other fields to queue (not specified in `print_order`)
|
|
|
43 |
|
44 |
# add fields from `print_order` to queue
|
45 |
for field in print_order:
|
46 |
+
(
|
47 |
+
queue.append(field)
|
48 |
+
if field in cfg
|
49 |
+
else log.warning(
|
50 |
+
f"Field '{field}' not found in config. "
|
51 |
+
+ f"Skipping '{field}' config printing..."
|
52 |
+
)
|
53 |
)
|
54 |
|
55 |
# add all the other fields to queue (not specified in `print_order`)
|
fish_speech/utils/spectrogram.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio.functional as F
|
3 |
+
from torch import Tensor, nn
|
4 |
+
from torchaudio.transforms import MelScale
|
5 |
+
|
6 |
+
|
7 |
+
class LinearSpectrogram(nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
n_fft=2048,
|
11 |
+
win_length=2048,
|
12 |
+
hop_length=512,
|
13 |
+
center=False,
|
14 |
+
mode="pow2_sqrt",
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.n_fft = n_fft
|
19 |
+
self.win_length = win_length
|
20 |
+
self.hop_length = hop_length
|
21 |
+
self.center = center
|
22 |
+
self.mode = mode
|
23 |
+
|
24 |
+
self.register_buffer("window", torch.hann_window(win_length), persistent=False)
|
25 |
+
|
26 |
+
def forward(self, y: Tensor) -> Tensor:
|
27 |
+
if y.ndim == 3:
|
28 |
+
y = y.squeeze(1)
|
29 |
+
|
30 |
+
y = torch.nn.functional.pad(
|
31 |
+
y.unsqueeze(1),
|
32 |
+
(
|
33 |
+
(self.win_length - self.hop_length) // 2,
|
34 |
+
(self.win_length - self.hop_length + 1) // 2,
|
35 |
+
),
|
36 |
+
mode="reflect",
|
37 |
+
).squeeze(1)
|
38 |
+
|
39 |
+
spec = torch.stft(
|
40 |
+
y,
|
41 |
+
self.n_fft,
|
42 |
+
hop_length=self.hop_length,
|
43 |
+
win_length=self.win_length,
|
44 |
+
window=self.window,
|
45 |
+
center=self.center,
|
46 |
+
pad_mode="reflect",
|
47 |
+
normalized=False,
|
48 |
+
onesided=True,
|
49 |
+
return_complex=True,
|
50 |
+
)
|
51 |
+
|
52 |
+
spec = torch.view_as_real(spec)
|
53 |
+
|
54 |
+
if self.mode == "pow2_sqrt":
|
55 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
56 |
+
|
57 |
+
return spec
|
58 |
+
|
59 |
+
|
60 |
+
class LogMelSpectrogram(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
sample_rate=44100,
|
64 |
+
n_fft=2048,
|
65 |
+
win_length=2048,
|
66 |
+
hop_length=512,
|
67 |
+
n_mels=128,
|
68 |
+
center=False,
|
69 |
+
f_min=0.0,
|
70 |
+
f_max=None,
|
71 |
+
):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
self.sample_rate = sample_rate
|
75 |
+
self.n_fft = n_fft
|
76 |
+
self.win_length = win_length
|
77 |
+
self.hop_length = hop_length
|
78 |
+
self.center = center
|
79 |
+
self.n_mels = n_mels
|
80 |
+
self.f_min = f_min
|
81 |
+
self.f_max = f_max or float(sample_rate // 2)
|
82 |
+
|
83 |
+
self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
|
84 |
+
|
85 |
+
fb = F.melscale_fbanks(
|
86 |
+
n_freqs=self.n_fft // 2 + 1,
|
87 |
+
f_min=self.f_min,
|
88 |
+
f_max=self.f_max,
|
89 |
+
n_mels=self.n_mels,
|
90 |
+
sample_rate=self.sample_rate,
|
91 |
+
norm="slaney",
|
92 |
+
mel_scale="slaney",
|
93 |
+
)
|
94 |
+
self.register_buffer(
|
95 |
+
"fb",
|
96 |
+
fb,
|
97 |
+
persistent=False,
|
98 |
+
)
|
99 |
+
|
100 |
+
def compress(self, x: Tensor) -> Tensor:
|
101 |
+
return torch.log(torch.clamp(x, min=1e-5))
|
102 |
+
|
103 |
+
def decompress(self, x: Tensor) -> Tensor:
|
104 |
+
return torch.exp(x)
|
105 |
+
|
106 |
+
def apply_mel_scale(self, x: Tensor) -> Tensor:
|
107 |
+
return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
|
108 |
+
|
109 |
+
def forward(
|
110 |
+
self, x: Tensor, return_linear: bool = False, sample_rate: int = None
|
111 |
+
) -> Tensor:
|
112 |
+
if sample_rate is not None and sample_rate != self.sample_rate:
|
113 |
+
x = F.resample(x, orig_freq=sample_rate, new_freq=self.sample_rate)
|
114 |
+
|
115 |
+
linear = self.spectrogram(x)
|
116 |
+
x = self.apply_mel_scale(linear)
|
117 |
+
x = self.compress(x)
|
118 |
+
|
119 |
+
if return_linear:
|
120 |
+
return x, self.compress(linear)
|
121 |
+
|
122 |
+
return x
|
tools/api.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import queue
|
5 |
+
import random
|
6 |
+
import traceback
|
7 |
+
import wave
|
8 |
+
from argparse import ArgumentParser
|
9 |
+
from http import HTTPStatus
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import Annotated, Literal, Optional
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
import pyrootutils
|
16 |
+
import soundfile as sf
|
17 |
+
import torch
|
18 |
+
from kui.asgi import (
|
19 |
+
Body,
|
20 |
+
HTTPException,
|
21 |
+
HttpView,
|
22 |
+
JSONResponse,
|
23 |
+
Kui,
|
24 |
+
OpenAPI,
|
25 |
+
StreamResponse,
|
26 |
+
)
|
27 |
+
from kui.asgi.routing import MultimethodRoutes
|
28 |
+
from loguru import logger
|
29 |
+
from pydantic import BaseModel, Field
|
30 |
+
|
31 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
32 |
+
|
33 |
+
# from fish_speech.models.vqgan.lit_module import VQGAN
|
34 |
+
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
|
35 |
+
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
|
36 |
+
from tools.llama.generate import (
|
37 |
+
GenerateRequest,
|
38 |
+
GenerateResponse,
|
39 |
+
WrappedGenerateResponse,
|
40 |
+
launch_thread_safe_queue,
|
41 |
+
)
|
42 |
+
from tools.vqgan.inference import load_model as load_decoder_model
|
43 |
+
|
44 |
+
|
45 |
+
def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
|
46 |
+
buffer = io.BytesIO()
|
47 |
+
|
48 |
+
with wave.open(buffer, "wb") as wav_file:
|
49 |
+
wav_file.setnchannels(channels)
|
50 |
+
wav_file.setsampwidth(bit_depth // 8)
|
51 |
+
wav_file.setframerate(sample_rate)
|
52 |
+
|
53 |
+
wav_header_bytes = buffer.getvalue()
|
54 |
+
buffer.close()
|
55 |
+
return wav_header_bytes
|
56 |
+
|
57 |
+
|
58 |
+
# Define utils for web server
|
59 |
+
async def http_execption_handler(exc: HTTPException):
|
60 |
+
return JSONResponse(
|
61 |
+
dict(
|
62 |
+
statusCode=exc.status_code,
|
63 |
+
message=exc.content,
|
64 |
+
error=HTTPStatus(exc.status_code).phrase,
|
65 |
+
),
|
66 |
+
exc.status_code,
|
67 |
+
exc.headers,
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
async def other_exception_handler(exc: "Exception"):
|
72 |
+
traceback.print_exc()
|
73 |
+
|
74 |
+
status = HTTPStatus.INTERNAL_SERVER_ERROR
|
75 |
+
return JSONResponse(
|
76 |
+
dict(statusCode=status, message=str(exc), error=status.phrase),
|
77 |
+
status,
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
def load_audio(reference_audio, sr):
|
82 |
+
if len(reference_audio) > 255 or not Path(reference_audio).exists():
|
83 |
+
try:
|
84 |
+
audio_data = base64.b64decode(reference_audio)
|
85 |
+
reference_audio = io.BytesIO(audio_data)
|
86 |
+
except base64.binascii.Error:
|
87 |
+
raise ValueError("Invalid path or base64 string")
|
88 |
+
|
89 |
+
audio, _ = librosa.load(reference_audio, sr=sr, mono=True)
|
90 |
+
return audio
|
91 |
+
|
92 |
+
|
93 |
+
def encode_reference(*, decoder_model, reference_audio, enable_reference_audio):
|
94 |
+
if enable_reference_audio and reference_audio is not None:
|
95 |
+
# Load audios, and prepare basic info here
|
96 |
+
reference_audio_content = load_audio(
|
97 |
+
reference_audio, decoder_model.spec_transform.sample_rate
|
98 |
+
)
|
99 |
+
|
100 |
+
audios = torch.from_numpy(reference_audio_content).to(decoder_model.device)[
|
101 |
+
None, None, :
|
102 |
+
]
|
103 |
+
audio_lengths = torch.tensor(
|
104 |
+
[audios.shape[2]], device=decoder_model.device, dtype=torch.long
|
105 |
+
)
|
106 |
+
logger.info(
|
107 |
+
f"Loaded audio with {audios.shape[2] / decoder_model.spec_transform.sample_rate:.2f} seconds"
|
108 |
+
)
|
109 |
+
|
110 |
+
# VQ Encoder
|
111 |
+
if isinstance(decoder_model, FireflyArchitecture):
|
112 |
+
prompt_tokens = decoder_model.encode(audios, audio_lengths)[0][0]
|
113 |
+
|
114 |
+
logger.info(f"Encoded prompt: {prompt_tokens.shape}")
|
115 |
+
else:
|
116 |
+
prompt_tokens = None
|
117 |
+
logger.info("No reference audio provided")
|
118 |
+
|
119 |
+
return prompt_tokens
|
120 |
+
|
121 |
+
|
122 |
+
def decode_vq_tokens(
|
123 |
+
*,
|
124 |
+
decoder_model,
|
125 |
+
codes,
|
126 |
+
):
|
127 |
+
feature_lengths = torch.tensor([codes.shape[1]], device=decoder_model.device)
|
128 |
+
logger.info(f"VQ features: {codes.shape}")
|
129 |
+
|
130 |
+
if isinstance(decoder_model, FireflyArchitecture):
|
131 |
+
# VQGAN Inference
|
132 |
+
return decoder_model.decode(
|
133 |
+
indices=codes[None],
|
134 |
+
feature_lengths=feature_lengths,
|
135 |
+
).squeeze()
|
136 |
+
|
137 |
+
raise ValueError(f"Unknown model type: {type(decoder_model)}")
|
138 |
+
|
139 |
+
|
140 |
+
routes = MultimethodRoutes(base_class=HttpView)
|
141 |
+
|
142 |
+
|
143 |
+
def get_random_paths(base_path, data, speaker, emotion):
|
144 |
+
if base_path and data and speaker and emotion and (Path(base_path).exists()):
|
145 |
+
if speaker in data and emotion in data[speaker]:
|
146 |
+
files = data[speaker][emotion]
|
147 |
+
lab_files = [f for f in files if f.endswith(".lab")]
|
148 |
+
wav_files = [f for f in files if f.endswith(".wav")]
|
149 |
+
|
150 |
+
if lab_files and wav_files:
|
151 |
+
selected_lab = random.choice(lab_files)
|
152 |
+
selected_wav = random.choice(wav_files)
|
153 |
+
|
154 |
+
lab_path = Path(base_path) / speaker / emotion / selected_lab
|
155 |
+
wav_path = Path(base_path) / speaker / emotion / selected_wav
|
156 |
+
if lab_path.exists() and wav_path.exists():
|
157 |
+
return lab_path, wav_path
|
158 |
+
|
159 |
+
return None, None
|
160 |
+
|
161 |
+
|
162 |
+
def load_json(json_file):
|
163 |
+
if not json_file:
|
164 |
+
logger.info("Not using a json file")
|
165 |
+
return None
|
166 |
+
try:
|
167 |
+
with open(json_file, "r", encoding="utf-8") as file:
|
168 |
+
data = json.load(file)
|
169 |
+
except FileNotFoundError:
|
170 |
+
logger.warning(f"ref json not found: {json_file}")
|
171 |
+
data = None
|
172 |
+
except Exception as e:
|
173 |
+
logger.warning(f"Loading json failed: {e}")
|
174 |
+
data = None
|
175 |
+
return data
|
176 |
+
|
177 |
+
|
178 |
+
class InvokeRequest(BaseModel):
|
179 |
+
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
|
180 |
+
reference_text: Optional[str] = None
|
181 |
+
reference_audio: Optional[str] = None
|
182 |
+
max_new_tokens: int = 1024
|
183 |
+
chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
|
184 |
+
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
185 |
+
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
|
186 |
+
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
|
187 |
+
emotion: Optional[str] = None
|
188 |
+
format: Literal["wav", "mp3", "flac"] = "wav"
|
189 |
+
streaming: bool = False
|
190 |
+
ref_json: Optional[str] = "ref_data.json"
|
191 |
+
ref_base: Optional[str] = "ref_data"
|
192 |
+
speaker: Optional[str] = None
|
193 |
+
|
194 |
+
|
195 |
+
def get_content_type(audio_format):
|
196 |
+
if audio_format == "wav":
|
197 |
+
return "audio/wav"
|
198 |
+
elif audio_format == "flac":
|
199 |
+
return "audio/flac"
|
200 |
+
elif audio_format == "mp3":
|
201 |
+
return "audio/mpeg"
|
202 |
+
else:
|
203 |
+
return "application/octet-stream"
|
204 |
+
|
205 |
+
|
206 |
+
@torch.inference_mode()
|
207 |
+
def inference(req: InvokeRequest):
|
208 |
+
# Parse reference audio aka prompt
|
209 |
+
prompt_tokens = None
|
210 |
+
|
211 |
+
ref_data = load_json(req.ref_json)
|
212 |
+
ref_base = req.ref_base
|
213 |
+
|
214 |
+
lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
|
215 |
+
|
216 |
+
if lab_path and wav_path:
|
217 |
+
with open(lab_path, "r", encoding="utf-8") as lab_file:
|
218 |
+
ref_text = lab_file.read()
|
219 |
+
req.reference_audio = wav_path
|
220 |
+
req.reference_text = ref_text
|
221 |
+
logger.info("ref_path: " + str(wav_path))
|
222 |
+
logger.info("ref_text: " + ref_text)
|
223 |
+
|
224 |
+
# Parse reference audio aka prompt
|
225 |
+
prompt_tokens = encode_reference(
|
226 |
+
decoder_model=decoder_model,
|
227 |
+
reference_audio=req.reference_audio,
|
228 |
+
enable_reference_audio=req.reference_audio is not None,
|
229 |
+
)
|
230 |
+
logger.info(f"ref_text: {req.reference_text}")
|
231 |
+
# LLAMA Inference
|
232 |
+
request = dict(
|
233 |
+
device=decoder_model.device,
|
234 |
+
max_new_tokens=req.max_new_tokens,
|
235 |
+
text=req.text,
|
236 |
+
top_p=req.top_p,
|
237 |
+
repetition_penalty=req.repetition_penalty,
|
238 |
+
temperature=req.temperature,
|
239 |
+
compile=args.compile,
|
240 |
+
iterative_prompt=req.chunk_length > 0,
|
241 |
+
chunk_length=req.chunk_length,
|
242 |
+
max_length=2048,
|
243 |
+
prompt_tokens=prompt_tokens,
|
244 |
+
prompt_text=req.reference_text,
|
245 |
+
)
|
246 |
+
|
247 |
+
response_queue = queue.Queue()
|
248 |
+
llama_queue.put(
|
249 |
+
GenerateRequest(
|
250 |
+
request=request,
|
251 |
+
response_queue=response_queue,
|
252 |
+
)
|
253 |
+
)
|
254 |
+
|
255 |
+
if req.streaming:
|
256 |
+
yield wav_chunk_header()
|
257 |
+
|
258 |
+
segments = []
|
259 |
+
while True:
|
260 |
+
result: WrappedGenerateResponse = response_queue.get()
|
261 |
+
if result.status == "error":
|
262 |
+
raise result.response
|
263 |
+
break
|
264 |
+
|
265 |
+
result: GenerateResponse = result.response
|
266 |
+
if result.action == "next":
|
267 |
+
break
|
268 |
+
|
269 |
+
with torch.autocast(
|
270 |
+
device_type=decoder_model.device.type, dtype=args.precision
|
271 |
+
):
|
272 |
+
fake_audios = decode_vq_tokens(
|
273 |
+
decoder_model=decoder_model,
|
274 |
+
codes=result.codes,
|
275 |
+
)
|
276 |
+
|
277 |
+
fake_audios = fake_audios.float().cpu().numpy()
|
278 |
+
|
279 |
+
if req.streaming:
|
280 |
+
yield (fake_audios * 32768).astype(np.int16).tobytes()
|
281 |
+
else:
|
282 |
+
segments.append(fake_audios)
|
283 |
+
|
284 |
+
if req.streaming:
|
285 |
+
return
|
286 |
+
|
287 |
+
if len(segments) == 0:
|
288 |
+
raise HTTPException(
|
289 |
+
HTTPStatus.INTERNAL_SERVER_ERROR,
|
290 |
+
content="No audio generated, please check the input text.",
|
291 |
+
)
|
292 |
+
|
293 |
+
fake_audios = np.concatenate(segments, axis=0)
|
294 |
+
yield fake_audios
|
295 |
+
|
296 |
+
|
297 |
+
def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
|
298 |
+
if not use_auto_rerank:
|
299 |
+
# 如果不使用 auto_rerank,直接调用原始的 inference 函数
|
300 |
+
return inference(req)
|
301 |
+
|
302 |
+
zh_model, en_model = load_model()
|
303 |
+
max_attempts = 5
|
304 |
+
best_wer = float("inf")
|
305 |
+
best_audio = None
|
306 |
+
|
307 |
+
for attempt in range(max_attempts):
|
308 |
+
# 调用原始的 inference 函数
|
309 |
+
audio_generator = inference(req)
|
310 |
+
fake_audios = next(audio_generator)
|
311 |
+
|
312 |
+
asr_result = batch_asr(
|
313 |
+
zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
|
314 |
+
)[0]
|
315 |
+
wer = calculate_wer(req.text, asr_result["text"])
|
316 |
+
|
317 |
+
if wer <= 0.1 and not asr_result["huge_gap"]:
|
318 |
+
return fake_audios
|
319 |
+
|
320 |
+
if wer < best_wer:
|
321 |
+
best_wer = wer
|
322 |
+
best_audio = fake_audios
|
323 |
+
|
324 |
+
if attempt == max_attempts - 1:
|
325 |
+
break
|
326 |
+
|
327 |
+
return best_audio
|
328 |
+
|
329 |
+
|
330 |
+
async def inference_async(req: InvokeRequest):
|
331 |
+
for chunk in inference(req):
|
332 |
+
yield chunk
|
333 |
+
|
334 |
+
|
335 |
+
async def buffer_to_async_generator(buffer):
|
336 |
+
yield buffer
|
337 |
+
|
338 |
+
|
339 |
+
@routes.http.post("/v1/invoke")
|
340 |
+
async def api_invoke_model(
|
341 |
+
req: Annotated[InvokeRequest, Body(exclusive=True)],
|
342 |
+
):
|
343 |
+
"""
|
344 |
+
Invoke model and generate audio
|
345 |
+
"""
|
346 |
+
|
347 |
+
if args.max_text_length > 0 and len(req.text) > args.max_text_length:
|
348 |
+
raise HTTPException(
|
349 |
+
HTTPStatus.BAD_REQUEST,
|
350 |
+
content=f"Text is too long, max length is {args.max_text_length}",
|
351 |
+
)
|
352 |
+
|
353 |
+
if req.streaming and req.format != "wav":
|
354 |
+
raise HTTPException(
|
355 |
+
HTTPStatus.BAD_REQUEST,
|
356 |
+
content="Streaming only supports WAV format",
|
357 |
+
)
|
358 |
+
|
359 |
+
if req.streaming:
|
360 |
+
return StreamResponse(
|
361 |
+
iterable=inference_async(req),
|
362 |
+
headers={
|
363 |
+
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
364 |
+
},
|
365 |
+
content_type=get_content_type(req.format),
|
366 |
+
)
|
367 |
+
else:
|
368 |
+
fake_audios = next(inference(req))
|
369 |
+
buffer = io.BytesIO()
|
370 |
+
sf.write(
|
371 |
+
buffer,
|
372 |
+
fake_audios,
|
373 |
+
decoder_model.spec_transform.sample_rate,
|
374 |
+
format=req.format,
|
375 |
+
)
|
376 |
+
|
377 |
+
return StreamResponse(
|
378 |
+
iterable=buffer_to_async_generator(buffer.getvalue()),
|
379 |
+
headers={
|
380 |
+
"Content-Disposition": f"attachment; filename=audio.{req.format}",
|
381 |
+
},
|
382 |
+
content_type=get_content_type(req.format),
|
383 |
+
)
|
384 |
+
|
385 |
+
|
386 |
+
@routes.http.post("/v1/health")
|
387 |
+
async def api_health():
|
388 |
+
"""
|
389 |
+
Health check
|
390 |
+
"""
|
391 |
+
|
392 |
+
return JSONResponse({"status": "ok"})
|
393 |
+
|
394 |
+
|
395 |
+
def parse_args():
|
396 |
+
parser = ArgumentParser()
|
397 |
+
parser.add_argument(
|
398 |
+
"--llama-checkpoint-path",
|
399 |
+
type=str,
|
400 |
+
default="checkpoints/fish-speech-1.2-sft",
|
401 |
+
)
|
402 |
+
parser.add_argument(
|
403 |
+
"--decoder-checkpoint-path",
|
404 |
+
type=str,
|
405 |
+
default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
406 |
+
)
|
407 |
+
parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
|
408 |
+
parser.add_argument("--device", type=str, default="cuda")
|
409 |
+
parser.add_argument("--half", action="store_true")
|
410 |
+
parser.add_argument("--compile", action="store_true")
|
411 |
+
parser.add_argument("--max-text-length", type=int, default=0)
|
412 |
+
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
|
413 |
+
parser.add_argument("--workers", type=int, default=1)
|
414 |
+
parser.add_argument("--use-auto-rerank", type=bool, default=True)
|
415 |
+
|
416 |
+
return parser.parse_args()
|
417 |
+
|
418 |
+
|
419 |
+
# Define Kui app
|
420 |
+
openapi = OpenAPI(
|
421 |
+
{
|
422 |
+
"title": "Fish Speech API",
|
423 |
+
},
|
424 |
+
).routes
|
425 |
+
|
426 |
+
app = Kui(
|
427 |
+
routes=routes + openapi[1:], # Remove the default route
|
428 |
+
exception_handlers={
|
429 |
+
HTTPException: http_execption_handler,
|
430 |
+
Exception: other_exception_handler,
|
431 |
+
},
|
432 |
+
cors_config={},
|
433 |
+
)
|
434 |
+
|
435 |
+
|
436 |
+
if __name__ == "__main__":
|
437 |
+
import threading
|
438 |
+
|
439 |
+
import uvicorn
|
440 |
+
|
441 |
+
args = parse_args()
|
442 |
+
args.precision = torch.half if args.half else torch.bfloat16
|
443 |
+
|
444 |
+
logger.info("Loading Llama model...")
|
445 |
+
llama_queue = launch_thread_safe_queue(
|
446 |
+
checkpoint_path=args.llama_checkpoint_path,
|
447 |
+
device=args.device,
|
448 |
+
precision=args.precision,
|
449 |
+
compile=args.compile,
|
450 |
+
)
|
451 |
+
logger.info("Llama model loaded, loading VQ-GAN model...")
|
452 |
+
|
453 |
+
decoder_model = load_decoder_model(
|
454 |
+
config_name=args.decoder_config_name,
|
455 |
+
checkpoint_path=args.decoder_checkpoint_path,
|
456 |
+
device=args.device,
|
457 |
+
)
|
458 |
+
|
459 |
+
logger.info("VQ-GAN model loaded, warming up...")
|
460 |
+
|
461 |
+
# Dry run to check if the model is loaded correctly and avoid the first-time latency
|
462 |
+
list(
|
463 |
+
inference(
|
464 |
+
InvokeRequest(
|
465 |
+
text="Hello world.",
|
466 |
+
reference_text=None,
|
467 |
+
reference_audio=None,
|
468 |
+
max_new_tokens=0,
|
469 |
+
top_p=0.7,
|
470 |
+
repetition_penalty=1.2,
|
471 |
+
temperature=0.7,
|
472 |
+
emotion=None,
|
473 |
+
format="wav",
|
474 |
+
ref_base=None,
|
475 |
+
ref_json=None,
|
476 |
+
)
|
477 |
+
)
|
478 |
+
)
|
479 |
+
|
480 |
+
logger.info(f"Warming up done, starting server at http://{args.listen}")
|
481 |
+
host, port = args.listen.split(":")
|
482 |
+
uvicorn.run(app, host=host, port=int(port), workers=args.workers, log_level="info")
|
tools/auto_rerank.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["MODELSCOPE_CACHE"] = ".cache/"
|
4 |
+
|
5 |
+
import string
|
6 |
+
import time
|
7 |
+
from threading import Lock
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import opencc
|
12 |
+
import torch
|
13 |
+
from faster_whisper import WhisperModel
|
14 |
+
|
15 |
+
t2s_converter = opencc.OpenCC("t2s")
|
16 |
+
|
17 |
+
|
18 |
+
def load_model(*, device="cuda"):
|
19 |
+
model = WhisperModel(
|
20 |
+
"medium",
|
21 |
+
device=device,
|
22 |
+
compute_type="float16",
|
23 |
+
download_root="faster_whisper",
|
24 |
+
)
|
25 |
+
print("faster_whisper loaded!")
|
26 |
+
return model
|
27 |
+
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def batch_asr_internal(model: WhisperModel, audios, sr):
|
31 |
+
resampled_audios = []
|
32 |
+
for audio in audios:
|
33 |
+
|
34 |
+
if isinstance(audio, np.ndarray):
|
35 |
+
audio = torch.from_numpy(audio).float()
|
36 |
+
|
37 |
+
if audio.dim() > 1:
|
38 |
+
audio = audio.squeeze()
|
39 |
+
|
40 |
+
assert audio.dim() == 1
|
41 |
+
audio_np = audio.numpy()
|
42 |
+
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
|
43 |
+
resampled_audios.append(resampled_audio)
|
44 |
+
|
45 |
+
trans_results = []
|
46 |
+
|
47 |
+
for resampled_audio in resampled_audios:
|
48 |
+
segments, info = model.transcribe(
|
49 |
+
resampled_audio,
|
50 |
+
language=None,
|
51 |
+
beam_size=5,
|
52 |
+
initial_prompt="Punctuation is needed in any language.",
|
53 |
+
)
|
54 |
+
trans_results.append(list(segments))
|
55 |
+
|
56 |
+
results = []
|
57 |
+
for trans_res, audio in zip(trans_results, audios):
|
58 |
+
|
59 |
+
duration = len(audio) / sr * 1000
|
60 |
+
huge_gap = False
|
61 |
+
max_gap = 0.0
|
62 |
+
|
63 |
+
text = None
|
64 |
+
last_tr = None
|
65 |
+
|
66 |
+
for tr in trans_res:
|
67 |
+
delta = tr.text.strip()
|
68 |
+
if tr.id > 1:
|
69 |
+
max_gap = max(tr.start - last_tr.end, max_gap)
|
70 |
+
text += delta
|
71 |
+
else:
|
72 |
+
text = delta
|
73 |
+
|
74 |
+
last_tr = tr
|
75 |
+
if max_gap > 3.0:
|
76 |
+
huge_gap = True
|
77 |
+
break
|
78 |
+
|
79 |
+
sim_text = t2s_converter.convert(text)
|
80 |
+
results.append(
|
81 |
+
{
|
82 |
+
"text": sim_text,
|
83 |
+
"duration": duration,
|
84 |
+
"huge_gap": huge_gap,
|
85 |
+
}
|
86 |
+
)
|
87 |
+
|
88 |
+
return results
|
89 |
+
|
90 |
+
|
91 |
+
global_lock = Lock()
|
92 |
+
|
93 |
+
|
94 |
+
def batch_asr(model, audios, sr):
|
95 |
+
return batch_asr_internal(model, audios, sr)
|
96 |
+
|
97 |
+
|
98 |
+
def is_chinese(text):
|
99 |
+
return True
|
100 |
+
|
101 |
+
|
102 |
+
def calculate_wer(text1, text2, debug=False):
|
103 |
+
chars1 = remove_punctuation(text1)
|
104 |
+
chars2 = remove_punctuation(text2)
|
105 |
+
|
106 |
+
m, n = len(chars1), len(chars2)
|
107 |
+
|
108 |
+
if m > n:
|
109 |
+
chars1, chars2 = chars2, chars1
|
110 |
+
m, n = n, m
|
111 |
+
|
112 |
+
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
|
113 |
+
curr = [0] * (m + 1)
|
114 |
+
|
115 |
+
for j in range(1, n + 1):
|
116 |
+
curr[0] = j
|
117 |
+
for i in range(1, m + 1):
|
118 |
+
if chars1[i - 1] == chars2[j - 1]:
|
119 |
+
curr[i] = prev[i - 1]
|
120 |
+
else:
|
121 |
+
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
|
122 |
+
prev, curr = curr, prev
|
123 |
+
|
124 |
+
edits = prev[m]
|
125 |
+
tot = max(len(chars1), len(chars2))
|
126 |
+
wer = edits / tot
|
127 |
+
|
128 |
+
if debug:
|
129 |
+
print(" gt: ", chars1)
|
130 |
+
print(" pred: ", chars2)
|
131 |
+
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
|
132 |
+
|
133 |
+
return wer
|
134 |
+
|
135 |
+
|
136 |
+
def remove_punctuation(text):
|
137 |
+
chinese_punctuation = (
|
138 |
+
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
|
139 |
+
'‛""„‟…‧﹏'
|
140 |
+
)
|
141 |
+
all_punctuation = string.punctuation + chinese_punctuation
|
142 |
+
translator = str.maketrans("", "", all_punctuation)
|
143 |
+
text_without_punctuation = text.translate(translator)
|
144 |
+
return text_without_punctuation
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
model = load_model()
|
149 |
+
audios = [
|
150 |
+
librosa.load("44100.wav", sr=44100)[0],
|
151 |
+
librosa.load("lengyue.wav", sr=44100)[0],
|
152 |
+
]
|
153 |
+
print(np.array(audios[0]))
|
154 |
+
print(batch_asr(model, audios, 44100))
|
155 |
+
|
156 |
+
start_time = time.time()
|
157 |
+
for _ in range(10):
|
158 |
+
print(batch_asr(model, audios, 44100))
|
159 |
+
print("Time taken:", time.time() - start_time)
|
tools/llama/build_dataset.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from collections import defaultdict
|
5 |
+
from functools import partial
|
6 |
+
from multiprocessing import Pool
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
import click
|
10 |
+
import numpy as np
|
11 |
+
from loguru import logger
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
|
15 |
+
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
|
16 |
+
from fish_speech.utils.file import load_filelist
|
17 |
+
|
18 |
+
# To avoid CPU overload
|
19 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
20 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
21 |
+
|
22 |
+
|
23 |
+
def task_generator_folder(root: Path, text_extension: str):
|
24 |
+
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
|
25 |
+
files = sorted(files)
|
26 |
+
|
27 |
+
grouped_files = defaultdict(list)
|
28 |
+
for file in tqdm(files, desc=f"Grouping {root}"):
|
29 |
+
p = str(file.parent)
|
30 |
+
speaker = file.parent.name
|
31 |
+
|
32 |
+
try:
|
33 |
+
if isinstance(text_extension, str):
|
34 |
+
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")]
|
35 |
+
else:
|
36 |
+
texts = [
|
37 |
+
file.with_suffix(ext).read_text(encoding="utf-8")
|
38 |
+
for ext in text_extension
|
39 |
+
]
|
40 |
+
except Exception as e:
|
41 |
+
logger.error(f"Failed to read text {file}: {e}")
|
42 |
+
continue
|
43 |
+
|
44 |
+
grouped_files[p].append((speaker, file, texts))
|
45 |
+
|
46 |
+
logger.info(
|
47 |
+
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..."
|
48 |
+
)
|
49 |
+
|
50 |
+
for i in grouped_files.values():
|
51 |
+
subset = [(f, t) for _, f, t in i]
|
52 |
+
yield i[0][0], subset, "folder"
|
53 |
+
|
54 |
+
|
55 |
+
def task_generator_filelist(filelist):
|
56 |
+
grouped_files = defaultdict(list)
|
57 |
+
for filename, speaker, _, text in load_filelist(filelist):
|
58 |
+
grouped_files[speaker].append((Path(filename), [text]))
|
59 |
+
|
60 |
+
logger.info(f"Found {len(grouped_files)} groups in {filelist}")
|
61 |
+
for speaker, values in grouped_files.items():
|
62 |
+
yield speaker, values, "filelist"
|
63 |
+
|
64 |
+
|
65 |
+
def run_task(task):
|
66 |
+
name, subset, source = task
|
67 |
+
|
68 |
+
# Parse the files
|
69 |
+
sentences = []
|
70 |
+
for file, texts in subset:
|
71 |
+
np_file = file.with_suffix(".npy")
|
72 |
+
if np_file.exists() is False:
|
73 |
+
logger.warning(f"Can't find {np_file}")
|
74 |
+
continue
|
75 |
+
|
76 |
+
new_texts = []
|
77 |
+
|
78 |
+
for text in texts:
|
79 |
+
# Simple cleaning: replace { xxx } and < xxx > with space
|
80 |
+
text = re.sub(r"\{.*?\}", " ", text)
|
81 |
+
text = re.sub(r"<.*?>", " ", text)
|
82 |
+
text = re.sub(r"\s+", " ", text)
|
83 |
+
new_texts.append(text)
|
84 |
+
|
85 |
+
try:
|
86 |
+
semantics = np.load(np_file)
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"Failed to parse {file}: {e}")
|
89 |
+
continue
|
90 |
+
|
91 |
+
if isinstance(semantics, np.ndarray):
|
92 |
+
semantics = semantics.tolist()
|
93 |
+
|
94 |
+
sentences.append(
|
95 |
+
Sentence(
|
96 |
+
texts=new_texts,
|
97 |
+
semantics=[Semantics(values=s) for s in semantics],
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
# Pack the sentences
|
102 |
+
return pack_pb_stream(
|
103 |
+
TextData(
|
104 |
+
source=source,
|
105 |
+
name=name,
|
106 |
+
sentences=sentences,
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
@click.command()
|
112 |
+
@click.option(
|
113 |
+
"--input",
|
114 |
+
type=click.Path(path_type=Path),
|
115 |
+
required=True,
|
116 |
+
help="A folder containing the dataset or a filelist",
|
117 |
+
multiple=True,
|
118 |
+
)
|
119 |
+
@click.option(
|
120 |
+
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft"
|
121 |
+
)
|
122 |
+
@click.option("--num-workers", type=int, default=16)
|
123 |
+
@click.option("--text-extension", type=str, default=[".txt"], multiple=True)
|
124 |
+
@click.option(
|
125 |
+
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb"
|
126 |
+
)
|
127 |
+
def main(input, output, num_workers, text_extension, shard_size):
|
128 |
+
generator_fns = []
|
129 |
+
|
130 |
+
for f in input:
|
131 |
+
assert f.exists(), f"{f} not found"
|
132 |
+
|
133 |
+
if f.is_dir():
|
134 |
+
generator_fn = task_generator_folder(f, text_extension)
|
135 |
+
else:
|
136 |
+
generator_fn = task_generator_filelist(f)
|
137 |
+
|
138 |
+
generator_fns.append(generator_fn)
|
139 |
+
|
140 |
+
generator_fn = itertools.chain(*generator_fns)
|
141 |
+
output.mkdir(parents=True, exist_ok=True)
|
142 |
+
|
143 |
+
dataset_fp = None
|
144 |
+
tar_idx = 0
|
145 |
+
written_size = 0
|
146 |
+
|
147 |
+
with Pool(num_workers) as p:
|
148 |
+
for result in tqdm(p.imap_unordered(run_task, generator_fn)):
|
149 |
+
if dataset_fp is None:
|
150 |
+
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb")
|
151 |
+
|
152 |
+
dataset_fp.write(result)
|
153 |
+
written_size += len(result)
|
154 |
+
|
155 |
+
if written_size > shard_size * 1024 * 1024:
|
156 |
+
logger.info(f"Finished writing {tar_idx} shards to {output}")
|
157 |
+
dataset_fp.close()
|
158 |
+
dataset_fp = None
|
159 |
+
written_size = 0
|
160 |
+
tar_idx += 1
|
161 |
+
|
162 |
+
if dataset_fp is not None:
|
163 |
+
dataset_fp.close()
|
164 |
+
|
165 |
+
logger.info(f"Finished writing {tar_idx + 1} shards to {output}")
|
166 |
+
|
167 |
+
|
168 |
+
if __name__ == "__main__":
|
169 |
+
main()
|
tools/llama/eval_in_context.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyrootutils
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
# register eval resolver and root
|
8 |
+
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
9 |
+
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
|
12 |
+
from fish_speech.datasets.semantic import AutoAugTextDataset, TextDataCollator
|
13 |
+
from tools.llama.generate import load_model
|
14 |
+
|
15 |
+
|
16 |
+
def smooth(
|
17 |
+
scalars: list[float], weight: float
|
18 |
+
) -> list[float]: # Weight between 0 and 1
|
19 |
+
last = scalars[0] # First value in the plot (first timestep)
|
20 |
+
smoothed = list()
|
21 |
+
for point in scalars:
|
22 |
+
smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value
|
23 |
+
smoothed.append(smoothed_val) # Save it
|
24 |
+
last = smoothed_val # Anchor the last smoothed value
|
25 |
+
|
26 |
+
return smoothed
|
27 |
+
|
28 |
+
|
29 |
+
@torch.inference_mode()
|
30 |
+
def analyze_one_model(loader, config, weight, max_length):
|
31 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
+
model = load_model(
|
33 |
+
config,
|
34 |
+
weight,
|
35 |
+
device,
|
36 |
+
torch.bfloat16,
|
37 |
+
max_length,
|
38 |
+
compile=False,
|
39 |
+
)[0]
|
40 |
+
|
41 |
+
current_step = 0
|
42 |
+
model.eval()
|
43 |
+
|
44 |
+
semantic_loss_sum = torch.zeros(
|
45 |
+
max_length,
|
46 |
+
dtype=torch.float32,
|
47 |
+
device=device,
|
48 |
+
)
|
49 |
+
counter = torch.zeros(
|
50 |
+
max_length,
|
51 |
+
dtype=torch.long,
|
52 |
+
device=device,
|
53 |
+
)
|
54 |
+
|
55 |
+
for batch in loader:
|
56 |
+
batch = {k: v.to(device) for k, v in batch.items()}
|
57 |
+
|
58 |
+
labels = batch["labels"]
|
59 |
+
outputs = model(
|
60 |
+
inp=batch["inputs"],
|
61 |
+
key_padding_mask=batch["attention_masks"],
|
62 |
+
)
|
63 |
+
|
64 |
+
token_logits = outputs.token_logits
|
65 |
+
codebook_logits = outputs.codebook_logits
|
66 |
+
|
67 |
+
# Generate labels
|
68 |
+
base_loss = F.cross_entropy(
|
69 |
+
token_logits.reshape(-1, token_logits.size(-1)),
|
70 |
+
labels[:, 0].reshape(-1),
|
71 |
+
ignore_index=-100,
|
72 |
+
reduction="none",
|
73 |
+
)
|
74 |
+
|
75 |
+
codebook_labels = labels[:, 1 : 1 + model.config.num_codebooks].mT
|
76 |
+
semantic_loss = F.cross_entropy(
|
77 |
+
codebook_logits.reshape(-1, codebook_logits.size(-1)),
|
78 |
+
codebook_labels.reshape(-1),
|
79 |
+
ignore_index=-100,
|
80 |
+
reduction="none",
|
81 |
+
)
|
82 |
+
|
83 |
+
base_loss = base_loss.reshape(labels[:, 0].shape)
|
84 |
+
semantic_loss = semantic_loss.reshape(codebook_labels.shape)
|
85 |
+
|
86 |
+
semantic_loss_frame = semantic_loss.mean(-1)
|
87 |
+
pad_pos = codebook_labels.sum(-1) == -100 * model.config.num_codebooks
|
88 |
+
|
89 |
+
for loss_sample, pad in zip(semantic_loss_frame, pad_pos):
|
90 |
+
semantic_loss_sum[~pad] += loss_sample[~pad]
|
91 |
+
counter[~pad] += 1
|
92 |
+
|
93 |
+
current_step += 1
|
94 |
+
if current_step == 10:
|
95 |
+
break
|
96 |
+
|
97 |
+
semantic_loss = semantic_loss.cpu()
|
98 |
+
counter = counter.cpu()
|
99 |
+
xs, ys = [], []
|
100 |
+
|
101 |
+
for i, (loss, count) in enumerate(zip(semantic_loss_sum, counter)):
|
102 |
+
if count > 0:
|
103 |
+
xs.append(i)
|
104 |
+
ys.append((loss / count).item()) # for better loss visualization
|
105 |
+
|
106 |
+
smoothed_ys = smooth(ys, 0.95)
|
107 |
+
|
108 |
+
# Unload model
|
109 |
+
del model
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
|
112 |
+
return xs, ys, smoothed_ys
|
113 |
+
|
114 |
+
|
115 |
+
def main():
|
116 |
+
tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
|
117 |
+
max_length = 4096
|
118 |
+
|
119 |
+
ds = AutoAugTextDataset(
|
120 |
+
["data/protos/sft/云天河"],
|
121 |
+
tokenizer=tokenizer,
|
122 |
+
use_speaker=False,
|
123 |
+
interactive_prob=1.0,
|
124 |
+
max_length=max_length,
|
125 |
+
)
|
126 |
+
|
127 |
+
loader = DataLoader(
|
128 |
+
ds,
|
129 |
+
batch_size=8,
|
130 |
+
collate_fn=TextDataCollator(tokenizer, max_length=max_length),
|
131 |
+
num_workers=0,
|
132 |
+
shuffle=False,
|
133 |
+
)
|
134 |
+
|
135 |
+
plt.figure(figsize=(10, 5), dpi=200)
|
136 |
+
|
137 |
+
plt.xlabel("Frame")
|
138 |
+
plt.ylabel("Loss")
|
139 |
+
plt.yscale("log")
|
140 |
+
plt.title("Semantic Loss")
|
141 |
+
plt.grid(which="both", axis="both")
|
142 |
+
plt.xlim(0, max_length)
|
143 |
+
|
144 |
+
tests = [
|
145 |
+
(
|
146 |
+
"pertrain-medium",
|
147 |
+
"dual_ar_2_codebook_medium",
|
148 |
+
"checkpoints/text2semantic-pretrain-medium-2k-v1.pth",
|
149 |
+
),
|
150 |
+
(
|
151 |
+
"sft-medium",
|
152 |
+
"dual_ar_2_codebook_medium",
|
153 |
+
"checkpoints/text2semantic-sft-medium-v1.1-4k.pth",
|
154 |
+
),
|
155 |
+
(
|
156 |
+
"sft-large",
|
157 |
+
"dual_ar_2_codebook_large",
|
158 |
+
"checkpoints/text2semantic-sft-large-v1.1-4k.pth",
|
159 |
+
),
|
160 |
+
]
|
161 |
+
|
162 |
+
for name, config, weight in tests:
|
163 |
+
xs, _, smoothed_ys = analyze_one_model(loader, config, weight, max_length)
|
164 |
+
plt.plot(xs, smoothed_ys, label=name)
|
165 |
+
|
166 |
+
plt.legend()
|
167 |
+
plt.savefig("semantic_loss.png")
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
main()
|
tools/llama/generate.py
CHANGED
@@ -2,8 +2,9 @@ import os
|
|
2 |
import queue
|
3 |
import threading
|
4 |
import time
|
|
|
5 |
from pathlib import Path
|
6 |
-
from typing import Optional, Tuple, Union
|
7 |
|
8 |
import click
|
9 |
import hydra
|
@@ -11,14 +12,11 @@ import numpy as np
|
|
11 |
import torch
|
12 |
import torch._dynamo.config
|
13 |
import torch._inductor.config
|
14 |
-
from hydra import compose, initialize
|
15 |
-
from hydra.utils import instantiate
|
16 |
from loguru import logger
|
17 |
from tqdm import tqdm
|
18 |
-
from transformers import AutoTokenizer
|
19 |
|
20 |
-
from fish_speech.
|
21 |
-
from fish_speech.text
|
22 |
|
23 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
24 |
torch._inductor.config.coordinate_descent_tuning = True
|
@@ -29,7 +27,11 @@ if hasattr(torch._inductor.config, "fx_graph_cache"):
|
|
29 |
torch._inductor.config.fx_graph_cache = True
|
30 |
|
31 |
|
32 |
-
from fish_speech.models.text2semantic.llama import
|
|
|
|
|
|
|
|
|
33 |
|
34 |
|
35 |
def multinomial_sample_one_no_sync(
|
@@ -94,7 +96,9 @@ def decode_one_token_ar(
|
|
94 |
codebooks = [
|
95 |
sample(
|
96 |
x.logits,
|
97 |
-
previous_tokens=
|
|
|
|
|
98 |
**sampling_kwargs,
|
99 |
)[0]
|
100 |
]
|
@@ -159,7 +163,6 @@ def decode_n_tokens(
|
|
159 |
cur_token: torch.Tensor,
|
160 |
input_pos: torch.Tensor,
|
161 |
num_new_tokens: int,
|
162 |
-
eos_token_id: int = 2,
|
163 |
im_end_id: int = 4,
|
164 |
decode_one_token=decode_one_token_naive,
|
165 |
**sampling_kwargs,
|
@@ -195,11 +198,7 @@ def decode_n_tokens(
|
|
195 |
model.config.num_codebooks + 1, -1
|
196 |
)
|
197 |
|
198 |
-
if
|
199 |
-
cur_token[0, 0, -1] == eos_token_id
|
200 |
-
or cur_token[0, 0, -1] == im_end_id
|
201 |
-
or (cur_token[0, 1:, -1] == CODEBOOK_EOS_TOKEN_ID).any()
|
202 |
-
):
|
203 |
break
|
204 |
|
205 |
return previous_tokens[:, : i + 1]
|
@@ -212,7 +211,6 @@ def generate(
|
|
212 |
model: NaiveTransformer,
|
213 |
prompt: torch.Tensor,
|
214 |
max_new_tokens: int,
|
215 |
-
eos_token_id: int = 2,
|
216 |
im_end_id: int = 4,
|
217 |
decode_one_token=decode_one_token_naive,
|
218 |
**sampling_kwargs,
|
@@ -253,6 +251,7 @@ def generate(
|
|
253 |
if isinstance(model, NaiveTransformer)
|
254 |
else decode_one_token_ar
|
255 |
)
|
|
|
256 |
next_token = prefill_decode(
|
257 |
model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
|
258 |
)
|
@@ -264,7 +263,6 @@ def generate(
|
|
264 |
next_token.view(1, codebook_dim, -1),
|
265 |
input_pos,
|
266 |
max_new_tokens - 1,
|
267 |
-
eos_token_id=eos_token_id,
|
268 |
im_end_id=im_end_id,
|
269 |
decode_one_token=decode_one_token,
|
270 |
**sampling_kwargs,
|
@@ -279,22 +277,12 @@ def generate(
|
|
279 |
def encode_tokens(
|
280 |
tokenizer,
|
281 |
string,
|
282 |
-
bos=True,
|
283 |
device="cuda",
|
284 |
prompt_tokens=None,
|
285 |
-
speaker=None,
|
286 |
num_codebooks=4,
|
287 |
):
|
288 |
string = clean_text(string)
|
289 |
-
|
290 |
-
if speaker is None:
|
291 |
-
speaker = "assistant"
|
292 |
-
|
293 |
-
string = (
|
294 |
-
f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
|
295 |
-
)
|
296 |
-
if bos:
|
297 |
-
string = f"<|begin_of_sequence|>{string}"
|
298 |
|
299 |
new_tokens = tokenizer.encode(
|
300 |
string,
|
@@ -322,7 +310,7 @@ def encode_tokens(
|
|
322 |
prompt_tokens = prompt_tokens[0]
|
323 |
|
324 |
assert prompt_tokens.ndim == 2
|
325 |
-
data = prompt_tokens +
|
326 |
|
327 |
if prompt_tokens.shape[0] > num_codebooks:
|
328 |
logger.warning(
|
@@ -330,13 +318,9 @@ def encode_tokens(
|
|
330 |
)
|
331 |
data = data[:num_codebooks]
|
332 |
|
333 |
-
# Add
|
334 |
data = torch.cat(
|
335 |
-
(
|
336 |
-
data,
|
337 |
-
torch.ones((data.size(0), 1), dtype=torch.int, device=device)
|
338 |
-
* CODEBOOK_EOS_TOKEN_ID,
|
339 |
-
),
|
340 |
dim=1,
|
341 |
)
|
342 |
|
@@ -354,49 +338,13 @@ def encode_tokens(
|
|
354 |
return prompt
|
355 |
|
356 |
|
357 |
-
def load_model(
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
with initialize(version_base="1.3", config_path="../../fish_speech/configs/model"):
|
362 |
-
cfg = compose(
|
363 |
-
config_name=config_name, overrides=[f"config.max_seq_len={max_length}"]
|
364 |
-
)
|
365 |
-
|
366 |
-
model: Union[NaiveTransformer, DualARTransformer] = instantiate(cfg)
|
367 |
-
|
368 |
-
if "int8" in str(checkpoint_path):
|
369 |
-
logger.info("Using int8 weight-only quantization!")
|
370 |
-
from quantize import WeightOnlyInt8QuantHandler
|
371 |
-
|
372 |
-
simple_quantizer = WeightOnlyInt8QuantHandler(model)
|
373 |
-
model = simple_quantizer.convert_for_runtime()
|
374 |
-
|
375 |
-
if "int4" in str(checkpoint_path):
|
376 |
-
logger.info("Using int4 quantization!")
|
377 |
-
path_comps = checkpoint_path.name.split(".")
|
378 |
-
assert path_comps[-2].startswith("g")
|
379 |
-
groupsize = int(path_comps[-2][1:])
|
380 |
-
from quantize import WeightOnlyInt4QuantHandler
|
381 |
-
|
382 |
-
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
|
383 |
-
model = simple_quantizer.convert_for_runtime()
|
384 |
-
|
385 |
-
checkpoint = torch.load(str(checkpoint_path), map_location="cpu")
|
386 |
-
if "state_dict" in checkpoint:
|
387 |
-
checkpoint = checkpoint["state_dict"]
|
388 |
-
|
389 |
-
if any(k.startswith("model.") for k in checkpoint):
|
390 |
-
checkpoint = {
|
391 |
-
k.replace("model.", ""): v
|
392 |
-
for k, v in checkpoint.items()
|
393 |
-
if k.startswith("model.")
|
394 |
-
}
|
395 |
-
|
396 |
-
model.load_state_dict(checkpoint, assign=True)
|
397 |
|
398 |
model = model.to(device=device, dtype=precision)
|
399 |
-
logger.info("Restored model from checkpoint")
|
400 |
|
401 |
if isinstance(model, DualARTransformer):
|
402 |
decode_one_token = decode_one_token_ar
|
@@ -414,29 +362,16 @@ def load_model(
|
|
414 |
return model.eval(), decode_one_token
|
415 |
|
416 |
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
curr += char
|
423 |
-
if char not in [".", ",", "!", "?"]:
|
424 |
-
continue
|
425 |
-
|
426 |
-
if len(curr) >= min_length:
|
427 |
-
segments.append(curr)
|
428 |
-
curr = ""
|
429 |
-
|
430 |
-
if curr:
|
431 |
-
segments.append(curr)
|
432 |
-
|
433 |
-
return segments
|
434 |
|
435 |
|
436 |
def generate_long(
|
437 |
*,
|
438 |
model,
|
439 |
-
tokenizer: callable,
|
440 |
device: str | torch.device,
|
441 |
decode_one_token: callable,
|
442 |
text: str,
|
@@ -448,42 +383,49 @@ def generate_long(
|
|
448 |
compile: bool = False,
|
449 |
iterative_prompt: bool = True,
|
450 |
max_length: int = 2048,
|
451 |
-
chunk_length: int =
|
452 |
-
|
453 |
-
|
454 |
-
prompt_tokens: Optional[torch.Tensor] = None,
|
455 |
-
is_streaming: bool = False,
|
456 |
):
|
457 |
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
458 |
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
459 |
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
462 |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
463 |
|
464 |
-
use_prompt = prompt_text is not None and prompt_tokens is not None
|
465 |
encoded = []
|
466 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
467 |
|
468 |
if use_prompt:
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
|
|
478 |
|
479 |
for idx, text in enumerate(texts):
|
480 |
encoded.append(
|
481 |
encode_tokens(
|
482 |
tokenizer,
|
483 |
string=text,
|
484 |
-
bos=idx == 0 and not use_prompt,
|
485 |
device=device,
|
486 |
-
speaker=speaker,
|
487 |
num_codebooks=model.config.num_codebooks,
|
488 |
)
|
489 |
)
|
@@ -502,7 +444,6 @@ def generate_long(
|
|
502 |
torch.cuda.synchronize()
|
503 |
|
504 |
global_encoded = []
|
505 |
-
all_codes = []
|
506 |
seg_idx = 0
|
507 |
|
508 |
while seg_idx < len(encoded):
|
@@ -519,7 +460,9 @@ def generate_long(
|
|
519 |
count = 0
|
520 |
for i, length in enumerate(lengths):
|
521 |
count += length
|
522 |
-
if count + length > max_length - 1024
|
|
|
|
|
523 |
break
|
524 |
|
525 |
if i != 0 and i % 2 == 0:
|
@@ -532,7 +475,7 @@ def generate_long(
|
|
532 |
partial_encoded = global_encoded
|
533 |
|
534 |
if use_prompt:
|
535 |
-
partial_encoded =
|
536 |
|
537 |
cat_encoded = torch.cat(partial_encoded, dim=1)
|
538 |
prompt_length = cat_encoded.size(1)
|
@@ -542,7 +485,6 @@ def generate_long(
|
|
542 |
model=model,
|
543 |
prompt=cat_encoded,
|
544 |
max_new_tokens=max_new_tokens,
|
545 |
-
eos_token_id=tokenizer.eos_token_id,
|
546 |
im_end_id=im_end_id,
|
547 |
decode_one_token=decode_one_token,
|
548 |
temperature=temperature,
|
@@ -574,76 +516,66 @@ def generate_long(
|
|
574 |
|
575 |
# Put the generated tokens
|
576 |
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
577 |
-
codes = y[1:, prompt_length:-
|
578 |
-
|
579 |
-
codes = codes - 2
|
580 |
assert (codes >= 0).all(), f"Negative code found"
|
581 |
|
582 |
decoded = y[:, prompt_length:-1].clone()
|
583 |
-
if decoded[0, -1] != im_end_id: # <im_end>
|
584 |
-
val = [[im_end_id]] + [[CODEBOOK_EOS_TOKEN_ID]] * (decoded.size(0) - 1)
|
585 |
-
decoded = torch.cat(
|
586 |
-
(decoded, torch.tensor(val, device=device, dtype=torch.int)), dim=1
|
587 |
-
)
|
588 |
-
|
589 |
# But for global encoding, we should keep the <im_end> token
|
|
|
590 |
global_encoded.append(decoded)
|
|
|
|
|
|
|
591 |
|
592 |
-
|
593 |
-
|
594 |
-
yield codes
|
595 |
-
else:
|
596 |
-
all_codes.append(codes)
|
597 |
|
598 |
-
seg_idx += 1
|
599 |
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
|
|
|
|
|
|
607 |
|
608 |
|
609 |
def launch_thread_safe_queue(
|
610 |
-
config_name,
|
611 |
checkpoint_path,
|
612 |
device,
|
613 |
precision,
|
614 |
-
|
615 |
-
compile=False,
|
616 |
):
|
617 |
input_queue = queue.Queue()
|
618 |
init_event = threading.Event()
|
619 |
|
620 |
def worker():
|
621 |
model, decode_one_token = load_model(
|
622 |
-
|
623 |
)
|
624 |
init_event.set()
|
625 |
|
626 |
while True:
|
627 |
-
item = input_queue.get()
|
628 |
if item is None:
|
629 |
break
|
630 |
|
631 |
-
kwargs = item
|
632 |
-
response_queue = item
|
633 |
|
634 |
try:
|
635 |
-
item["success"] = True
|
636 |
for chunk in generate_long(
|
637 |
model=model, decode_one_token=decode_one_token, **kwargs
|
638 |
):
|
639 |
-
response_queue.put(
|
640 |
-
|
641 |
-
|
642 |
except Exception as e:
|
643 |
-
|
644 |
-
item["response"] = e
|
645 |
-
|
646 |
-
response_queue.put("done")
|
647 |
|
648 |
threading.Thread(target=worker, daemon=True).start()
|
649 |
init_event.wait()
|
@@ -657,57 +589,58 @@ def launch_thread_safe_queue(
|
|
657 |
type=str,
|
658 |
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
659 |
)
|
660 |
-
@click.option("--prompt-text", type=str, default=None)
|
661 |
@click.option(
|
662 |
-
"--prompt-tokens",
|
|
|
|
|
|
|
663 |
)
|
664 |
@click.option("--num-samples", type=int, default=1)
|
665 |
@click.option("--max-new-tokens", type=int, default=0)
|
666 |
@click.option("--top-p", type=float, default=0.7)
|
667 |
-
@click.option("--repetition-penalty", type=float, default=1.
|
668 |
@click.option("--temperature", type=float, default=0.7)
|
669 |
@click.option(
|
670 |
"--checkpoint-path",
|
671 |
type=click.Path(path_type=Path, exists=True),
|
672 |
-
default="
|
673 |
)
|
674 |
-
@click.option("--
|
675 |
-
@click.option("--tokenizer", type=str, default="fishaudio/fish-speech-1")
|
676 |
@click.option("--compile/--no-compile", default=False)
|
677 |
@click.option("--seed", type=int, default=42)
|
678 |
-
@click.option("--speaker", type=str, default=None)
|
679 |
@click.option("--half/--no-half", default=False)
|
680 |
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
681 |
-
@click.option("--
|
682 |
-
@click.option("--chunk-length", type=int, default=30)
|
683 |
def main(
|
684 |
text: str,
|
685 |
-
prompt_text: Optional[str],
|
686 |
-
prompt_tokens: Optional[Path],
|
687 |
num_samples: int,
|
688 |
max_new_tokens: int,
|
689 |
top_p: int,
|
690 |
repetition_penalty: float,
|
691 |
temperature: float,
|
692 |
checkpoint_path: Path,
|
693 |
-
|
694 |
-
tokenizer: str,
|
695 |
compile: bool,
|
696 |
seed: int,
|
697 |
-
speaker: Optional[str],
|
698 |
half: bool,
|
699 |
iterative_prompt: bool,
|
700 |
-
max_length: int,
|
701 |
chunk_length: int,
|
702 |
) -> None:
|
703 |
-
device = "cuda"
|
704 |
|
705 |
precision = torch.half if half else torch.bfloat16
|
706 |
|
|
|
|
|
|
|
|
|
|
|
707 |
logger.info("Loading model ...")
|
708 |
t0 = time.time()
|
709 |
model, decode_one_token = load_model(
|
710 |
-
|
711 |
)
|
712 |
|
713 |
if torch.cuda.is_available():
|
@@ -715,13 +648,9 @@ def main(
|
|
715 |
|
716 |
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
717 |
|
718 |
-
prompt_tokens
|
719 |
-
torch.from_numpy(np.load(
|
720 |
-
if prompt_tokens is not None
|
721 |
-
else None
|
722 |
-
)
|
723 |
|
724 |
-
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
725 |
torch.manual_seed(seed)
|
726 |
|
727 |
if torch.cuda.is_available():
|
@@ -737,19 +666,29 @@ def main(
|
|
737 |
top_p=top_p,
|
738 |
repetition_penalty=repetition_penalty,
|
739 |
temperature=temperature,
|
740 |
-
tokenizer=tokenizer,
|
741 |
compile=compile,
|
742 |
-
speaker=speaker,
|
743 |
iterative_prompt=iterative_prompt,
|
744 |
-
max_length=max_length,
|
745 |
chunk_length=chunk_length,
|
746 |
prompt_text=prompt_text,
|
747 |
prompt_tokens=prompt_tokens,
|
748 |
)
|
749 |
|
750 |
-
|
751 |
-
|
752 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
753 |
|
754 |
|
755 |
if __name__ == "__main__":
|
|
|
2 |
import queue
|
3 |
import threading
|
4 |
import time
|
5 |
+
from dataclasses import dataclass
|
6 |
from pathlib import Path
|
7 |
+
from typing import Literal, Optional, Tuple, Union
|
8 |
|
9 |
import click
|
10 |
import hydra
|
|
|
12 |
import torch
|
13 |
import torch._dynamo.config
|
14 |
import torch._inductor.config
|
|
|
|
|
15 |
from loguru import logger
|
16 |
from tqdm import tqdm
|
|
|
17 |
|
18 |
+
from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
|
19 |
+
from fish_speech.text import clean_text, split_text
|
20 |
|
21 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
22 |
torch._inductor.config.coordinate_descent_tuning = True
|
|
|
27 |
torch._inductor.config.fx_graph_cache = True
|
28 |
|
29 |
|
30 |
+
from fish_speech.models.text2semantic.llama import (
|
31 |
+
BaseTransformer,
|
32 |
+
DualARTransformer,
|
33 |
+
NaiveTransformer,
|
34 |
+
)
|
35 |
|
36 |
|
37 |
def multinomial_sample_one_no_sync(
|
|
|
96 |
codebooks = [
|
97 |
sample(
|
98 |
x.logits,
|
99 |
+
previous_tokens=(
|
100 |
+
previous_tokens[0] if previous_tokens is not None else None
|
101 |
+
), # Disable repetition penalty for the token codebook
|
102 |
**sampling_kwargs,
|
103 |
)[0]
|
104 |
]
|
|
|
163 |
cur_token: torch.Tensor,
|
164 |
input_pos: torch.Tensor,
|
165 |
num_new_tokens: int,
|
|
|
166 |
im_end_id: int = 4,
|
167 |
decode_one_token=decode_one_token_naive,
|
168 |
**sampling_kwargs,
|
|
|
198 |
model.config.num_codebooks + 1, -1
|
199 |
)
|
200 |
|
201 |
+
if cur_token[0, 0, -1] == im_end_id:
|
|
|
|
|
|
|
|
|
202 |
break
|
203 |
|
204 |
return previous_tokens[:, : i + 1]
|
|
|
211 |
model: NaiveTransformer,
|
212 |
prompt: torch.Tensor,
|
213 |
max_new_tokens: int,
|
|
|
214 |
im_end_id: int = 4,
|
215 |
decode_one_token=decode_one_token_naive,
|
216 |
**sampling_kwargs,
|
|
|
251 |
if isinstance(model, NaiveTransformer)
|
252 |
else decode_one_token_ar
|
253 |
)
|
254 |
+
|
255 |
next_token = prefill_decode(
|
256 |
model, prompt.view(1, codebook_dim, -1), input_pos, **sampling_kwargs
|
257 |
)
|
|
|
263 |
next_token.view(1, codebook_dim, -1),
|
264 |
input_pos,
|
265 |
max_new_tokens - 1,
|
|
|
266 |
im_end_id=im_end_id,
|
267 |
decode_one_token=decode_one_token,
|
268 |
**sampling_kwargs,
|
|
|
277 |
def encode_tokens(
|
278 |
tokenizer,
|
279 |
string,
|
|
|
280 |
device="cuda",
|
281 |
prompt_tokens=None,
|
|
|
282 |
num_codebooks=4,
|
283 |
):
|
284 |
string = clean_text(string)
|
285 |
+
string = f"<|im_start|>user\n{string}<|im_end|><|im_start|>assistant\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
new_tokens = tokenizer.encode(
|
288 |
string,
|
|
|
310 |
prompt_tokens = prompt_tokens[0]
|
311 |
|
312 |
assert prompt_tokens.ndim == 2
|
313 |
+
data = prompt_tokens + 1
|
314 |
|
315 |
if prompt_tokens.shape[0] > num_codebooks:
|
316 |
logger.warning(
|
|
|
318 |
)
|
319 |
data = data[:num_codebooks]
|
320 |
|
321 |
+
# Add pad token for each codebook
|
322 |
data = torch.cat(
|
323 |
+
(data, torch.zeros((data.size(0), 1), dtype=torch.int, device=device)),
|
|
|
|
|
|
|
|
|
324 |
dim=1,
|
325 |
)
|
326 |
|
|
|
338 |
return prompt
|
339 |
|
340 |
|
341 |
+
def load_model(checkpoint_path, device, precision, compile=False):
|
342 |
+
model: Union[NaiveTransformer, DualARTransformer] = BaseTransformer.from_pretrained(
|
343 |
+
checkpoint_path, load_weights=True
|
344 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
model = model.to(device=device, dtype=precision)
|
347 |
+
logger.info(f"Restored model from checkpoint")
|
348 |
|
349 |
if isinstance(model, DualARTransformer):
|
350 |
decode_one_token = decode_one_token_ar
|
|
|
362 |
return model.eval(), decode_one_token
|
363 |
|
364 |
|
365 |
+
@dataclass
|
366 |
+
class GenerateResponse:
|
367 |
+
action: Literal["sample", "next"]
|
368 |
+
codes: Optional[torch.Tensor] = None
|
369 |
+
text: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
|
372 |
def generate_long(
|
373 |
*,
|
374 |
model,
|
|
|
375 |
device: str | torch.device,
|
376 |
decode_one_token: callable,
|
377 |
text: str,
|
|
|
383 |
compile: bool = False,
|
384 |
iterative_prompt: bool = True,
|
385 |
max_length: int = 2048,
|
386 |
+
chunk_length: int = 150,
|
387 |
+
prompt_text: Optional[str | list[str]] = None,
|
388 |
+
prompt_tokens: Optional[torch.Tensor | list[torch.Tensor]] = None,
|
|
|
|
|
389 |
):
|
390 |
assert 0 < top_p <= 1, "top_p must be in (0, 1]"
|
391 |
assert 0 < repetition_penalty < 2, "repetition_penalty must be in (0, 2)"
|
392 |
assert 0 < temperature < 2, "temperature must be in (0, 2)"
|
393 |
|
394 |
+
use_prompt = prompt_text is not None and prompt_tokens is not None
|
395 |
+
if use_prompt and isinstance(prompt_text, str):
|
396 |
+
prompt_text = [prompt_text]
|
397 |
+
prompt_tokens = [prompt_tokens]
|
398 |
+
|
399 |
+
assert use_prompt is False or len(prompt_text) == len(
|
400 |
+
prompt_tokens
|
401 |
+
), "Prompt text and tokens must have the same length"
|
402 |
+
|
403 |
model_size = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
404 |
+
tokenizer = model.tokenizer
|
405 |
im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
406 |
|
|
|
407 |
encoded = []
|
408 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
409 |
+
encoded_prompts = []
|
410 |
|
411 |
if use_prompt:
|
412 |
+
for idx, (t, c) in enumerate(zip(prompt_text, prompt_tokens)):
|
413 |
+
encoded_prompts.append(
|
414 |
+
encode_tokens(
|
415 |
+
tokenizer,
|
416 |
+
string=t,
|
417 |
+
device=device,
|
418 |
+
prompt_tokens=c,
|
419 |
+
num_codebooks=model.config.num_codebooks,
|
420 |
+
)
|
421 |
+
)
|
422 |
|
423 |
for idx, text in enumerate(texts):
|
424 |
encoded.append(
|
425 |
encode_tokens(
|
426 |
tokenizer,
|
427 |
string=text,
|
|
|
428 |
device=device,
|
|
|
429 |
num_codebooks=model.config.num_codebooks,
|
430 |
)
|
431 |
)
|
|
|
444 |
torch.cuda.synchronize()
|
445 |
|
446 |
global_encoded = []
|
|
|
447 |
seg_idx = 0
|
448 |
|
449 |
while seg_idx < len(encoded):
|
|
|
460 |
count = 0
|
461 |
for i, length in enumerate(lengths):
|
462 |
count += length
|
463 |
+
if count + length > max_length - 1024 - sum(
|
464 |
+
t.shape[1] for t in encoded_prompts
|
465 |
+
):
|
466 |
break
|
467 |
|
468 |
if i != 0 and i % 2 == 0:
|
|
|
475 |
partial_encoded = global_encoded
|
476 |
|
477 |
if use_prompt:
|
478 |
+
partial_encoded = encoded_prompts + partial_encoded
|
479 |
|
480 |
cat_encoded = torch.cat(partial_encoded, dim=1)
|
481 |
prompt_length = cat_encoded.size(1)
|
|
|
485 |
model=model,
|
486 |
prompt=cat_encoded,
|
487 |
max_new_tokens=max_new_tokens,
|
|
|
488 |
im_end_id=im_end_id,
|
489 |
decode_one_token=decode_one_token,
|
490 |
temperature=temperature,
|
|
|
516 |
|
517 |
# Put the generated tokens
|
518 |
# since there is <im_end> and <eos> tokens, we remove last 2 tokens
|
519 |
+
codes = y[1:, prompt_length:-1].clone()
|
520 |
+
codes = codes - 1
|
|
|
521 |
assert (codes >= 0).all(), f"Negative code found"
|
522 |
|
523 |
decoded = y[:, prompt_length:-1].clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
# But for global encoding, we should keep the <im_end> token
|
525 |
+
|
526 |
global_encoded.append(decoded)
|
527 |
+
assert (codes >= 0).all(), f"Negative code found: {codes}"
|
528 |
+
yield GenerateResponse(action="sample", codes=codes, text=texts[seg_idx])
|
529 |
+
seg_idx += 1
|
530 |
|
531 |
+
# This indicates the end of the current sample
|
532 |
+
yield GenerateResponse(action="next")
|
|
|
|
|
|
|
533 |
|
|
|
534 |
|
535 |
+
@dataclass
|
536 |
+
class WrappedGenerateResponse:
|
537 |
+
status: Literal["success", "error"]
|
538 |
+
response: Optional[GenerateResponse | Exception] = None
|
539 |
+
|
540 |
+
|
541 |
+
@dataclass
|
542 |
+
class GenerateRequest:
|
543 |
+
request: dict
|
544 |
+
response_queue: queue.Queue
|
545 |
|
546 |
|
547 |
def launch_thread_safe_queue(
|
|
|
548 |
checkpoint_path,
|
549 |
device,
|
550 |
precision,
|
551 |
+
compile: bool = False,
|
|
|
552 |
):
|
553 |
input_queue = queue.Queue()
|
554 |
init_event = threading.Event()
|
555 |
|
556 |
def worker():
|
557 |
model, decode_one_token = load_model(
|
558 |
+
checkpoint_path, device, precision, compile=compile
|
559 |
)
|
560 |
init_event.set()
|
561 |
|
562 |
while True:
|
563 |
+
item: GenerateRequest | None = input_queue.get()
|
564 |
if item is None:
|
565 |
break
|
566 |
|
567 |
+
kwargs = item.request
|
568 |
+
response_queue = item.response_queue
|
569 |
|
570 |
try:
|
|
|
571 |
for chunk in generate_long(
|
572 |
model=model, decode_one_token=decode_one_token, **kwargs
|
573 |
):
|
574 |
+
response_queue.put(
|
575 |
+
WrappedGenerateResponse(status="success", response=chunk)
|
576 |
+
)
|
577 |
except Exception as e:
|
578 |
+
response_queue.put(WrappedGenerateResponse(status="error", response=e))
|
|
|
|
|
|
|
579 |
|
580 |
threading.Thread(target=worker, daemon=True).start()
|
581 |
init_event.wait()
|
|
|
589 |
type=str,
|
590 |
default="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
|
591 |
)
|
592 |
+
@click.option("--prompt-text", type=str, default=None, multiple=True)
|
593 |
@click.option(
|
594 |
+
"--prompt-tokens",
|
595 |
+
type=click.Path(path_type=Path, exists=True),
|
596 |
+
default=None,
|
597 |
+
multiple=True,
|
598 |
)
|
599 |
@click.option("--num-samples", type=int, default=1)
|
600 |
@click.option("--max-new-tokens", type=int, default=0)
|
601 |
@click.option("--top-p", type=float, default=0.7)
|
602 |
+
@click.option("--repetition-penalty", type=float, default=1.2)
|
603 |
@click.option("--temperature", type=float, default=0.7)
|
604 |
@click.option(
|
605 |
"--checkpoint-path",
|
606 |
type=click.Path(path_type=Path, exists=True),
|
607 |
+
default="checkpoints/fish-speech-1.2-sft",
|
608 |
)
|
609 |
+
@click.option("--device", type=str, default="cuda")
|
|
|
610 |
@click.option("--compile/--no-compile", default=False)
|
611 |
@click.option("--seed", type=int, default=42)
|
|
|
612 |
@click.option("--half/--no-half", default=False)
|
613 |
@click.option("--iterative-prompt/--no-iterative-prompt", default=True)
|
614 |
+
@click.option("--chunk-length", type=int, default=100)
|
|
|
615 |
def main(
|
616 |
text: str,
|
617 |
+
prompt_text: Optional[list[str]],
|
618 |
+
prompt_tokens: Optional[list[Path]],
|
619 |
num_samples: int,
|
620 |
max_new_tokens: int,
|
621 |
top_p: int,
|
622 |
repetition_penalty: float,
|
623 |
temperature: float,
|
624 |
checkpoint_path: Path,
|
625 |
+
device: str,
|
|
|
626 |
compile: bool,
|
627 |
seed: int,
|
|
|
628 |
half: bool,
|
629 |
iterative_prompt: bool,
|
|
|
630 |
chunk_length: int,
|
631 |
) -> None:
|
|
|
632 |
|
633 |
precision = torch.half if half else torch.bfloat16
|
634 |
|
635 |
+
if prompt_text is not None and len(prompt_text) != len(prompt_tokens):
|
636 |
+
raise ValueError(
|
637 |
+
f"Number of prompt text ({len(prompt_text)}) and prompt tokens ({len(prompt_tokens)}) should be the same"
|
638 |
+
)
|
639 |
+
|
640 |
logger.info("Loading model ...")
|
641 |
t0 = time.time()
|
642 |
model, decode_one_token = load_model(
|
643 |
+
checkpoint_path, device, precision, compile=compile
|
644 |
)
|
645 |
|
646 |
if torch.cuda.is_available():
|
|
|
648 |
|
649 |
logger.info(f"Time to load model: {time.time() - t0:.02f} seconds")
|
650 |
|
651 |
+
if prompt_tokens is not None:
|
652 |
+
prompt_tokens = [torch.from_numpy(np.load(p)).to(device) for p in prompt_tokens]
|
|
|
|
|
|
|
653 |
|
|
|
654 |
torch.manual_seed(seed)
|
655 |
|
656 |
if torch.cuda.is_available():
|
|
|
666 |
top_p=top_p,
|
667 |
repetition_penalty=repetition_penalty,
|
668 |
temperature=temperature,
|
|
|
669 |
compile=compile,
|
|
|
670 |
iterative_prompt=iterative_prompt,
|
|
|
671 |
chunk_length=chunk_length,
|
672 |
prompt_text=prompt_text,
|
673 |
prompt_tokens=prompt_tokens,
|
674 |
)
|
675 |
|
676 |
+
idx = 0
|
677 |
+
codes = []
|
678 |
+
|
679 |
+
for response in generator:
|
680 |
+
if response.action == "sample":
|
681 |
+
codes.append(response.codes)
|
682 |
+
logger.info(f"Sampled text: {response.text}")
|
683 |
+
elif response.action == "next":
|
684 |
+
if codes:
|
685 |
+
np.save(f"codes_{idx}.npy", torch.cat(codes, dim=1).cpu().numpy())
|
686 |
+
logger.info(f"Saved codes to codes_{idx}.npy")
|
687 |
+
logger.info(f"Next sample")
|
688 |
+
codes = []
|
689 |
+
idx += 1
|
690 |
+
else:
|
691 |
+
logger.error(f"Error: {response}")
|
692 |
|
693 |
|
694 |
if __name__ == "__main__":
|
tools/llama/merge_lora.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from copy import deepcopy
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import click
|
6 |
+
import hydra
|
7 |
+
import torch
|
8 |
+
from hydra import compose, initialize
|
9 |
+
from hydra.utils import instantiate
|
10 |
+
from loguru import logger
|
11 |
+
|
12 |
+
from fish_speech.models.text2semantic.llama import BaseTransformer
|
13 |
+
from fish_speech.models.text2semantic.lora import get_merged_state_dict
|
14 |
+
|
15 |
+
|
16 |
+
@click.command()
|
17 |
+
@click.option("--lora-config", type=str, default="r_8_alpha_16")
|
18 |
+
@click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2-sft")
|
19 |
+
@click.option("--lora-weight", type=str, required=True)
|
20 |
+
@click.option("--output", type=str, required=True)
|
21 |
+
def merge(lora_config, base_weight, lora_weight, output):
|
22 |
+
output = Path(output)
|
23 |
+
logger.info(
|
24 |
+
f"Merging {base_weight} and {lora_weight} into {output} with {lora_config}"
|
25 |
+
)
|
26 |
+
|
27 |
+
with initialize(version_base="1.3", config_path="../../fish_speech/configs/lora"):
|
28 |
+
cfg = compose(config_name=lora_config)
|
29 |
+
|
30 |
+
lora_config = instantiate(cfg)
|
31 |
+
logger.info(f"Loaded lora model with config {lora_config}")
|
32 |
+
|
33 |
+
llama_model = BaseTransformer.from_pretrained(
|
34 |
+
path=base_weight,
|
35 |
+
load_weights=True,
|
36 |
+
lora_config=lora_config,
|
37 |
+
)
|
38 |
+
logger.info(f"Loaded llama model")
|
39 |
+
|
40 |
+
llama_state_dict = llama_model.state_dict()
|
41 |
+
llama_state_dict = {k: v for k, v in llama_state_dict.items() if "lora" not in k}
|
42 |
+
llama_state_dict_copy = deepcopy(llama_state_dict)
|
43 |
+
lora_state_dict = torch.load(lora_weight, map_location="cpu")
|
44 |
+
|
45 |
+
if "state_dict" in llama_state_dict:
|
46 |
+
llama_state_dict = llama_state_dict["state_dict"]
|
47 |
+
|
48 |
+
if "state_dict" in lora_state_dict:
|
49 |
+
lora_state_dict = lora_state_dict["state_dict"]
|
50 |
+
|
51 |
+
# remove prefix model.
|
52 |
+
if any(k.startswith("model.") for k in llama_state_dict.keys()):
|
53 |
+
llama_state_dict = {
|
54 |
+
k.replace("model.", ""): v
|
55 |
+
for k, v in llama_state_dict.items()
|
56 |
+
if k.startswith("model.")
|
57 |
+
}
|
58 |
+
if any(k.startswith("model.") for k in lora_state_dict.keys()):
|
59 |
+
lora_state_dict = {
|
60 |
+
k.replace("model.", ""): v
|
61 |
+
for k, v in lora_state_dict.items()
|
62 |
+
if k.startswith("model.")
|
63 |
+
}
|
64 |
+
|
65 |
+
logger.info(f"Found {len(llama_state_dict)} keys in llama model")
|
66 |
+
logger.info(f"Found {len(lora_state_dict)} keys in lora model")
|
67 |
+
|
68 |
+
merged_state_dict = llama_state_dict | lora_state_dict
|
69 |
+
llama_model.load_state_dict(merged_state_dict, strict=True)
|
70 |
+
logger.info(f"Merged model loaded")
|
71 |
+
|
72 |
+
# Trigger eval mode to merge lora
|
73 |
+
llama_model.eval()
|
74 |
+
llama_model.save_pretrained(output, drop_lora=True)
|
75 |
+
logger.info(f"Saved merged model to {output}, validating")
|
76 |
+
|
77 |
+
new_state_dict = torch.load(output / "model.pth", map_location="cpu")
|
78 |
+
original_keys = set(llama_state_dict_copy.keys())
|
79 |
+
merged_keys = set(new_state_dict.keys())
|
80 |
+
|
81 |
+
assert original_keys == merged_keys, "Keys should be same"
|
82 |
+
|
83 |
+
for key in original_keys:
|
84 |
+
diff_l1 = (new_state_dict[key] - llama_state_dict_copy[key]).abs().sum().item()
|
85 |
+
if diff_l1 != 0:
|
86 |
+
break
|
87 |
+
else:
|
88 |
+
logger.error("Merged model is same as the original model")
|
89 |
+
exit(1)
|
90 |
+
|
91 |
+
logger.info("Merged model is different from the original model, check passed")
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
merge()
|
tools/llama/quantize.py
CHANGED
@@ -1,16 +1,20 @@
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
# All rights reserved.
|
|
|
|
|
3 |
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
import time
|
7 |
from pathlib import Path
|
8 |
|
|
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
import torch.nn.functional as F
|
12 |
|
13 |
-
from fish_speech.models.text2semantic.llama import
|
|
|
14 |
|
15 |
##### Quantization Primitives ######
|
16 |
|
@@ -414,13 +418,26 @@ class WeightOnlyInt4Linear(torch.nn.Module):
|
|
414 |
)
|
415 |
|
416 |
|
417 |
-
def
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
device = "cpu"
|
426 |
precision = torch.bfloat16
|
@@ -428,31 +445,14 @@ def quantize(
|
|
428 |
print("Loading model ...")
|
429 |
t0 = time.time()
|
430 |
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
rope_base=10000,
|
440 |
-
norm_eps=1e-5,
|
441 |
-
num_codebooks=4, # single codebook
|
442 |
-
codebook_size=168, # codebook size 160 + 2 special tokens
|
443 |
-
)
|
444 |
-
)
|
445 |
-
|
446 |
-
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
|
447 |
-
if "state_dict" in checkpoint:
|
448 |
-
checkpoint = checkpoint["state_dict"]
|
449 |
-
checkpoint = {
|
450 |
-
k.replace("model.", ""): v
|
451 |
-
for k, v in checkpoint.items()
|
452 |
-
if k.startswith("model.")
|
453 |
-
}
|
454 |
-
model.load_state_dict(checkpoint, assign=True)
|
455 |
-
model = model.to(dtype=precision, device=device)
|
456 |
|
457 |
if mode == "int8":
|
458 |
print(
|
@@ -461,10 +461,12 @@ def quantize(
|
|
461 |
quant_handler = WeightOnlyInt8QuantHandler(model)
|
462 |
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
463 |
|
464 |
-
dir_name = checkpoint_path
|
465 |
-
|
466 |
-
|
467 |
-
|
|
|
|
|
468 |
|
469 |
elif mode == "int4":
|
470 |
print(
|
@@ -473,10 +475,12 @@ def quantize(
|
|
473 |
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
474 |
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
475 |
|
476 |
-
dir_name = checkpoint_path
|
477 |
-
|
478 |
-
|
479 |
-
|
|
|
|
|
480 |
|
481 |
else:
|
482 |
raise ValueError(
|
@@ -490,26 +494,4 @@ def quantize(
|
|
490 |
|
491 |
|
492 |
if __name__ == "__main__":
|
493 |
-
|
494 |
-
|
495 |
-
parser = argparse.ArgumentParser(description="Quantize a model.")
|
496 |
-
parser.add_argument(
|
497 |
-
"--checkpoint_path",
|
498 |
-
type=Path,
|
499 |
-
default=Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"),
|
500 |
-
help="Path to the model checkpoint to be quantized.",
|
501 |
-
)
|
502 |
-
parser.add_argument(
|
503 |
-
"--mode",
|
504 |
-
"-q",
|
505 |
-
type=str,
|
506 |
-
default="int8",
|
507 |
-
choices=["int8", "int4"],
|
508 |
-
help="type of quantization to perform",
|
509 |
-
)
|
510 |
-
parser.add_argument(
|
511 |
-
"--groupsize", type=int, default=32, help="Group size for int4 quantization."
|
512 |
-
)
|
513 |
-
|
514 |
-
args = parser.parse_args()
|
515 |
-
quantize(args.checkpoint_path, args.mode, args.groupsize)
|
|
|
1 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
# All rights reserved.
|
3 |
+
import datetime
|
4 |
+
import shutil
|
5 |
|
6 |
# This source code is licensed under the license found in the
|
7 |
# LICENSE file in the root directory of this source tree.
|
8 |
import time
|
9 |
from pathlib import Path
|
10 |
|
11 |
+
import click
|
12 |
import torch
|
13 |
import torch.nn as nn
|
14 |
import torch.nn.functional as F
|
15 |
|
16 |
+
from fish_speech.models.text2semantic.llama import find_multiple
|
17 |
+
from tools.llama.generate import load_model
|
18 |
|
19 |
##### Quantization Primitives ######
|
20 |
|
|
|
418 |
)
|
419 |
|
420 |
|
421 |
+
def generate_folder_name():
|
422 |
+
now = datetime.datetime.now()
|
423 |
+
folder_name = now.strftime("%Y%m%d_%H%M%S")
|
424 |
+
return folder_name
|
425 |
+
|
426 |
+
|
427 |
+
@click.command()
|
428 |
+
@click.option(
|
429 |
+
"--checkpoint-path",
|
430 |
+
type=click.Path(path_type=Path, exists=True),
|
431 |
+
default="checkpoints/fish-speech-1.2-sft",
|
432 |
+
)
|
433 |
+
@click.option(
|
434 |
+
"--mode", type=str, default="int8", help="type of quantization to perform"
|
435 |
+
)
|
436 |
+
@click.option(
|
437 |
+
"--groupsize", type=int, default=128, help="Group size for int4 quantization."
|
438 |
+
)
|
439 |
+
@click.option("--timestamp", type=str, default="None", help="When to do quantization")
|
440 |
+
def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:
|
441 |
|
442 |
device = "cpu"
|
443 |
precision = torch.bfloat16
|
|
|
445 |
print("Loading model ...")
|
446 |
t0 = time.time()
|
447 |
|
448 |
+
model, _ = load_model(
|
449 |
+
checkpoint_path=checkpoint_path,
|
450 |
+
device=device,
|
451 |
+
precision=precision,
|
452 |
+
compile=False,
|
453 |
+
)
|
454 |
+
vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
|
455 |
+
now = timestamp if timestamp != "None" else generate_folder_name()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
456 |
|
457 |
if mode == "int8":
|
458 |
print(
|
|
|
461 |
quant_handler = WeightOnlyInt8QuantHandler(model)
|
462 |
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
463 |
|
464 |
+
dir_name = checkpoint_path
|
465 |
+
dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
|
466 |
+
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
|
467 |
+
if (dst_name / vq_model).exists():
|
468 |
+
(dst_name / vq_model).unlink()
|
469 |
+
quantize_path = dst_name / "model.pth"
|
470 |
|
471 |
elif mode == "int4":
|
472 |
print(
|
|
|
475 |
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
|
476 |
quantized_state_dict = quant_handler.create_quantized_state_dict()
|
477 |
|
478 |
+
dir_name = checkpoint_path
|
479 |
+
dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
|
480 |
+
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
|
481 |
+
if (dst_name / vq_model).exists():
|
482 |
+
(dst_name / vq_model).unlink()
|
483 |
+
quantize_path = dst_name / "model.pth"
|
484 |
|
485 |
else:
|
486 |
raise ValueError(
|
|
|
494 |
|
495 |
|
496 |
if __name__ == "__main__":
|
497 |
+
quantize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/llama/rebuild_tokenizer.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
|
2 |
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
3 |
+
|
4 |
+
# Initialize a tokenizer
|
5 |
+
tokenizer = Tokenizer(models.BPE())
|
6 |
+
|
7 |
+
# Customize pre-tokenization and decoding
|
8 |
+
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
9 |
+
tokenizer.decoder = decoders.ByteLevel()
|
10 |
+
tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
|
11 |
+
|
12 |
+
# Don't train the tokenizer
|
13 |
+
trainer = trainers.BpeTrainer(
|
14 |
+
vocab_size=0,
|
15 |
+
min_frequency=2,
|
16 |
+
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
|
17 |
+
special_tokens=[
|
18 |
+
"<|begin_of_sequence|>",
|
19 |
+
"<|end_of_sequence|>",
|
20 |
+
"<|im_start|>",
|
21 |
+
"<|im_sep|>", # system, user, assistant, etc.
|
22 |
+
"<|im_end|>",
|
23 |
+
"<|semantic|>", # audio features
|
24 |
+
"<|pad|>",
|
25 |
+
],
|
26 |
+
)
|
27 |
+
|
28 |
+
# <|im_start|>user<|im_sep|>...<|im_end|>
|
29 |
+
# <|im_start|>assistant<|im_sep|><|semantic|><|semantic|><|semantic|><|semantic|><|semantic|><|im_end|>
|
30 |
+
tokenizer.train_from_iterator([], trainer=trainer)
|
31 |
+
|
32 |
+
print(len(tokenizer.get_vocab()))
|
33 |
+
x = tokenizer.encode(
|
34 |
+
"Hello, how are you? dfgnviadfjoiviouajeiodfjv 你好世界 🈶<|semantic|>"
|
35 |
+
).ids
|
36 |
+
print(x, len(x))
|
37 |
+
print(tokenizer.decode(x, skip_special_tokens=True))
|
38 |
+
|
39 |
+
|
40 |
+
tokenizer = PreTrainedTokenizerFast(
|
41 |
+
tokenizer_object=tokenizer,
|
42 |
+
pad_token="<|pad|>",
|
43 |
+
bos_token="<|begin_of_sequence|>",
|
44 |
+
eos_token="<|end_of_sequence|>",
|
45 |
+
)
|
46 |
+
|
47 |
+
# Try tokenizing a new sequence
|
48 |
+
sequence = "All around, too, lay vast quantities of the costliest merchandise, and treasures were heaped in every cranny of the rocks, but all these things only added to the desolation of the scene. 测试中文, 你好世界 🈶<|semantic|>"
|
49 |
+
encoded = tokenizer(sequence).input_ids
|
50 |
+
|
51 |
+
print("Test encoding....")
|
52 |
+
print(f"\tSentence: {sequence}")
|
53 |
+
print(f"\tEncoded: {encoded}")
|
54 |
+
print(f"\tDecoded: {tokenizer.batch_decode(encoded)}")
|
55 |
+
print(f"\tDecoded: {tokenizer.decode(encoded)}")
|
56 |
+
|
57 |
+
tokenizer.push_to_hub("fishaudio/fish-speech-1", private=True)
|
tools/vqgan/create_train_split.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pathlib import Path
|
3 |
+
from random import Random
|
4 |
+
|
5 |
+
import click
|
6 |
+
from loguru import logger
|
7 |
+
from pydub import AudioSegment
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
11 |
+
|
12 |
+
|
13 |
+
@click.command()
|
14 |
+
@click.argument("root", type=click.Path(exists=True, path_type=Path))
|
15 |
+
@click.option("--val-ratio", type=float, default=None)
|
16 |
+
@click.option("--val-count", type=int, default=None)
|
17 |
+
@click.option("--filelist", default=None, type=Path)
|
18 |
+
@click.option("--min-duration", default=None, type=float)
|
19 |
+
@click.option("--max-duration", default=None, type=float)
|
20 |
+
def main(root, val_ratio, val_count, filelist, min_duration, max_duration):
|
21 |
+
if filelist:
|
22 |
+
files = [i[0] for i in load_filelist(filelist)]
|
23 |
+
else:
|
24 |
+
files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
|
25 |
+
|
26 |
+
if min_duration is None and max_duration is None:
|
27 |
+
filtered_files = list(map(str, [file.relative_to(root) for file in files]))
|
28 |
+
else:
|
29 |
+
filtered_files = []
|
30 |
+
for file in tqdm(files):
|
31 |
+
try:
|
32 |
+
audio = AudioSegment.from_file(str(file))
|
33 |
+
duration = len(audio) / 1000.0
|
34 |
+
|
35 |
+
if min_duration is not None and duration < min_duration:
|
36 |
+
logger.info(
|
37 |
+
f"Skipping {file} due to duration {duration:.2f} < {min_duration:.2f}"
|
38 |
+
)
|
39 |
+
continue
|
40 |
+
|
41 |
+
if max_duration is not None and duration > max_duration:
|
42 |
+
logger.info(
|
43 |
+
f"Skipping {file} due to duration {duration:.2f} > {max_duration:.2f}"
|
44 |
+
)
|
45 |
+
continue
|
46 |
+
|
47 |
+
filtered_files.append(str(file.relative_to(root)))
|
48 |
+
except Exception as e:
|
49 |
+
logger.info(f"Error processing {file}: {e}")
|
50 |
+
|
51 |
+
logger.info(
|
52 |
+
f"Found {len(files)} files, remaining {len(filtered_files)} files after filtering"
|
53 |
+
)
|
54 |
+
|
55 |
+
Random(42).shuffle(filtered_files)
|
56 |
+
|
57 |
+
if val_count is None and val_ratio is None:
|
58 |
+
logger.info("Validation ratio and count not specified, using min(20%, 100)")
|
59 |
+
val_size = min(100, math.ceil(len(filtered_files) * 0.2))
|
60 |
+
elif val_count is not None and val_ratio is not None:
|
61 |
+
logger.error("Cannot specify both val_count and val_ratio")
|
62 |
+
return
|
63 |
+
elif val_count is not None:
|
64 |
+
if val_count < 1 or val_count > len(filtered_files):
|
65 |
+
logger.error("val_count must be between 1 and number of files")
|
66 |
+
return
|
67 |
+
val_size = val_count
|
68 |
+
else:
|
69 |
+
val_size = math.ceil(len(filtered_files) * val_ratio)
|
70 |
+
|
71 |
+
logger.info(f"Using {val_size} files for validation")
|
72 |
+
|
73 |
+
with open(root / "vq_train_filelist.txt", "w", encoding="utf-8") as f:
|
74 |
+
f.write("\n".join(filtered_files[val_size:]))
|
75 |
+
|
76 |
+
with open(root / "vq_val_filelist.txt", "w", encoding="utf-8") as f:
|
77 |
+
f.write("\n".join(filtered_files[:val_size]))
|
78 |
+
|
79 |
+
logger.info("Done")
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|
tools/vqgan/extract_vq.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess as sp
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
from datetime import timedelta
|
6 |
+
from functools import lru_cache
|
7 |
+
from pathlib import Path
|
8 |
+
from random import Random
|
9 |
+
|
10 |
+
import click
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torchaudio
|
14 |
+
from hydra import compose, initialize
|
15 |
+
from hydra.utils import instantiate
|
16 |
+
from lightning import LightningModule
|
17 |
+
from loguru import logger
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
|
20 |
+
from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
21 |
+
|
22 |
+
# register eval resolver
|
23 |
+
OmegaConf.register_new_resolver("eval", eval)
|
24 |
+
# This file is used to convert the audio files to text files using the Whisper model.
|
25 |
+
# It's mainly used to generate the training data for the VQ model.
|
26 |
+
|
27 |
+
|
28 |
+
RANK = int(os.environ.get("SLURM_PROCID", 0))
|
29 |
+
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
|
30 |
+
|
31 |
+
logger_format = (
|
32 |
+
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
33 |
+
"<level>{level: <8}</level> | "
|
34 |
+
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
35 |
+
"{extra[rank]} - <level>{message}</level>"
|
36 |
+
)
|
37 |
+
logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
|
38 |
+
logger.remove()
|
39 |
+
logger.add(sys.stderr, format=logger_format)
|
40 |
+
|
41 |
+
|
42 |
+
@lru_cache(maxsize=1)
|
43 |
+
def get_model(
|
44 |
+
config_name: str = "firefly_gan_vq",
|
45 |
+
checkpoint_path: str = "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
46 |
+
device: str | torch.device = "cuda",
|
47 |
+
):
|
48 |
+
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
49 |
+
cfg = compose(config_name=config_name)
|
50 |
+
|
51 |
+
model = instantiate(cfg)
|
52 |
+
state_dict = torch.load(
|
53 |
+
checkpoint_path,
|
54 |
+
map_location=device,
|
55 |
+
)
|
56 |
+
if "state_dict" in state_dict:
|
57 |
+
state_dict = state_dict["state_dict"]
|
58 |
+
|
59 |
+
if any("generator" in k for k in state_dict):
|
60 |
+
state_dict = {
|
61 |
+
k.replace("generator.", ""): v
|
62 |
+
for k, v in state_dict.items()
|
63 |
+
if "generator." in k
|
64 |
+
}
|
65 |
+
|
66 |
+
model.load_state_dict(state_dict, strict=False)
|
67 |
+
model.eval()
|
68 |
+
model.to(device)
|
69 |
+
|
70 |
+
logger.info(f"Loaded model")
|
71 |
+
return model
|
72 |
+
|
73 |
+
|
74 |
+
@torch.inference_mode()
|
75 |
+
def process_batch(files: list[Path], model) -> float:
|
76 |
+
wavs = []
|
77 |
+
audio_lengths = []
|
78 |
+
new_files = []
|
79 |
+
max_length = total_time = 0
|
80 |
+
|
81 |
+
for file in files:
|
82 |
+
try:
|
83 |
+
wav, sr = torchaudio.load(
|
84 |
+
str(file), backend="sox" if sys.platform == "linux" else "soundfile"
|
85 |
+
) # Need to install libsox-dev
|
86 |
+
except Exception as e:
|
87 |
+
logger.error(f"Error reading {file}: {e}")
|
88 |
+
continue
|
89 |
+
|
90 |
+
if wav.shape[0] > 1:
|
91 |
+
wav = wav.mean(dim=0, keepdim=True)
|
92 |
+
|
93 |
+
wav = torchaudio.functional.resample(
|
94 |
+
wav.cuda(), sr, model.spec_transform.sample_rate
|
95 |
+
)[0]
|
96 |
+
total_time += len(wav) / model.spec_transform.sample_rate
|
97 |
+
max_length = max(max_length, len(wav))
|
98 |
+
|
99 |
+
wavs.append(wav)
|
100 |
+
audio_lengths.append(len(wav))
|
101 |
+
new_files.append(file)
|
102 |
+
|
103 |
+
files = new_files
|
104 |
+
|
105 |
+
# Pad to max length
|
106 |
+
for i, wav in enumerate(wavs):
|
107 |
+
wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
|
108 |
+
|
109 |
+
audios = torch.stack(wavs, dim=0)[:, None]
|
110 |
+
audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
|
111 |
+
|
112 |
+
# Calculate lengths
|
113 |
+
indices, feature_lengths = model.encode(audios, audio_lengths)
|
114 |
+
|
115 |
+
# Save to disk
|
116 |
+
outputs = indices.cpu().numpy()
|
117 |
+
|
118 |
+
for file, length, feature, audio_length in zip(
|
119 |
+
files, feature_lengths, outputs, audio_lengths
|
120 |
+
):
|
121 |
+
feature = feature[:, :length]
|
122 |
+
|
123 |
+
# (T,)
|
124 |
+
with open(file.with_suffix(".npy"), "wb") as f:
|
125 |
+
np.save(f, feature)
|
126 |
+
|
127 |
+
return total_time
|
128 |
+
|
129 |
+
|
130 |
+
@click.command()
|
131 |
+
@click.argument("folder")
|
132 |
+
@click.option("--num-workers", default=1)
|
133 |
+
@click.option("--config-name", default="firefly_gan_vq")
|
134 |
+
@click.option(
|
135 |
+
"--checkpoint-path",
|
136 |
+
default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
137 |
+
)
|
138 |
+
@click.option("--batch-size", default=64)
|
139 |
+
@click.option("--filelist", default=None, type=Path)
|
140 |
+
def main(
|
141 |
+
folder: str,
|
142 |
+
num_workers: int,
|
143 |
+
config_name: str,
|
144 |
+
checkpoint_path: str,
|
145 |
+
batch_size: int,
|
146 |
+
filelist: Path,
|
147 |
+
):
|
148 |
+
if num_workers > 1 and WORLD_SIZE != num_workers:
|
149 |
+
assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
|
150 |
+
|
151 |
+
logger.info(f"Spawning {num_workers} workers")
|
152 |
+
|
153 |
+
if torch.cuda.is_available():
|
154 |
+
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
155 |
+
if visible_devices is None:
|
156 |
+
visible_devices = list(range(torch.cuda.device_count()))
|
157 |
+
else:
|
158 |
+
visible_devices = visible_devices.split(",")
|
159 |
+
else:
|
160 |
+
# Set to empty string to avoid using GPU
|
161 |
+
visible_devices = [""]
|
162 |
+
|
163 |
+
processes = []
|
164 |
+
for i in range(num_workers):
|
165 |
+
env = os.environ.copy()
|
166 |
+
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
|
167 |
+
env["SLURM_PROCID"] = str(i)
|
168 |
+
env["SLURM_NTASKS"] = str(num_workers)
|
169 |
+
|
170 |
+
processes.append(
|
171 |
+
sp.Popen(
|
172 |
+
[sys.executable] + sys.argv.copy(),
|
173 |
+
env=env,
|
174 |
+
)
|
175 |
+
)
|
176 |
+
|
177 |
+
for p in processes:
|
178 |
+
p.wait()
|
179 |
+
|
180 |
+
logger.info(f"All workers finished")
|
181 |
+
return
|
182 |
+
|
183 |
+
# This is a worker
|
184 |
+
logger.info(f"Starting worker")
|
185 |
+
if filelist:
|
186 |
+
files = [i[0] for i in load_filelist(filelist)]
|
187 |
+
else:
|
188 |
+
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
|
189 |
+
|
190 |
+
print(f"Found {len(files)} files")
|
191 |
+
files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
|
192 |
+
|
193 |
+
total_files = len(files)
|
194 |
+
files = files[RANK::WORLD_SIZE]
|
195 |
+
logger.info(f"Processing {len(files)}/{total_files} files")
|
196 |
+
|
197 |
+
# Batch processing
|
198 |
+
total_time = 0
|
199 |
+
begin_time = time.time()
|
200 |
+
processed_files = 0
|
201 |
+
model = get_model(config_name, checkpoint_path)
|
202 |
+
|
203 |
+
for n_batch, idx in enumerate(range(0, len(files), batch_size)):
|
204 |
+
batch = files[idx : idx + batch_size]
|
205 |
+
batch_time = process_batch(batch, model)
|
206 |
+
|
207 |
+
total_time += batch_time
|
208 |
+
processed_files += len(batch)
|
209 |
+
|
210 |
+
if (n_batch + 1) % 10 == 0:
|
211 |
+
eta = (
|
212 |
+
(time.time() - begin_time)
|
213 |
+
/ processed_files
|
214 |
+
* (len(files) - processed_files)
|
215 |
+
)
|
216 |
+
logger.info(
|
217 |
+
f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
|
218 |
+
+ f"ETA: {timedelta(seconds=round(eta))}s"
|
219 |
+
)
|
220 |
+
|
221 |
+
logger.info(
|
222 |
+
f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
|
223 |
+
)
|
224 |
+
|
225 |
+
|
226 |
+
if __name__ == "__main__":
|
227 |
+
main()
|
tools/vqgan/inference.py
CHANGED
@@ -2,13 +2,12 @@ from pathlib import Path
|
|
2 |
|
3 |
import click
|
4 |
import hydra
|
5 |
-
import librosa
|
6 |
import numpy as np
|
7 |
import soundfile as sf
|
8 |
import torch
|
|
|
9 |
from hydra import compose, initialize
|
10 |
from hydra.utils import instantiate
|
11 |
-
from lightning import LightningModule
|
12 |
from loguru import logger
|
13 |
from omegaconf import OmegaConf
|
14 |
|
@@ -23,20 +22,26 @@ def load_model(config_name, checkpoint_path, device="cuda"):
|
|
23 |
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
24 |
cfg = compose(config_name=config_name)
|
25 |
|
26 |
-
model
|
27 |
state_dict = torch.load(
|
28 |
checkpoint_path,
|
29 |
-
map_location=
|
30 |
)
|
31 |
-
|
32 |
if "state_dict" in state_dict:
|
33 |
state_dict = state_dict["state_dict"]
|
34 |
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
model.eval()
|
37 |
model.to(device)
|
38 |
-
logger.info("Restored model from checkpoint")
|
39 |
|
|
|
40 |
return model
|
41 |
|
42 |
|
@@ -51,11 +56,10 @@ def load_model(config_name, checkpoint_path, device="cuda"):
|
|
51 |
@click.option(
|
52 |
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
|
53 |
)
|
54 |
-
@click.option("--config-name",
|
55 |
@click.option(
|
56 |
"--checkpoint-path",
|
57 |
-
"-
|
58 |
-
default="checkpoints/vq-gan-group-fsq-2x1024.pth",
|
59 |
)
|
60 |
@click.option(
|
61 |
"--device",
|
@@ -67,21 +71,22 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
|
|
67 |
|
68 |
if input_path.suffix in AUDIO_EXTENSIONS:
|
69 |
logger.info(f"Processing in-place reconstruction of {input_path}")
|
|
|
70 |
# Load audio
|
71 |
-
audio,
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
75 |
)
|
76 |
-
|
|
|
77 |
logger.info(
|
78 |
-
f"Loaded audio with {audios.shape[2] / model.
|
79 |
)
|
80 |
|
81 |
# VQ Encoder
|
82 |
-
audio_lengths = torch.tensor(
|
83 |
-
[audios.shape[2]], device=model.device, dtype=torch.long
|
84 |
-
)
|
85 |
indices = model.encode(audios, audio_lengths)[0][0]
|
86 |
|
87 |
logger.info(f"Generated indices of shape {indices.shape}")
|
@@ -91,17 +96,15 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
|
|
91 |
elif input_path.suffix == ".npy":
|
92 |
logger.info(f"Processing precomputed indices from {input_path}")
|
93 |
indices = np.load(input_path)
|
94 |
-
indices = torch.from_numpy(indices).to(
|
95 |
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
|
96 |
else:
|
97 |
raise ValueError(f"Unknown input type: {input_path}")
|
98 |
|
99 |
# Restore
|
100 |
-
feature_lengths = torch.tensor([indices.shape[1]], device=
|
101 |
-
fake_audios = model.decode(
|
102 |
-
|
103 |
-
)
|
104 |
-
audio_time = fake_audios.shape[-1] / model.sampling_rate
|
105 |
|
106 |
logger.info(
|
107 |
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
|
@@ -109,7 +112,7 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
|
|
109 |
|
110 |
# Save audio
|
111 |
fake_audio = fake_audios[0, 0].float().cpu().numpy()
|
112 |
-
sf.write(output_path, fake_audio, model.
|
113 |
logger.info(f"Saved audio to {output_path}")
|
114 |
|
115 |
|
|
|
2 |
|
3 |
import click
|
4 |
import hydra
|
|
|
5 |
import numpy as np
|
6 |
import soundfile as sf
|
7 |
import torch
|
8 |
+
import torchaudio
|
9 |
from hydra import compose, initialize
|
10 |
from hydra.utils import instantiate
|
|
|
11 |
from loguru import logger
|
12 |
from omegaconf import OmegaConf
|
13 |
|
|
|
22 |
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
23 |
cfg = compose(config_name=config_name)
|
24 |
|
25 |
+
model = instantiate(cfg)
|
26 |
state_dict = torch.load(
|
27 |
checkpoint_path,
|
28 |
+
map_location=device,
|
29 |
)
|
|
|
30 |
if "state_dict" in state_dict:
|
31 |
state_dict = state_dict["state_dict"]
|
32 |
|
33 |
+
if any("generator" in k for k in state_dict):
|
34 |
+
state_dict = {
|
35 |
+
k.replace("generator.", ""): v
|
36 |
+
for k, v in state_dict.items()
|
37 |
+
if "generator." in k
|
38 |
+
}
|
39 |
+
|
40 |
+
result = model.load_state_dict(state_dict, strict=False)
|
41 |
model.eval()
|
42 |
model.to(device)
|
|
|
43 |
|
44 |
+
logger.info(f"Loaded model: {result}")
|
45 |
return model
|
46 |
|
47 |
|
|
|
56 |
@click.option(
|
57 |
"--output-path", "-o", default="fake.wav", type=click.Path(path_type=Path)
|
58 |
)
|
59 |
+
@click.option("--config-name", default="firefly_gan_vq")
|
60 |
@click.option(
|
61 |
"--checkpoint-path",
|
62 |
+
default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
|
|
|
63 |
)
|
64 |
@click.option(
|
65 |
"--device",
|
|
|
71 |
|
72 |
if input_path.suffix in AUDIO_EXTENSIONS:
|
73 |
logger.info(f"Processing in-place reconstruction of {input_path}")
|
74 |
+
|
75 |
# Load audio
|
76 |
+
audio, sr = torchaudio.load(str(input_path))
|
77 |
+
if audio.shape[0] > 1:
|
78 |
+
audio = audio.mean(0, keepdim=True)
|
79 |
+
audio = torchaudio.functional.resample(
|
80 |
+
audio, sr, model.spec_transform.sample_rate
|
81 |
)
|
82 |
+
|
83 |
+
audios = audio[None].to(device)
|
84 |
logger.info(
|
85 |
+
f"Loaded audio with {audios.shape[2] / model.spec_transform.sample_rate:.2f} seconds"
|
86 |
)
|
87 |
|
88 |
# VQ Encoder
|
89 |
+
audio_lengths = torch.tensor([audios.shape[2]], device=device, dtype=torch.long)
|
|
|
|
|
90 |
indices = model.encode(audios, audio_lengths)[0][0]
|
91 |
|
92 |
logger.info(f"Generated indices of shape {indices.shape}")
|
|
|
96 |
elif input_path.suffix == ".npy":
|
97 |
logger.info(f"Processing precomputed indices from {input_path}")
|
98 |
indices = np.load(input_path)
|
99 |
+
indices = torch.from_numpy(indices).to(device).long()
|
100 |
assert indices.ndim == 2, f"Expected 2D indices, got {indices.ndim}"
|
101 |
else:
|
102 |
raise ValueError(f"Unknown input type: {input_path}")
|
103 |
|
104 |
# Restore
|
105 |
+
feature_lengths = torch.tensor([indices.shape[1]], device=device)
|
106 |
+
fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
|
107 |
+
audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
|
|
|
|
|
108 |
|
109 |
logger.info(
|
110 |
f"Generated audio of shape {fake_audios.shape}, equivalent to {audio_time:.2f} seconds from {indices.shape[1]} features, features/second: {indices.shape[1] / audio_time:.2f}"
|
|
|
112 |
|
113 |
# Save audio
|
114 |
fake_audio = fake_audios[0, 0].float().cpu().numpy()
|
115 |
+
sf.write(output_path, fake_audio, model.spec_transform.sample_rate)
|
116 |
logger.info(f"Saved audio to {output_path}")
|
117 |
|
118 |
|