PoTaTo721 commited on
Commit
4f6613a
1 Parent(s): 5a969dd

Upload Fish-Agent Demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .project-root +0 -0
  2. app.py +217 -51
  3. fish_speech/callbacks/__init__.py +3 -0
  4. fish_speech/callbacks/grad_norm.py +113 -0
  5. fish_speech/configs/base.yaml +87 -0
  6. fish_speech/configs/firefly_gan_vq.yaml +33 -0
  7. fish_speech/configs/lora/r_8_alpha_16.yaml +4 -0
  8. fish_speech/configs/text2semantic_finetune.yaml +83 -0
  9. fish_speech/conversation.py +256 -0
  10. fish_speech/datasets/concat_repeat.py +53 -0
  11. fish_speech/datasets/protos/text-data.proto +24 -0
  12. fish_speech/datasets/protos/text_data_pb2.py +33 -0
  13. fish_speech/datasets/protos/text_data_stream.py +36 -0
  14. fish_speech/datasets/semantic.py +496 -0
  15. fish_speech/datasets/vqgan.py +147 -0
  16. fish_speech/i18n/README.md +27 -0
  17. fish_speech/i18n/__init__.py +3 -0
  18. fish_speech/i18n/core.py +40 -0
  19. fish_speech/i18n/locale/en_US.json +123 -0
  20. fish_speech/i18n/locale/es_ES.json +123 -0
  21. fish_speech/i18n/locale/ja_JP.json +123 -0
  22. fish_speech/i18n/locale/ko_KR.json +123 -0
  23. fish_speech/i18n/locale/pt_BR.json +133 -0
  24. fish_speech/i18n/locale/zh_CN.json +123 -0
  25. fish_speech/i18n/scan.py +122 -0
  26. fish_speech/models/text2semantic/__init__.py +0 -0
  27. fish_speech/models/text2semantic/lit_module.py +202 -0
  28. fish_speech/models/text2semantic/llama.py +844 -0
  29. fish_speech/models/text2semantic/lora.py +92 -0
  30. fish_speech/models/vqgan/__init__.py +0 -0
  31. fish_speech/models/vqgan/modules/firefly.py +596 -0
  32. fish_speech/models/vqgan/modules/fsq.py +116 -0
  33. fish_speech/models/vqgan/utils.py +94 -0
  34. fish_speech/scheduler.py +40 -0
  35. fish_speech/text/__init__.py +4 -0
  36. fish_speech/text/chn_text_norm/.gitignore +114 -0
  37. fish_speech/text/chn_text_norm/README.md +36 -0
  38. fish_speech/text/chn_text_norm/__init__.py +0 -0
  39. fish_speech/text/chn_text_norm/basic_class.py +172 -0
  40. fish_speech/text/chn_text_norm/basic_constant.py +30 -0
  41. fish_speech/text/chn_text_norm/basic_util.py +342 -0
  42. fish_speech/text/chn_text_norm/cardinal.py +32 -0
  43. fish_speech/text/chn_text_norm/date.py +75 -0
  44. fish_speech/text/chn_text_norm/digit.py +32 -0
  45. fish_speech/text/chn_text_norm/fraction.py +35 -0
  46. fish_speech/text/chn_text_norm/money.py +43 -0
  47. fish_speech/text/chn_text_norm/percentage.py +33 -0
  48. fish_speech/text/chn_text_norm/telephone.py +51 -0
  49. fish_speech/text/chn_text_norm/text.py +177 -0
  50. fish_speech/text/clean.py +62 -0
.project-root ADDED
File without changes
app.py CHANGED
@@ -1,64 +1,230 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
  import gradio as gr
3
+ import numpy as np
4
+ import os
5
+ import threading
6
+ import subprocess
7
+ import sys
8
+ import time
9
 
10
+ from huggingface_hub import snapshot_download
11
+ from tools.fish_e2e import FishE2EAgent, FishE2EEventType
12
+ from tools.schema import ServeMessage, ServeTextPart, ServeVQPart
 
13
 
14
+ # Download Weights
15
+ os.makedirs("checkpoints", exist_ok=True)
16
+ snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
17
+ snapshot_download(repo_id="fishaudio/fish-agent-v0.1-3b", local_dir="./checkpoints/fish-agent-v0.1-3b")
18
 
19
+ class ChatState:
20
+ def __init__(self):
21
+ self.conversation = []
22
+ self.added_systext = False
23
+ self.added_sysaudio = False
24
+
25
+ def get_history(self):
26
+ results = []
27
+ for msg in self.conversation:
28
+ results.append({"role": msg.role, "content": self.repr_message(msg)})
29
+
30
+ # Process assistant messages to extract questions and update user messages
31
+ for i, msg in enumerate(results):
32
+ if msg["role"] == "assistant":
33
+ match = re.search(r"Question: (.*?)\n\nResponse:", msg["content"])
34
+ if match and i > 0 and results[i - 1]["role"] == "user":
35
+ # Update previous user message with extracted question
36
+ results[i - 1]["content"] += "\n" + match.group(1)
37
+ # Remove the Question/Answer format from assistant message
38
+ msg["content"] = msg["content"].split("\n\nResponse: ", 1)[1]
39
+ return results
40
+
41
+ def repr_message(self, msg: ServeMessage):
42
+ response = ""
43
+ for part in msg.parts:
44
+ if isinstance(part, ServeTextPart):
45
+ response += part.text
46
+ elif isinstance(part, ServeVQPart):
47
+ response += f"<audio {len(part.codes[0]) / 21:.2f}s>"
48
+ return response
49
+
50
+
51
+ def clear_fn():
52
+ return [], ChatState(), None, None, None
53
+
54
+
55
+ async def process_audio_input(
56
+ sys_audio_input, sys_text_input, audio_input, state: ChatState, text_input: str
57
  ):
58
+ if audio_input is None and not text_input:
59
+ raise gr.Error("No input provided")
60
+
61
+ agent = FishE2EAgent() # Create new agent instance for each request
62
 
63
+ # Convert audio input to numpy array
64
+ if isinstance(audio_input, tuple):
65
+ sr, audio_data = audio_input
66
+ elif text_input:
67
+ sr = 44100
68
+ audio_data = None
69
+ else:
70
+ raise gr.Error("Invalid audio format")
71
 
72
+ if isinstance(sys_audio_input, tuple):
73
+ sr, sys_audio_data = sys_audio_input
74
+ elif text_input:
75
+ sr = 44100
76
+ sys_audio_data = None
77
+ else:
78
+ raise gr.Error("Invalid audio format")
79
 
80
+ def append_to_chat_ctx(
81
+ part: ServeTextPart | ServeVQPart, role: str = "assistant"
82
+ ) -> None:
83
+ if not state.conversation or state.conversation[-1].role != role:
84
+ state.conversation.append(ServeMessage(role=role, parts=[part]))
85
+ else:
86
+ state.conversation[-1].parts.append(part)
87
 
88
+ if state.added_systext is False and sys_text_input:
89
+ state.added_systext = True
90
+ append_to_chat_ctx(ServeTextPart(text=sys_text_input), role="system")
91
+ if text_input:
92
+ append_to_chat_ctx(ServeTextPart(text=text_input), role="user")
93
+ audio_data = None
94
+
95
+ result_audio = b""
96
+ async for event in agent.stream(
97
+ sys_audio_data,
98
+ audio_data,
99
+ sr,
100
+ 1,
101
+ chat_ctx={
102
+ "messages": state.conversation,
103
+ "added_sysaudio": state.added_sysaudio,
104
+ },
105
+ ):
106
+ if event.type == FishE2EEventType.USER_CODES:
107
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes), role="user")
108
+ elif event.type == FishE2EEventType.SPEECH_SEGMENT:
109
+ result_audio += event.frame.data
110
+ np_audio = np.frombuffer(result_audio, dtype=np.int16)
111
+ append_to_chat_ctx(ServeVQPart(codes=event.vq_codes))
112
+
113
+ yield state.get_history(), (44100, np_audio), None, None
114
+ elif event.type == FishE2EEventType.TEXT_SEGMENT:
115
+ append_to_chat_ctx(ServeTextPart(text=event.text))
116
+ if result_audio:
117
+ np_audio = np.frombuffer(result_audio, dtype=np.int16)
118
+ yield state.get_history(), (44100, np_audio), None, None
119
+ else:
120
+ yield state.get_history(), None, None, None
121
+
122
+ np_audio = np.frombuffer(result_audio, dtype=np.int16)
123
+ yield state.get_history(), (44100, np_audio), None, None
124
+
125
+
126
+ async def process_text_input(
127
+ sys_audio_input, sys_text_input, state: ChatState, text_input: str
128
+ ):
129
+ async for event in process_audio_input(
130
+ sys_audio_input, sys_text_input, None, state, text_input
131
  ):
132
+ yield event
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
 
135
+ def create_demo():
136
+ with gr.Blocks() as demo:
137
+ state = gr.State(ChatState())
138
+
139
+ with gr.Row():
140
+ # Left column (70%) for chatbot and notes
141
+ with gr.Column(scale=7):
142
+ chatbot = gr.Chatbot(
143
+ [],
144
+ elem_id="chatbot",
145
+ bubble_full_width=False,
146
+ height=600,
147
+ type="messages",
148
+ )
149
+
150
+ notes = gr.Markdown(
151
+ """
152
+ # Fish Agent
153
+ 1. 此Demo为Fish Audio自研端到端语言模型Fish Agent 3B版本.
154
+ 2. 你可以在我们的官方仓库找到代码以及权重,但是相关内容全部基于 CC BY-NC-SA 4.0 许可证发布.
155
+ 3. Demo为早期灰度测试版本,推理速度尚待优化.
156
+ # 特色
157
+ 1. 该模型自动集成ASR与TTS部分,不需要外挂其它模型,即真正的端到端,而非三段式(ASR+LLM+TTS).
158
+ 2. 模型可以使用reference audio控制说话音色.
159
+ 3. 可以生成具有较强情感与韵律的音频.
160
+ """
161
+ )
162
+
163
+ # Right column (30%) for controls
164
+ with gr.Column(scale=3):
165
+ sys_audio_input = gr.Audio(
166
+ sources=["upload"],
167
+ type="numpy",
168
+ label="Give a timbre for your assistant",
169
+ )
170
+ sys_text_input = gr.Textbox(
171
+ label="What is your assistant's role?",
172
+ value='您是由 Fish Audio 设计的语音助手,提供端到端的语音交互,实现无缝用户体验。首先转录用户的语音,然后使用以下格式回答:"Question: [用户语音]\n\nResponse: [你的回答]\n"。',
173
+ type="text",
174
+ )
175
+ audio_input = gr.Audio(
176
+ sources=["microphone"], type="numpy", label="Speak your message"
177
+ )
178
+
179
+ text_input = gr.Textbox(label="Or type your message", type="text")
180
+
181
+ output_audio = gr.Audio(label="Assistant's Voice", type="numpy")
182
+
183
+ send_button = gr.Button("Send", variant="primary")
184
+ clear_button = gr.Button("Clear")
185
+
186
+ # Event handlers
187
+ audio_input.stop_recording(
188
+ process_audio_input,
189
+ inputs=[sys_audio_input, sys_text_input, audio_input, state, text_input],
190
+ outputs=[chatbot, output_audio, audio_input, text_input],
191
+ show_progress=True,
192
+ )
193
+
194
+ send_button.click(
195
+ process_text_input,
196
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
197
+ outputs=[chatbot, output_audio, audio_input, text_input],
198
+ show_progress=True,
199
+ )
200
+
201
+ text_input.submit(
202
+ process_text_input,
203
+ inputs=[sys_audio_input, sys_text_input, state, text_input],
204
+ outputs=[chatbot, output_audio, audio_input, text_input],
205
+ show_progress=True,
206
+ )
207
+
208
+ clear_button.click(
209
+ clear_fn,
210
+ inputs=[],
211
+ outputs=[chatbot, state, audio_input, output_audio, text_input],
212
+ )
213
+
214
+ return demo
215
+
216
+ def run_api():
217
+ subprocess.run([sys.executable, "-m", "tools.api"])
218
+
219
  if __name__ == "__main__":
220
+
221
+ # 创建并启动 API 线程
222
+ api_thread = threading.Thread(target=run_api, daemon=True)
223
+ api_thread.start()
224
+
225
+ # 给 API 一些时间启动
226
+ time.sleep(60)
227
+
228
+ # 创建并启动 Gradio demo
229
+ demo = create_demo()
230
+ demo.launch(server_name="127.0.0.1", server_port=7860, share=True)
fish_speech/callbacks/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .grad_norm import GradNormMonitor
2
+
3
+ __all__ = ["GradNormMonitor"]
fish_speech/callbacks/grad_norm.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import lightning.pytorch as pl
4
+ import torch
5
+ from lightning import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Callback
7
+ from torch import Tensor, nn
8
+ from torch.utils._foreach_utils import (
9
+ _group_tensors_by_device_and_dtype,
10
+ _has_foreach_support,
11
+ )
12
+
13
+
14
+ @torch.no_grad()
15
+ def grad_norm(
16
+ parameters: Union[Tensor, list[Tensor]],
17
+ norm_type: float = 2.0,
18
+ ) -> float:
19
+ """
20
+ Returns the norm of the gradients of the given parameters.
21
+
22
+ Args:
23
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
24
+ single Tensor that will have gradients normalized
25
+ norm_type (float): type of the used p-norm.
26
+
27
+ Returns:
28
+ Total norm of the parameter gradients (viewed as a single vector).
29
+ """ # noqa: E501
30
+
31
+ if isinstance(parameters, Tensor):
32
+ parameters = [parameters]
33
+
34
+ grads = [p.grad for p in parameters if p.grad is not None]
35
+ if len(grads) == 0:
36
+ return None
37
+
38
+ first_device = grads[0].device
39
+ grouped_grads: dict[
40
+ tuple[torch.device, torch.dtype], list[list[Tensor]]
41
+ ] = _group_tensors_by_device_and_dtype(
42
+ [[g.detach() for g in grads]]
43
+ ) # type: ignore[assignment]
44
+
45
+ norms = []
46
+ for (device, _), ([grads], _) in grouped_grads.items():
47
+ if _has_foreach_support(grads, device=device):
48
+ norms.extend(torch._foreach_norm(grads, norm_type))
49
+ else:
50
+ norms.extend([torch.norm(g, norm_type) for g in grads])
51
+
52
+ return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
53
+
54
+
55
+ class GradNormMonitor(Callback):
56
+ """
57
+ Callback that computes the gradient norm of the model parameters.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ norm_type: float = 2.0,
63
+ logging_interval: str = "step",
64
+ sub_module: Optional[Union[str, list[str]]] = None,
65
+ ) -> None:
66
+ """
67
+ Args:
68
+ norm_type (float): type of the used p-norm.
69
+ logging_interval (str): "step" or "epoch".
70
+ """
71
+ super().__init__()
72
+
73
+ self.norm_type = norm_type
74
+ self.logging_interval = logging_interval
75
+ self.sub_module = sub_module
76
+
77
+ def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
78
+ """
79
+ Computes the gradient norm of the model parameters and logs it to the logger.
80
+
81
+ Args:
82
+ trainer (Trainer): The trainer object
83
+ model (LightningModule): The current lightningModule
84
+ """
85
+
86
+ lightning_model = model
87
+
88
+ if self.sub_module is None:
89
+ return self.log_sub_module_grad_norm(lightning_model, model, "")
90
+
91
+ sub_modules = self.sub_module
92
+ if isinstance(sub_modules, str):
93
+ sub_modules = [sub_modules]
94
+
95
+ for sub_module in sub_modules:
96
+ self.log_sub_module_grad_norm(
97
+ lightning_model, getattr(model, sub_module), f"/{sub_module}"
98
+ )
99
+
100
+ def log_sub_module_grad_norm(
101
+ self, lightning_model: LightningModule, model: nn.Module, path: str
102
+ ) -> None:
103
+ grad_norm_val = grad_norm(model.parameters(), self.norm_type)
104
+ if grad_norm_val is None:
105
+ return
106
+
107
+ on_step = self.logging_interval == "step"
108
+ lightning_model.log(
109
+ f"train{path}/grad_norm",
110
+ grad_norm_val,
111
+ on_step=on_step,
112
+ on_epoch=not on_step,
113
+ )
fish_speech/configs/base.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base configuration for training a model
2
+ paths:
3
+ run_dir: results/${project}
4
+ ckpt_dir: ${paths.run_dir}/checkpoints
5
+
6
+ hydra:
7
+ run:
8
+ dir: ${paths.run_dir}
9
+
10
+ # Lightning Trainer
11
+ trainer:
12
+ _target_: lightning.pytorch.trainer.Trainer
13
+
14
+ default_root_dir: ${paths.run_dir}
15
+ accelerator: gpu
16
+ num_nodes: 1
17
+ devices: auto
18
+ strategy:
19
+ _target_: lightning.pytorch.strategies.DDPStrategy
20
+ process_group_backend: nccl # This should be override when training on windows
21
+
22
+ precision: bf16-mixed
23
+
24
+ # disable validation by epoch end
25
+ check_val_every_n_epoch: null
26
+ val_check_interval: 5000
27
+ max_steps: 100_000
28
+
29
+ # Use torch.backends.cudnn.benchmark to speed up training
30
+ benchmark: true
31
+
32
+ # Callbacks
33
+ callbacks:
34
+ model_checkpoint:
35
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
36
+ dirpath: ${paths.ckpt_dir}
37
+ filename: "step_{step:09d}"
38
+ save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
39
+ save_top_k: 5 # save 5 latest checkpoints
40
+ monitor: step # use step to monitor checkpoints
41
+ mode: max # save the latest checkpoint with the highest global_step
42
+ every_n_epochs: null # don't save checkpoints by epoch end
43
+ every_n_train_steps: 5000 # save checkpoints every 5000 steps
44
+ auto_insert_metric_name: false
45
+
46
+ model_summary:
47
+ _target_: lightning.pytorch.callbacks.ModelSummary
48
+ max_depth: 2 # the maximum depth of layer nesting that the summary will include
49
+
50
+ learning_rate_monitor:
51
+ _target_: lightning.pytorch.callbacks.LearningRateMonitor
52
+ logging_interval: step
53
+ log_momentum: false
54
+
55
+ grad_norm_monitor:
56
+ _target_: fish_speech.callbacks.GradNormMonitor
57
+ norm_type: 2
58
+ logging_interval: step
59
+
60
+ # Logger
61
+ logger:
62
+ tensorboard:
63
+ _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
64
+ save_dir: "${paths.run_dir}/tensorboard/"
65
+ name: null
66
+ log_graph: false
67
+ default_hp_metric: true
68
+ prefix: ""
69
+
70
+ # wandb:
71
+ # _target_: lightning.pytorch.loggers.wandb.WandbLogger
72
+ # # name: "" # name of the run (normally generated by wandb)
73
+ # save_dir: "${paths.run_dir}"
74
+ # offline: False
75
+ # id: null # pass correct id to resume experiment!
76
+ # anonymous: null # enable anonymous logging
77
+ # project: "fish-speech"
78
+ # log_model: False # upload lightning ckpts
79
+ # prefix: "" # a string to put at the beginning of metric keys
80
+ # # entity: "" # set to name of your wandb team
81
+ # group: ""
82
+ # tags: ["vq", "hq", "finetune"]
83
+ # job_type: ""
84
+
85
+ # Loop
86
+ train: true
87
+ test: false
fish_speech/configs/firefly_gan_vq.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: fish_speech.models.vqgan.modules.firefly.FireflyArchitecture
2
+ spec_transform:
3
+ _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
4
+ sample_rate: 44100
5
+ n_mels: 160
6
+ n_fft: 2048
7
+ hop_length: 512
8
+ win_length: 2048
9
+ backbone:
10
+ _target_: fish_speech.models.vqgan.modules.firefly.ConvNeXtEncoder
11
+ input_channels: 160
12
+ depths: [3, 3, 9, 3]
13
+ dims: [128, 256, 384, 512]
14
+ drop_path_rate: 0.2
15
+ kernel_size: 7
16
+ head:
17
+ _target_: fish_speech.models.vqgan.modules.firefly.HiFiGANGenerator
18
+ hop_length: 512
19
+ upsample_rates: [8, 8, 2, 2, 2] # aka. strides
20
+ upsample_kernel_sizes: [16, 16, 4, 4, 4]
21
+ resblock_kernel_sizes: [3, 7, 11]
22
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
23
+ num_mels: 512
24
+ upsample_initial_channel: 512
25
+ pre_conv_kernel_size: 13
26
+ post_conv_kernel_size: 13
27
+ quantizer:
28
+ _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
29
+ input_dim: 512
30
+ n_groups: 8
31
+ n_codebooks: 1
32
+ levels: [8, 5, 5, 5]
33
+ downsample_factor: [2, 2]
fish_speech/configs/lora/r_8_alpha_16.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: fish_speech.models.text2semantic.lora.LoraConfig
2
+ r: 8
3
+ lora_alpha: 16
4
+ lora_dropout: 0.01
fish_speech/configs/text2semantic_finetune.yaml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base
3
+ - _self_
4
+
5
+ project: text2semantic_finetune_dual_ar
6
+ max_length: 4096
7
+ pretrained_ckpt_path: checkpoints/fish-speech-1.4
8
+
9
+ # Lightning Trainer
10
+ trainer:
11
+ accumulate_grad_batches: 1
12
+ gradient_clip_val: 1.0
13
+ gradient_clip_algorithm: "norm"
14
+ max_steps: 1000
15
+ precision: bf16-true
16
+ limit_val_batches: 10
17
+ val_check_interval: 100
18
+
19
+ # Dataset Configuration
20
+ tokenizer:
21
+ _target_: transformers.AutoTokenizer.from_pretrained
22
+ pretrained_model_name_or_path: ${pretrained_ckpt_path}
23
+
24
+ # Dataset Configuration
25
+ train_dataset:
26
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
27
+ proto_files:
28
+ - data/protos
29
+ tokenizer: ${tokenizer}
30
+ causal: true
31
+ max_length: ${max_length}
32
+ use_speaker: false
33
+ interactive_prob: 0.7
34
+
35
+ val_dataset:
36
+ _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
37
+ proto_files:
38
+ - data/protos
39
+ tokenizer: ${tokenizer}
40
+ causal: true
41
+ max_length: ${max_length}
42
+ use_speaker: false
43
+ interactive_prob: 0.7
44
+
45
+ data:
46
+ _target_: fish_speech.datasets.semantic.SemanticDataModule
47
+ train_dataset: ${train_dataset}
48
+ val_dataset: ${val_dataset}
49
+ num_workers: 4
50
+ batch_size: 8
51
+ tokenizer: ${tokenizer}
52
+ max_length: ${max_length}
53
+
54
+ # Model Configuration
55
+ model:
56
+ _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
57
+ model:
58
+ _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
59
+ path: ${pretrained_ckpt_path}
60
+ load_weights: true
61
+ max_length: ${max_length}
62
+ lora_config: null
63
+
64
+ optimizer:
65
+ _target_: torch.optim.AdamW
66
+ _partial_: true
67
+ lr: 1e-4
68
+ weight_decay: 0
69
+ betas: [0.9, 0.95]
70
+ eps: 1e-5
71
+
72
+ lr_scheduler:
73
+ _target_: torch.optim.lr_scheduler.LambdaLR
74
+ _partial_: true
75
+ lr_lambda:
76
+ _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
77
+ _partial_: true
78
+ num_warmup_steps: 10
79
+
80
+ # Callbacks
81
+ callbacks:
82
+ model_checkpoint:
83
+ every_n_train_steps: ${trainer.val_check_interval}
fish_speech/conversation.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Literal
3
+
4
+ import torch
5
+ from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerFast
6
+
7
+ IM_START_TOKEN = "<|im_start|>"
8
+ IM_END_TOKEN = "<|im_end|>"
9
+ SEMANTIC_TOKEN = "<|semantic|>"
10
+ MEL_TOKEN = "<|mel|>"
11
+ PHONEME_START_TOKEN = "<|phoneme_start|>"
12
+ PHONEME_END_TOKEN = "<|phoneme_end|>"
13
+ ALL_SPECIAL_TOKENS = [
14
+ IM_START_TOKEN,
15
+ IM_END_TOKEN,
16
+ SEMANTIC_TOKEN,
17
+ MEL_TOKEN,
18
+ PHONEME_START_TOKEN,
19
+ PHONEME_END_TOKEN,
20
+ ]
21
+
22
+ CODEBOOK_PAD_TOKEN_ID = 0
23
+
24
+
25
+ class FishTokenizerConfig(PretrainedConfig):
26
+ share_codebook_embeddings: bool = True
27
+ codebook_size: int = 1024
28
+ num_codebooks: int = 8
29
+
30
+
31
+ class FishTokenizerFast(PreTrainedTokenizerFast):
32
+ def __init__(self, *args, **kwargs):
33
+ super().__init__(*args, **kwargs)
34
+ self.share_codebook_embeddings = kwargs.pop("share_codebook_embeddings", True)
35
+ self.codebook_size = kwargs.pop("codebook_size", 1024)
36
+ self.num_codebooks = kwargs.pop("num_codebooks", 8)
37
+
38
+
39
+ AutoTokenizer.register(FishTokenizerConfig, fast_tokenizer_class=FishTokenizerFast)
40
+
41
+
42
+ @dataclass(kw_only=True)
43
+ class BasePart:
44
+ pass
45
+
46
+
47
+ @dataclass(kw_only=True)
48
+ class VQPart(BasePart):
49
+ codes: torch.Tensor
50
+
51
+
52
+ @dataclass(kw_only=True)
53
+ class TextPart(BasePart):
54
+ text: str
55
+
56
+
57
+ @dataclass(kw_only=True)
58
+ class MelPart(BasePart):
59
+ mels: torch.Tensor
60
+
61
+
62
+ @dataclass(kw_only=True)
63
+ class EncodedMessage:
64
+ tokens: torch.Tensor
65
+ labels: torch.Tensor
66
+ vq_parts: list[torch.Tensor]
67
+ mel_parts: list[torch.Tensor]
68
+ vq_require_losses: torch.Tensor | None = None
69
+
70
+
71
+ @dataclass(kw_only=True)
72
+ class Message:
73
+ role: Literal["system", "user", "assistant"]
74
+ parts: list[VQPart | TextPart | MelPart] = field(default_factory=list)
75
+ add_im_start: bool = True
76
+ add_im_end: bool = True
77
+ cal_loss: bool = False
78
+
79
+ # By default, ignore the loss of the auto-generated im_start token
80
+ ignore_im_start_loss: bool = True
81
+
82
+ def encode(
83
+ self: "Message",
84
+ tokenizer: AutoTokenizer,
85
+ ) -> EncodedMessage:
86
+ all_tokens = []
87
+ all_labels = []
88
+
89
+ # Multi-modal tokens
90
+ vq_parts = []
91
+ mel_parts = []
92
+
93
+ semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
94
+ [SEMANTIC_TOKEN, MEL_TOKEN]
95
+ )
96
+
97
+ parts = self.parts.copy()
98
+ if self.add_im_start:
99
+ parts.insert(0, TextPart(text=f"<|im_start|>{self.role}\n"))
100
+
101
+ if self.add_im_end:
102
+ parts.append(TextPart(text="<|im_end|>"))
103
+
104
+ for part in parts:
105
+ if isinstance(part, TextPart):
106
+ tokens = tokenizer.encode(
107
+ part.text,
108
+ add_special_tokens=False,
109
+ truncation=False,
110
+ return_tensors="pt",
111
+ ).int()[0]
112
+ elif isinstance(part, VQPart):
113
+ tokens = torch.zeros(part.codes.shape[1], dtype=torch.int) + semantic_id
114
+ codes = part.codes.clone() + 1
115
+
116
+ if getattr(tokenizer, "share_codebook_embeddings", True) is False:
117
+ for i in range(len(codes)):
118
+ codes[i] += tokenizer.codebook_size * i
119
+
120
+ vq_parts.append(codes)
121
+ elif isinstance(part, MelPart):
122
+ tokens = torch.zeros(part.mels.shape[1], dtype=torch.int) + mel_id
123
+ mel_parts.append(part.mels)
124
+ else:
125
+ raise ValueError(f"Unsupported part type: {type(part)}")
126
+
127
+ all_tokens.append(tokens)
128
+ if self.cal_loss:
129
+ all_labels.append(tokens.clone())
130
+ else:
131
+ all_labels.append(torch.full_like(tokens, -100))
132
+
133
+ tokens = torch.cat(all_tokens, dim=0)
134
+ labels = torch.cat(all_labels, dim=0)
135
+ assert tokens.shape == labels.shape
136
+
137
+ if self.ignore_im_start_loss and self.add_im_start:
138
+ labels[: len(all_tokens[0])] = -100
139
+
140
+ return EncodedMessage(
141
+ tokens=tokens,
142
+ labels=labels,
143
+ vq_parts=vq_parts,
144
+ mel_parts=mel_parts,
145
+ )
146
+
147
+
148
+ @dataclass
149
+ class Conversation:
150
+ messages: list[Message]
151
+
152
+ def encode(
153
+ self: "Conversation",
154
+ tokenizer: AutoTokenizer,
155
+ add_shift: bool = True,
156
+ ) -> EncodedMessage:
157
+ # Build the input_ids and labels
158
+ tokens = []
159
+ labels = []
160
+ vq_parts = []
161
+ mel_parts = []
162
+ vq_require_losses = []
163
+
164
+ for message in self.messages:
165
+ encoded = message.encode(
166
+ tokenizer,
167
+ )
168
+ tokens.append(encoded.tokens)
169
+ labels.append(encoded.labels)
170
+ vq_parts.extend(encoded.vq_parts)
171
+ mel_parts.extend(encoded.mel_parts)
172
+ vq_require_losses.extend([message.cal_loss] * len(encoded.vq_parts))
173
+
174
+ tokens = torch.cat(tokens, dim=0)
175
+ labels = torch.cat(labels, dim=0)
176
+ vq_require_losses = torch.tensor(vq_require_losses, dtype=torch.bool)
177
+
178
+ if add_shift:
179
+ tokens = tokens[:-1]
180
+ labels = labels[1:]
181
+
182
+ assert tokens.dtype in [
183
+ torch.int,
184
+ torch.long,
185
+ ], f"Invalid dtype: {tokens.dtype}, conv: {conversation}"
186
+
187
+ return EncodedMessage(
188
+ tokens=tokens,
189
+ labels=labels,
190
+ vq_parts=vq_parts,
191
+ mel_parts=mel_parts,
192
+ vq_require_losses=vq_require_losses,
193
+ )
194
+
195
+ def encode_for_inference(
196
+ self: "Conversation",
197
+ tokenizer: AutoTokenizer,
198
+ num_codebooks: int,
199
+ ) -> EncodedMessage:
200
+ encoded = self.encode(tokenizer, add_shift=False)
201
+ tokens = encoded.tokens
202
+ values = torch.zeros((num_codebooks + 1, len(tokens)), dtype=torch.int)
203
+ values[0] = tokens
204
+
205
+ if encoded.vq_parts is None or len(encoded.vq_parts) == 0:
206
+ return values
207
+
208
+ semantic_id, mel_id = tokenizer.convert_tokens_to_ids(
209
+ [SEMANTIC_TOKEN, MEL_TOKEN]
210
+ )
211
+ vq_parts = encoded.vq_parts
212
+ vq_parts = torch.cat(vq_parts, dim=1)
213
+ values[1:, tokens == semantic_id] = vq_parts
214
+ return values
215
+
216
+ def visualize(self: "Conversation", tokenizer: AutoTokenizer):
217
+ encoded = self.encode(tokenizer, add_shift=False)
218
+
219
+ print_in_blue = lambda x: print("\033[94m" + x + "\033[0m", end="")
220
+ print_in_green = lambda x: print("\033[92m" + x + "\033[0m", end="")
221
+
222
+ for tok, lab in zip(encoded.tokens, encoded.labels):
223
+ val = tokenizer.decode(tok, skip_special_tokens=False)
224
+ if val == "\n":
225
+ val = "\\n\n"
226
+
227
+ if lab == -100:
228
+ print_in_green(val)
229
+ else:
230
+ print_in_blue(val)
231
+
232
+ print()
233
+
234
+
235
+ if __name__ == "__main__":
236
+ message0 = Message(
237
+ role="user",
238
+ parts=[
239
+ TextPart(text="Hello, how are you?"),
240
+ VQPart(codes=torch.zeros((4, 10))),
241
+ ],
242
+ cal_loss=False,
243
+ )
244
+
245
+ message1 = Message(
246
+ role="assistant",
247
+ parts=[TextPart(text="I'm fine, thank you.")],
248
+ cal_loss=True,
249
+ )
250
+ conversation = Conversation([message0, message1])
251
+ tokenizer = AutoTokenizer.from_pretrained("checkpoints/Qwen2-1.5B-Instruct")
252
+ conversation.visualize(tokenizer)
253
+
254
+ encoded = conversation.encode(tokenizer)
255
+ print(encoded)
256
+ print(tokenizer.batch_decode(encoded.tokens))
fish_speech/datasets/concat_repeat.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import random
3
+ from typing import Iterable
4
+
5
+ from torch.utils.data import Dataset, IterableDataset
6
+
7
+
8
+ class ConcatRepeatDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+ repeats: list[int]
12
+
13
+ @staticmethod
14
+ def cumsum(sequence, repeats):
15
+ r, s = [], 0
16
+ for dataset, repeat in zip(sequence, repeats):
17
+ l = len(dataset) * repeat
18
+ r.append(l + s)
19
+ s += l
20
+ return r
21
+
22
+ def __init__(self, datasets: Iterable[Dataset], repeats: list[int]):
23
+ super().__init__()
24
+
25
+ self.datasets = list(datasets)
26
+ self.repeats = repeats
27
+
28
+ assert len(self.datasets) > 0, "datasets should not be an empty iterable"
29
+ assert len(self.datasets) == len(
30
+ repeats
31
+ ), "datasets and repeats should have the same length"
32
+
33
+ for d in self.datasets:
34
+ assert not isinstance(
35
+ d, IterableDataset
36
+ ), "ConcatRepeatDataset does not support IterableDataset"
37
+
38
+ self.cumulative_sizes = self.cumsum(self.datasets, self.repeats)
39
+
40
+ def __len__(self):
41
+ return self.cumulative_sizes[-1]
42
+
43
+ def __getitem__(self, idx):
44
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
45
+
46
+ if dataset_idx == 0:
47
+ sample_idx = idx
48
+ else:
49
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
50
+
51
+ dataset = self.datasets[dataset_idx]
52
+
53
+ return dataset[sample_idx % len(dataset)]
fish_speech/datasets/protos/text-data.proto ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ syntax = "proto3";
2
+
3
+ package text_data;
4
+
5
+ message Semantics {
6
+ repeated uint32 values = 1;
7
+ }
8
+
9
+ message Sentence {
10
+ repeated string texts = 1;
11
+ repeated Semantics semantics = 3;
12
+ }
13
+
14
+ message TextData {
15
+ string source = 1;
16
+ string name = 2;
17
+ repeated Sentence sentences = 4;
18
+ }
19
+
20
+ message SampledData {
21
+ string source = 1;
22
+ string name = 2;
23
+ repeated Sentence samples = 3;
24
+ }
fish_speech/datasets/protos/text_data_pb2.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: text-data.proto
4
+ # Protobuf Python Version: 4.25.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+
11
+ # @@protoc_insertion_point(imports)
12
+
13
+ _sym_db = _symbol_database.Default()
14
+
15
+
16
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
17
+ b'\n\x0ftext-data.proto\x12\ttext_data"\x1b\n\tSemantics\x12\x0e\n\x06values\x18\x01 \x03(\r"B\n\x08Sentence\x12\r\n\x05texts\x18\x01 \x03(\t\x12\'\n\tsemantics\x18\x03 \x03(\x0b\x32\x14.text_data.Semantics"P\n\x08TextData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12&\n\tsentences\x18\x04 \x03(\x0b\x32\x13.text_data.Sentence"Q\n\x0bSampledData\x12\x0e\n\x06source\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12$\n\x07samples\x18\x03 \x03(\x0b\x32\x13.text_data.Sentenceb\x06proto3'
18
+ )
19
+
20
+ _globals = globals()
21
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "text_data_pb2", _globals)
23
+ if _descriptor._USE_C_DESCRIPTORS == False:
24
+ DESCRIPTOR._options = None
25
+ _globals["_SEMANTICS"]._serialized_start = 30
26
+ _globals["_SEMANTICS"]._serialized_end = 57
27
+ _globals["_SENTENCE"]._serialized_start = 59
28
+ _globals["_SENTENCE"]._serialized_end = 125
29
+ _globals["_TEXTDATA"]._serialized_start = 127
30
+ _globals["_TEXTDATA"]._serialized_end = 207
31
+ _globals["_SAMPLEDDATA"]._serialized_start = 209
32
+ _globals["_SAMPLEDDATA"]._serialized_end = 290
33
+ # @@protoc_insertion_point(module_scope)
fish_speech/datasets/protos/text_data_stream.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import struct
2
+
3
+ from .text_data_pb2 import TextData
4
+
5
+
6
+ def read_pb_stream(f):
7
+ while True:
8
+ buf = f.read(4)
9
+ if len(buf) == 0:
10
+ break
11
+ size = struct.unpack("I", buf)[0]
12
+ buf = f.read(size)
13
+ text_data = TextData()
14
+ text_data.ParseFromString(buf)
15
+ yield text_data
16
+
17
+
18
+ def write_pb_stream(f, text_data):
19
+ buf = text_data.SerializeToString()
20
+ f.write(struct.pack("I", len(buf)))
21
+ f.write(buf)
22
+
23
+
24
+ def pack_pb_stream(text_data):
25
+ buf = text_data.SerializeToString()
26
+ return struct.pack("I", len(buf)) + buf
27
+
28
+
29
+ def split_pb_stream(f):
30
+ while True:
31
+ head = f.read(4)
32
+ if len(head) == 0:
33
+ break
34
+ size = struct.unpack("I", head)[0]
35
+ buf = f.read(size)
36
+ yield head + buf
fish_speech/datasets/semantic.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from itertools import chain
4
+ from pathlib import Path
5
+ from random import Random
6
+ from typing import Optional, Union
7
+
8
+ import numpy as np
9
+ import pyarrow.parquet as pq
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from datasets.download.streaming_download_manager import xopen
13
+ from huggingface_hub import HfApi
14
+ from lightning import LightningDataModule
15
+ from torch.distributed import get_rank, get_world_size, is_initialized
16
+ from torch.utils.data import DataLoader, IterableDataset, get_worker_info
17
+ from transformers import AutoTokenizer
18
+
19
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
20
+ from fish_speech.datasets.protos.text_data_pb2 import SampledData
21
+ from fish_speech.datasets.protos.text_data_stream import read_pb_stream
22
+ from fish_speech.text.clean import clean_text
23
+ from fish_speech.utils import RankedLogger
24
+ from fish_speech.utils.braceexpand import braceexpand
25
+
26
+ log = RankedLogger(__name__, rank_zero_only=True)
27
+
28
+
29
+ def split_by_rank_worker(files):
30
+ # We need to know the total number of devices
31
+ # to split the data properly
32
+
33
+ total_devices = 1
34
+ if is_initialized():
35
+ total_devices = get_world_size()
36
+
37
+ worker_info = get_worker_info()
38
+ if worker_info is not None:
39
+ total_devices *= worker_info.num_workers
40
+
41
+ if len(files) < total_devices:
42
+ # Repeat the files N times to match the number of devices
43
+ files = files * (total_devices // len(files) + 1)
44
+
45
+ # DDP
46
+ if is_initialized():
47
+ files = files[get_rank() :: get_world_size()]
48
+
49
+ # Split by worker
50
+ if worker_info is not None:
51
+ files = files[worker_info.id :: worker_info.num_workers]
52
+
53
+ return files
54
+
55
+
56
+ class AutoTextSemanticInstructionDataset(IterableDataset):
57
+ """
58
+ Auto Augment Dataset by Speaker
59
+
60
+ 1. Random concatenate multiple sentences from the same speaker to form a longer sentence
61
+ 2. Automatically normalize the text
62
+
63
+ For interactive mode, we use the following format (multiple sequences):
64
+ <s> [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST] </s>
65
+
66
+ For non-interactive mode, we use the following format (one long sequence):
67
+ <s> [INST] text [/INST] ... </s>
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ proto_files: list[str],
73
+ seed: int = 42,
74
+ interactive_prob: float = 0.5,
75
+ max_length: int = 1024,
76
+ tokenizer: AutoTokenizer = None,
77
+ use_speaker: bool | float = True,
78
+ causal: bool = True,
79
+ num_codebooks: Optional[int] = None,
80
+ skip_text_prob: float = 0.0,
81
+ ):
82
+ """
83
+ Args:
84
+ proto_files: proto buf files if using local data
85
+ seed: random seed
86
+ interactive_prob: probability to use interactive mode
87
+ max_length: max length of the text
88
+ tokenizer: tokenizer
89
+ use_speaker: include speaker information in the prompt
90
+ causal: use causal sampling when using local data, disable will lead to random sampling
91
+ num_codebooks: number of codebooks, if None, it will be automatically detected
92
+ skip_text_prob: probability to skip the text (audio only), this only applies to interactive mode
93
+ """
94
+
95
+ super().__init__()
96
+
97
+ assert 0 <= interactive_prob <= 1, "interactive_prob must be in [0, 1]"
98
+
99
+ self.seed = seed
100
+ self.max_length = max_length
101
+ self.tokenizer = tokenizer
102
+ self.interactive_prob = interactive_prob
103
+ self.use_speaker = use_speaker
104
+ self.proto_files = proto_files
105
+ self.causal = causal
106
+ self.num_codebooks = num_codebooks
107
+ self.skip_text_prob = skip_text_prob
108
+
109
+ self.semantic_token_id = self.tokenizer.convert_tokens_to_ids("<|semantic|>")
110
+ self.groups = None
111
+
112
+ def init_mock_data_server(self):
113
+ if self.groups is not None:
114
+ return
115
+
116
+ # Expand the proto files
117
+ expanded_proto_files = []
118
+ for filename in self.proto_files:
119
+ for i in braceexpand(filename):
120
+ i = Path(i)
121
+ if i.is_file():
122
+ expanded_proto_files.append(i)
123
+ elif i.is_dir():
124
+ expanded_proto_files.extend(i.rglob("*.proto"))
125
+ expanded_proto_files.extend(i.rglob("*.protos"))
126
+ else:
127
+ raise ValueError(f"{i} is not a file or directory")
128
+
129
+ expanded_proto_files = sorted(expanded_proto_files)
130
+ Random(self.seed).shuffle(expanded_proto_files)
131
+
132
+ self.groups = []
133
+ shard_proto_files = split_by_rank_worker(expanded_proto_files)
134
+ log.info(
135
+ f"Reading {len(shard_proto_files)} / {len(expanded_proto_files)} files"
136
+ )
137
+
138
+ count = 0
139
+ for filename in shard_proto_files:
140
+ with open(filename, "rb") as f:
141
+ for text_data in read_pb_stream(f):
142
+ self.groups.append(text_data)
143
+ count += 1
144
+
145
+ log.info(f"Read total {count} groups of data")
146
+
147
+ # Shuffle the lines
148
+ Random(self.seed).shuffle(self.groups)
149
+ self.group_weights = [len(i.sentences) for i in self.groups]
150
+
151
+ def __iter__(self):
152
+ while True:
153
+ yield self.augment()
154
+
155
+ def tokenize_sentence(self, sentence: str):
156
+ sentence = clean_text(sentence)
157
+ tokens = self.tokenizer.encode(
158
+ f"{sentence}",
159
+ max_length=10**6,
160
+ add_special_tokens=False,
161
+ truncation=False,
162
+ )
163
+ return sentence, len(tokens)
164
+
165
+ def sample_data(self):
166
+ if self.groups is None:
167
+ self.init_mock_data_server()
168
+
169
+ # Shuffle unique lines, estimate that each sample is at least 20 tokens
170
+ num_samples = self.max_length // 20
171
+
172
+ # choice group based on their number of samples
173
+ group = random.choices(self.groups, weights=self.group_weights, k=1)[0]
174
+
175
+ if self.causal:
176
+ # Sample in order
177
+ if num_samples >= len(group.sentences):
178
+ samples = group.sentences
179
+ else:
180
+ begin = random.randint(0, len(group.sentences) - num_samples)
181
+ samples = group.sentences[begin : begin + num_samples]
182
+ else:
183
+ samples = random.choices(
184
+ group.sentences, k=min(num_samples, len(group.sentences))
185
+ )
186
+
187
+ return SampledData(
188
+ source=group.source,
189
+ name=group.name,
190
+ samples=samples,
191
+ )
192
+
193
+ def augment(self):
194
+ final_text, final_semantic = [], []
195
+ response = self.sample_data()
196
+ if len(response.samples) == 0:
197
+ # Invalid group
198
+ return None
199
+
200
+ samples = list(response.samples)
201
+ idx = 0
202
+ use_interactive = random.random() < self.interactive_prob
203
+
204
+ if use_interactive is False:
205
+ # Random sample based on speaker using a truncated normal distribution
206
+ a = torch.tensor([0], dtype=torch.float32)
207
+ torch.nn.init.trunc_normal_(
208
+ a,
209
+ mean=self.max_length // 2,
210
+ std=self.max_length // 4,
211
+ a=10,
212
+ b=self.max_length,
213
+ )
214
+ remaining_tokens = a.long().item() - 4
215
+ else:
216
+ remaining_tokens = self.max_length
217
+
218
+ # Use speaker
219
+ if isinstance(self.use_speaker, float):
220
+ use_speaker = random.random() < self.use_speaker
221
+ else:
222
+ use_speaker = self.use_speaker
223
+
224
+ all_tokens, all_labels = [], []
225
+ while remaining_tokens > 0 and len(samples) > 0:
226
+ sentence = samples.pop(0)
227
+
228
+ text = random.choice(sentence.texts)
229
+ text, length = self.tokenize_sentence(text)
230
+ remaining_tokens -= length + len(sentence.semantics[0].values)
231
+
232
+ if use_interactive is False:
233
+ final_text.append(text)
234
+ final_semantic.append(sentence.semantics)
235
+ else:
236
+ # For interactive mode, we only apply speaker for the first sentence
237
+ # [INST] [SPK: speaker] text [/INST] ... [INST] text [/INST]
238
+ tokens, labels = self.pack_sentences(
239
+ sentences=[text],
240
+ semantics=[sentence.semantics],
241
+ speaker=response.name if use_speaker else None,
242
+ skip_text=random.random() < self.skip_text_prob,
243
+ )
244
+
245
+ all_tokens.append(tokens)
246
+ all_labels.append(labels)
247
+
248
+ idx += 1
249
+
250
+ if use_interactive is False:
251
+ tokens, labels = self.pack_sentences(
252
+ final_text,
253
+ semantics=final_semantic,
254
+ speaker=response.name if use_speaker else None,
255
+ )
256
+ all_tokens.append(tokens)
257
+ all_labels.append(labels)
258
+
259
+ tokens = torch.cat(all_tokens, dim=1)
260
+ labels = torch.cat(all_labels, dim=1)
261
+
262
+ # Verify that the length is correct
263
+ assert tokens.size(1) == labels.size(1), f"{tokens.size(1)} != {labels.size(1)}"
264
+
265
+ data = {"tokens": tokens, "labels": labels}
266
+
267
+ return data
268
+
269
+ def pack_sentences(
270
+ self,
271
+ sentences: list[str],
272
+ semantics: list,
273
+ speaker: Optional[str] = None,
274
+ skip_text: bool = False,
275
+ ):
276
+ if speaker is None:
277
+ speaker = "assistant"
278
+
279
+ cated_sentences = " ".join(sentences)
280
+ if skip_text:
281
+ cated_sentences = "<|skip_text|>"
282
+
283
+ final_text = "<|im_start|>user\n" + cated_sentences + "<|im_end|>"
284
+ final_text = final_text + f"<|im_start|>{speaker}\n"
285
+
286
+ encoded = self.tokenizer.encode(
287
+ final_text,
288
+ add_special_tokens=False,
289
+ truncation=False,
290
+ max_length=10**6,
291
+ )
292
+ semantic_length = sum([len(i[0].values) for i in semantics])
293
+ prompt_length = len(encoded)
294
+ num_codebooks = (
295
+ len(semantics[0]) if self.num_codebooks is None else self.num_codebooks
296
+ )
297
+
298
+ # Pack the tokens and semantics (add <s> and </s> to semantic tokens)
299
+ tokens = (
300
+ encoded
301
+ + [self.semantic_token_id] * semantic_length
302
+ + self.tokenizer.convert_tokens_to_ids(["<|im_end|>"])
303
+ )
304
+
305
+ # Codebook bos/padding: 0, eos: 1
306
+ codes = [[CODEBOOK_PAD_TOKEN_ID] * prompt_length for _ in range(num_codebooks)]
307
+ for segment in semantics:
308
+ for book_idx, book in zip(range(num_codebooks), segment):
309
+ for j in book.values:
310
+ codes[book_idx].append(int(j) + 1)
311
+
312
+ for book in codes:
313
+ book.extend([CODEBOOK_PAD_TOKEN_ID] * 1)
314
+
315
+ tokens = [tokens] + codes
316
+
317
+ tokens = torch.tensor(tokens, dtype=torch.long)
318
+ labels = tokens.clone()
319
+
320
+ if skip_text:
321
+ # If text is not provided, the sentence is used for condition only, all labels are -100
322
+ torch.fill_(labels, -100)
323
+ return tokens, labels
324
+
325
+ # Mask out the <s> tokens for semantic, predict semantic tokens only
326
+ # Since we don't mask out the input tokens, the language modeling still works
327
+ labels[1:, :prompt_length] = -100
328
+
329
+ tokens = tokens[:, :-1]
330
+ labels = labels[:, 1:]
331
+
332
+ # Verify the padding is correct, and the last token is eos
333
+ assert (tokens[1:, :prompt_length] == CODEBOOK_PAD_TOKEN_ID).all()
334
+ assert (labels[1:, -1:] == CODEBOOK_PAD_TOKEN_ID).all()
335
+
336
+ return tokens, labels
337
+
338
+
339
+ @dataclass
340
+ class TextDataCollator:
341
+ tokenizer: AutoTokenizer
342
+ max_length: int = 1024
343
+
344
+ def __call__(self, examples):
345
+ if "negative_tokens" in examples:
346
+ positive_examples = []
347
+ negative_examples = []
348
+
349
+ for i in examples:
350
+ positive_examples.append(
351
+ {
352
+ "tokens": i["tokens"],
353
+ "labels": i["labels"],
354
+ }
355
+ )
356
+ negative_examples.append(
357
+ {
358
+ "tokens": i["negative_tokens"],
359
+ "labels": i["negative_labels"],
360
+ }
361
+ )
362
+
363
+ examples = positive_examples + negative_examples
364
+
365
+ return self.batchify(examples)
366
+
367
+ def batchify(self, examples, tokens_key="tokens", labels_key="labels"):
368
+ tokens, attention_masks, labels = [], [], []
369
+
370
+ # Calculate the max length
371
+ max_tokens_length = 0
372
+ for example in examples:
373
+ max_tokens_length = max(max_tokens_length, example[tokens_key].size(1))
374
+ max_tokens_length = min(max_tokens_length, self.max_length)
375
+
376
+ for example in examples:
377
+ _tokens = example[tokens_key][:, :max_tokens_length]
378
+ _labels = example[labels_key][:, :max_tokens_length]
379
+ _attention_mask = torch.ones((max_tokens_length,), dtype=torch.bool)
380
+ tokens_length = _tokens.size(1)
381
+ _attention_mask[:tokens_length] = False
382
+
383
+ assert tokens_length == _labels.size(
384
+ 1
385
+ ), f"{tokens_length} != {_labels.size(1)}"
386
+
387
+ if tokens_length < max_tokens_length:
388
+ _tokens = F.pad(
389
+ _tokens,
390
+ (0, max_tokens_length - tokens_length),
391
+ value=self.tokenizer.eos_token_id,
392
+ )
393
+ _tokens[1:, tokens_length:] = CODEBOOK_PAD_TOKEN_ID
394
+ _labels = F.pad(
395
+ _labels, (0, max_tokens_length - _labels.size(1)), value=-100
396
+ )
397
+
398
+ tokens.append(_tokens)
399
+ attention_masks.append(_attention_mask)
400
+ labels.append(_labels)
401
+
402
+ tokens = torch.stack(tokens, dim=0)
403
+ attention_masks = torch.stack(attention_masks, dim=0)
404
+ labels = torch.stack(labels, dim=0)
405
+
406
+ return {
407
+ "inputs": tokens,
408
+ "attention_masks": attention_masks,
409
+ "labels": labels,
410
+ }
411
+
412
+
413
+ class InterleaveDataset(IterableDataset):
414
+ def __init__(
415
+ self,
416
+ datasets: list[IterableDataset],
417
+ probabilities: list[float],
418
+ seed: int = 42,
419
+ ):
420
+ super().__init__()
421
+
422
+ self.datasets = datasets
423
+ self.probabilities = probabilities
424
+ self.seed = seed
425
+
426
+ def __iter__(self):
427
+ rng = np.random.default_rng(self.seed)
428
+ dataset_iterators = [iter(dataset) for dataset in self.datasets]
429
+
430
+ while True:
431
+ # Random choice one
432
+ dataset_idx = rng.choice(len(self.datasets), p=self.probabilities)
433
+ dataset_iterator = dataset_iterators[dataset_idx]
434
+
435
+ try:
436
+ yield next(dataset_iterator)
437
+ except StopIteration:
438
+ # Exhausted, create a new iterator
439
+ dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
440
+ yield next(dataset_iterators[dataset_idx])
441
+
442
+
443
+ class SemanticDataModule(LightningDataModule):
444
+ def __init__(
445
+ self,
446
+ train_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
447
+ val_dataset: Union[AutoTextSemanticInstructionDataset, InterleaveDataset],
448
+ batch_size: int = 32,
449
+ tokenizer: AutoTokenizer = None,
450
+ max_length: int = 1024,
451
+ num_workers: int = 4,
452
+ ):
453
+ super().__init__()
454
+
455
+ self.train_dataset = train_dataset
456
+ self.val_dataset = val_dataset
457
+ self.batch_size = batch_size
458
+ self.tokenizer = tokenizer
459
+ self.max_length = max_length
460
+ self.num_workers = num_workers
461
+
462
+ def train_dataloader(self):
463
+ return DataLoader(
464
+ self.train_dataset,
465
+ batch_size=self.batch_size,
466
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
467
+ num_workers=self.num_workers,
468
+ persistent_workers=True,
469
+ )
470
+
471
+ def val_dataloader(self):
472
+ return DataLoader(
473
+ self.val_dataset,
474
+ batch_size=self.batch_size,
475
+ collate_fn=TextDataCollator(self.tokenizer, self.max_length),
476
+ num_workers=self.num_workers,
477
+ persistent_workers=True,
478
+ )
479
+
480
+
481
+ if __name__ == "__main__":
482
+ from tqdm import tqdm
483
+
484
+ ds = AutoTextSemanticInstructionDataset(
485
+ ["data/protos"],
486
+ tokenizer=AutoTokenizer.from_pretrained("fishaudio/fish-speech-1"),
487
+ use_speaker=False,
488
+ interactive_prob=1.0,
489
+ skip_text_prob=0.5,
490
+ )
491
+
492
+ for i in ds:
493
+ print(ds.tokenizer.decode(i["tokens"][0], skip_special_tokens=False))
494
+ # i["labels"][0][i["labels"][0] == -100] = 0
495
+ # print(ds.tokenizer.decode(i["labels"][0], skip_special_tokens=False))
496
+ break
fish_speech/datasets/vqgan.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import librosa
6
+ import numpy as np
7
+ import torch
8
+ from lightning import LightningDataModule
9
+ from torch.utils.data import DataLoader, Dataset
10
+
11
+ from fish_speech.utils import RankedLogger
12
+
13
+ logger = RankedLogger(__name__, rank_zero_only=False)
14
+
15
+
16
+ class VQGANDataset(Dataset):
17
+ def __init__(
18
+ self,
19
+ filelist: str,
20
+ sample_rate: int = 32000,
21
+ hop_length: int = 640,
22
+ slice_frames: Optional[int] = None,
23
+ ):
24
+ super().__init__()
25
+
26
+ filelist = Path(filelist)
27
+ root = filelist.parent
28
+
29
+ self.files = [
30
+ root / line.strip()
31
+ for line in filelist.read_text(encoding="utf-8").splitlines()
32
+ if line.strip()
33
+ ]
34
+ self.sample_rate = sample_rate
35
+ self.hop_length = hop_length
36
+ self.slice_frames = slice_frames
37
+
38
+ def __len__(self):
39
+ return len(self.files)
40
+
41
+ def get_item(self, idx):
42
+ file = self.files[idx]
43
+
44
+ audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
45
+
46
+ # Slice audio and features
47
+ if (
48
+ self.slice_frames is not None
49
+ and audio.shape[0] > self.slice_frames * self.hop_length
50
+ ):
51
+ start = np.random.randint(
52
+ 0, audio.shape[0] - self.slice_frames * self.hop_length
53
+ )
54
+ audio = audio[start : start + self.slice_frames * self.hop_length]
55
+
56
+ if len(audio) == 0:
57
+ return None
58
+
59
+ max_value = np.abs(audio).max()
60
+ if max_value > 1.0:
61
+ audio = audio / max_value
62
+
63
+ return {
64
+ "audio": torch.from_numpy(audio),
65
+ }
66
+
67
+ def __getitem__(self, idx):
68
+ try:
69
+ return self.get_item(idx)
70
+ except Exception as e:
71
+ import traceback
72
+
73
+ traceback.print_exc()
74
+ logger.error(f"Error loading {self.files[idx]}: {e}")
75
+ return None
76
+
77
+
78
+ @dataclass
79
+ class VQGANCollator:
80
+ def __call__(self, batch):
81
+ batch = [x for x in batch if x is not None]
82
+
83
+ audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
84
+ audio_maxlen = audio_lengths.max()
85
+
86
+ # Rounds up to nearest multiple of 2 (audio_lengths)
87
+ audios = []
88
+ for x in batch:
89
+ audios.append(
90
+ torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
91
+ )
92
+
93
+ return {
94
+ "audios": torch.stack(audios),
95
+ "audio_lengths": audio_lengths,
96
+ }
97
+
98
+
99
+ class VQGANDataModule(LightningDataModule):
100
+ def __init__(
101
+ self,
102
+ train_dataset: VQGANDataset,
103
+ val_dataset: VQGANDataset,
104
+ batch_size: int = 32,
105
+ num_workers: int = 4,
106
+ val_batch_size: Optional[int] = None,
107
+ ):
108
+ super().__init__()
109
+
110
+ self.train_dataset = train_dataset
111
+ self.val_dataset = val_dataset
112
+ self.batch_size = batch_size
113
+ self.val_batch_size = val_batch_size or batch_size
114
+ self.num_workers = num_workers
115
+
116
+ def train_dataloader(self):
117
+ return DataLoader(
118
+ self.train_dataset,
119
+ batch_size=self.batch_size,
120
+ collate_fn=VQGANCollator(),
121
+ num_workers=self.num_workers,
122
+ shuffle=True,
123
+ persistent_workers=True,
124
+ )
125
+
126
+ def val_dataloader(self):
127
+ return DataLoader(
128
+ self.val_dataset,
129
+ batch_size=self.val_batch_size,
130
+ collate_fn=VQGANCollator(),
131
+ num_workers=self.num_workers,
132
+ persistent_workers=True,
133
+ )
134
+
135
+
136
+ if __name__ == "__main__":
137
+ dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
138
+ dataloader = DataLoader(
139
+ dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
140
+ )
141
+
142
+ for batch in dataloader:
143
+ print(batch["audios"].shape)
144
+ print(batch["features"].shape)
145
+ print(batch["audio_lengths"])
146
+ print(batch["feature_lengths"])
147
+ break
fish_speech/i18n/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## i18n Folder Attribution
2
+
3
+ The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
4
+
5
+ ### fish_speech/i18n/core.py
6
+
7
+ **Related code from RVC:**
8
+ [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
9
+
10
+ **Initial commit:**
11
+ add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
12
+
13
+ **Initial author:**
14
+ [@L4Ph](https://github.com/L4Ph)
15
+
16
+ ### fish_speech/i18n/scan.py
17
+
18
+ **Related code from RVC:**
19
+ [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
20
+
21
+ **Initial commit:**
22
+ File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
23
+
24
+ **Initial author:**
25
+ [@towzeur](https://github.com/towzeur)
26
+
27
+ We appreciate the contributions of the RVC project and its authors.
fish_speech/i18n/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .core import i18n
2
+
3
+ __all__ = ["i18n"]
fish_speech/i18n/core.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import locale
3
+ from pathlib import Path
4
+
5
+ I18N_FILE_PATH = Path(__file__).parent / "locale"
6
+ DEFAULT_LANGUAGE = "en_US"
7
+
8
+
9
+ def load_language_list(language):
10
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11
+ language_list = json.load(f)
12
+
13
+ return language_list
14
+
15
+
16
+ class I18nAuto:
17
+ def __init__(self):
18
+ i18n_file = Path(".locale")
19
+
20
+ if i18n_file.exists():
21
+ with open(i18n_file, "r", encoding="utf-8") as f:
22
+ language = f.read().strip()
23
+ else:
24
+ # getlocale can't identify the system's language ((None, None))
25
+ language = locale.getdefaultlocale()[0]
26
+
27
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28
+ language = DEFAULT_LANGUAGE
29
+
30
+ self.language = language
31
+ self.language_map = load_language_list(language)
32
+
33
+ def __call__(self, key):
34
+ return self.language_map.get(key, key)
35
+
36
+ def __repr__(self):
37
+ return "Use Language: " + self.language
38
+
39
+
40
+ i18n = I18nAuto()
fish_speech/i18n/locale/en_US.json ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
5
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
6
+ "Add to Processing Area": "Add to Processing Area",
7
+ "Added path successfully!": "Added path successfully!",
8
+ "Advanced Config": "Advanced Config",
9
+ "Base LLAMA Model": "Base LLAMA Model",
10
+ "Batch Inference": "Batch Inference",
11
+ "Batch Size": "Batch Size",
12
+ "Changing with the Model Path": "Changing with the Model Path",
13
+ "Chinese": "Chinese",
14
+ "Compile Model": "Compile Model",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
16
+ "Copy": "Copy",
17
+ "Data Preprocessing": "Data Preprocessing",
18
+ "Data Preprocessing Path": "Data Preprocessing Path",
19
+ "Data Source": "Data Source",
20
+ "Decoder Model Config": "Decoder Model Config",
21
+ "Decoder Model Path": "Decoder Model Path",
22
+ "Disabled": "Disabled",
23
+ "Enable Reference Audio": "Enable Reference Audio",
24
+ "English": "English",
25
+ "Error Message": "Error Message",
26
+ "File Preprocessing": "File Preprocessing",
27
+ "Generate": "Generate",
28
+ "Generated Audio": "Generated Audio",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
30
+ "Infer interface is closed": "Infer interface is closed",
31
+ "Inference Configuration": "Inference Configuration",
32
+ "Inference Server Configuration": "Inference Server Configuration",
33
+ "Inference Server Error": "Inference Server Error",
34
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
35
+ "Initial Learning Rate": "Initial Learning Rate",
36
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
37
+ "Input Text": "Input Text",
38
+ "Invalid path: {}": "Invalid path: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
40
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
41
+ "Japanese": "Japanese",
42
+ "LLAMA Configuration": "LLAMA Configuration",
43
+ "LLAMA Model Config": "LLAMA Model Config",
44
+ "LLAMA Model Path": "LLAMA Model Path",
45
+ "Labeling Device": "Labeling Device",
46
+ "LoRA Model to be merged": "LoRA Model to be merged",
47
+ "Maximum Audio Duration": "Maximum Audio Duration",
48
+ "Maximum Length per Sample": "Maximum Length per Sample",
49
+ "Maximum Training Steps": "Maximum Training Steps",
50
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
51
+ "Merge": "Merge",
52
+ "Merge LoRA": "Merge LoRA",
53
+ "Merge successfully": "Merge successfully",
54
+ "Minimum Audio Duration": "Minimum Audio Duration",
55
+ "Model Output Path": "Model Output Path",
56
+ "Model Size": "Model Size",
57
+ "Move": "Move",
58
+ "Move files successfully": "Move files successfully",
59
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
60
+ "No selected options": "No selected options",
61
+ "Number of Workers": "Number of Workers",
62
+ "Open Inference Server": "Open Inference Server",
63
+ "Open Labeler WebUI": "Open Labeler WebUI",
64
+ "Open Tensorboard": "Open Tensorboard",
65
+ "Opened labeler in browser": "Opened labeler in browser",
66
+ "Optional Label Language": "Optional Label Language",
67
+ "Optional online ver": "Optional online ver",
68
+ "Output Path": "Output Path",
69
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
70
+ "Precision": "Precision",
71
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
72
+ "Put your text here.": "Put your text here.",
73
+ "Reference Audio": "Reference Audio",
74
+ "Reference Text": "Reference Text",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
76
+ "Remove Selected Data": "Remove Selected Data",
77
+ "Removed path successfully!": "Removed path successfully!",
78
+ "Repetition Penalty": "Repetition Penalty",
79
+ "Save model every n steps": "Save model every n steps",
80
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
81
+ "Select VITS ckpt": "Select VITS ckpt",
82
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
83
+ "Select source file processing method": "Select source file processing method",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
85
+ "Selected: {}": "Selected: {}",
86
+ "Speaker": "Speaker",
87
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
88
+ "Start Training": "Start Training",
89
+ "Streaming Audio": "Streaming Audio",
90
+ "Streaming Generate": "Streaming Generate",
91
+ "Tensorboard Host": "Tensorboard Host",
92
+ "Tensorboard Log Path": "Tensorboard Log Path",
93
+ "Tensorboard Port": "Tensorboard Port",
94
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
95
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
96
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
98
+ "Training Configuration": "Training Configuration",
99
+ "Training Error": "Training Error",
100
+ "Training stopped": "Training stopped",
101
+ "Type name of the speaker": "Type name of the speaker",
102
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
103
+ "Use LoRA": "Use LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
105
+ "Use filelist": "Use filelist",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
107
+ "VITS Configuration": "VITS Configuration",
108
+ "VQGAN Configuration": "VQGAN Configuration",
109
+ "Validation Batch Size": "Validation Batch Size",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
112
+ "WebUI Host": "WebUI Host",
113
+ "WebUI Port": "WebUI Port",
114
+ "Whisper Model": "Whisper Model",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
117
+ "latest": "latest",
118
+ "new": "new",
119
+ "Realtime Transform Text": "Realtime Transform Text",
120
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121
+ "Text Normalization": "Text Normalization",
122
+ "Select Example Audio": "Select Example Audio"
123
+ }
fish_speech/i18n/locale/es_ES.json ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
5
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
6
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
7
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
8
+ "Advanced Config": "Configuración Avanzada",
9
+ "Base LLAMA Model": "Modelo Base LLAMA",
10
+ "Batch Inference": "Inferencia por Lote",
11
+ "Batch Size": "Tamaño del Lote",
12
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
13
+ "Chinese": "Chino",
14
+ "Compile Model": "Compilar Modelo",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
16
+ "Copy": "Copiar",
17
+ "Data Preprocessing": "Preprocesamiento de Datos",
18
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
19
+ "Data Source": "Fuente de Datos",
20
+ "Decoder Model Config": "Configuración del modelo decodificador",
21
+ "Decoder Model Path": "Ruta del modelo decodificador",
22
+ "Disabled": "Desactivado",
23
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
24
+ "English": "Inglés",
25
+ "Error Message": "Mensaje de Error",
26
+ "File Preprocessing": "Preprocesamiento de Archivos",
27
+ "Generate": "Generar",
28
+ "Generated Audio": "Audio Generado",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
30
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
31
+ "Inference Configuration": "Configuración de Inferencia",
32
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
33
+ "Inference Server Error": "Error del Servidor de Inferencia",
34
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
35
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
36
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
37
+ "Input Text": "Texto de Entrada",
38
+ "Invalid path: {}": "Ruta inválida: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
40
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
41
+ "Japanese": "Japonés",
42
+ "LLAMA Configuration": "Configuración de LLAMA",
43
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
44
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
45
+ "Labeling Device": "Dispositivo de Etiquetado",
46
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
47
+ "Maximum Audio Duration": "Duración máxima de audio",
48
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
49
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
50
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
51
+ "Merge": "Fusionar",
52
+ "Merge LoRA": "Fusionar LoRA",
53
+ "Merge successfully": "Fusionado exitosamente",
54
+ "Minimum Audio Duration": "Duración mínima de audio",
55
+ "Model Output Path": "Ruta de Salida del Modelo",
56
+ "Model Size": "Tamaño del Modelo",
57
+ "Move": "Mover",
58
+ "Move files successfully": "Archivos movidos exitosamente",
59
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
60
+ "No selected options": "No hay opciones seleccionadas",
61
+ "Number of Workers": "Número de Trabajadores",
62
+ "Open Inference Server": "Abrir Servidor de Inferencia",
63
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
64
+ "Open Tensorboard": "Abrir Tensorboard",
65
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
66
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
67
+ "Optional online ver": "Ver en línea opcional",
68
+ "Output Path": "Ruta de Salida",
69
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
70
+ "Precision": "Precisión",
71
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
72
+ "Put your text here.": "Ponga su texto aquí.",
73
+ "Reference Audio": "Audio de Referencia",
74
+ "Reference Text": "Texto de Referencia",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
76
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
77
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
78
+ "Repetition Penalty": "Penalización por Repetición",
79
+ "Save model every n steps": "Guardar modelo cada n pasos",
80
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
81
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
82
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
83
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
85
+ "Selected: {}": "Seleccionado: {}",
86
+ "Speaker": "Hablante",
87
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
88
+ "Start Training": "Iniciar Entrenamiento",
89
+ "Streaming Audio": "transmisión de audio",
90
+ "Streaming Generate": "síntesis en flujo",
91
+ "Tensorboard Host": "Host de Tensorboard",
92
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
93
+ "Tensorboard Port": "Puerto de Tensorboard",
94
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
95
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
96
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
98
+ "Training Configuration": "Configuración de Entrenamiento",
99
+ "Training Error": "Error de Entrenamiento",
100
+ "Training stopped": "Entrenamiento detenido",
101
+ "Type name of the speaker": "Escriba el nombre del hablante",
102
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
103
+ "Use LoRA": "Usar LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
105
+ "Use filelist": "Usar lista de archivos",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
107
+ "VITS Configuration": "Configuración de VITS",
108
+ "VQGAN Configuration": "Configuración de VQGAN",
109
+ "Validation Batch Size": "Tamaño del Lote de Validación",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
112
+ "WebUI Host": "Host de WebUI",
113
+ "WebUI Port": "Puerto de WebUI",
114
+ "Whisper Model": "Modelo Whisper",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
117
+ "latest": "más reciente",
118
+ "new": "nuevo",
119
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121
+ "Text Normalization": "Normalización de Texto",
122
+ "Select Example Audio": "Selecionar áudio de exemplo"
123
+ }
fish_speech/i18n/locale/ja_JP.json ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
5
+ "Accumulate Gradient Batches": "勾配バッチの累積",
6
+ "Add to Processing Area": "処理エリアに追加",
7
+ "Added path successfully!": "パスの追加に成功しました!",
8
+ "Advanced Config": "詳細設定",
9
+ "Base LLAMA Model": "基本LLAMAモデル",
10
+ "Batch Inference": "バッチ推論",
11
+ "Batch Size": "バッチサイズ",
12
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
13
+ "Chinese": "中国語",
14
+ "Compile Model": "モデルのコンパイル",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
16
+ "Copy": "コピー",
17
+ "Data Preprocessing": "データ前処理",
18
+ "Data Preprocessing Path": "データ前処理パス",
19
+ "Data Source": "データソース",
20
+ "Decoder Model Config": "デコーダーモデルの構成",
21
+ "Decoder Model Path": "デコーダーモデルのパス",
22
+ "Disabled": "無効",
23
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
24
+ "English": "英語",
25
+ "Error Message": "エラーメッセージ",
26
+ "File Preprocessing": "文書前处理",
27
+ "Generate": "生成",
28
+ "Generated Audio": "生成されたオーディオ",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
30
+ "Infer interface is closed": "推論インターフェースが閉じられています",
31
+ "Inference Configuration": "推論設定",
32
+ "Inference Server Configuration": "推論サーバー設定",
33
+ "Inference Server Error": "推論サーバーエラー",
34
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
35
+ "Initial Learning Rate": "初期学習率",
36
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
37
+ "Input Text": "入力テキスト",
38
+ "Invalid path: {}": "無効なパス: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
40
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
41
+ "Japanese": "日本語",
42
+ "LLAMA Configuration": "LLAMA設定",
43
+ "LLAMA Model Config": "LLAMAモデル設定",
44
+ "LLAMA Model Path": "LLAMAモデルパス",
45
+ "Labeling Device": "ラベリングデバイス",
46
+ "LoRA Model to be merged": "マージするLoRAモデル",
47
+ "Maximum Audio Duration": "最大オーディオの長さ",
48
+ "Maximum Length per Sample": "サンプルあたりの最大長",
49
+ "Maximum Training Steps": "最大トレーニングステップ数",
50
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
51
+ "Merge": "マージ",
52
+ "Merge LoRA": "LoRAのマージ",
53
+ "Merge successfully": "マージに成功しました",
54
+ "Minimum Audio Duration": "最小オーディオの長さ",
55
+ "Model Output Path": "モデル出力パス",
56
+ "Model Size": "モデルサイズ",
57
+ "Move": "移動",
58
+ "Move files successfully": "ファイルの移動に成功しました",
59
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
60
+ "No selected options": "選択されたオプションはありません",
61
+ "Number of Workers": "ワーカー数",
62
+ "Open Inference Server": "推論サーバーを開く",
63
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
64
+ "Open Tensorboard": "Tensorboardを開く",
65
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
66
+ "Optional Label Language": "オプションのラベル言語",
67
+ "Optional online ver": "オプションのオンラインバージョン",
68
+ "Output Path": "出力パス",
69
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
70
+ "Precision": "精度",
71
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
72
+ "Put your text here.": "ここにテキストを入力してください。",
73
+ "Reference Audio": "リファレンスオーディオ",
74
+ "Reference Text": "リファレンステキスト",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
76
+ "Remove Selected Data": "選択したデータを削除",
77
+ "Removed path successfully!": "パスの削除に成功しました!",
78
+ "Repetition Penalty": "反復ペナルティ",
79
+ "Save model every n steps": "nステップごとにモデルを保存",
80
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
81
+ "Select VITS ckpt": "VITS チェックポイントを選択",
82
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
83
+ "Select source file processing method": "ソースファイルの処理方法を選択",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
85
+ "Selected: {}": "選択済み: {}",
86
+ "Speaker": "話者",
87
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
88
+ "Start Training": "トレーニング開始",
89
+ "Streaming Audio": "ストリーミングオーディオ",
90
+ "Streaming Generate": "ストリーミング合成",
91
+ "Tensorboard Host": "Tensorboardホスト",
92
+ "Tensorboard Log Path": "Tensorboardログパス",
93
+ "Tensorboard Port": "Tensorboardポート",
94
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
95
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
96
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
98
+ "Training Configuration": "トレーニング設定",
99
+ "Training Error": "トレーニングエラー",
100
+ "Training stopped": "トレーニングが停止しました",
101
+ "Type name of the speaker": "話者の名前を入力",
102
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
103
+ "Use LoRA": "LoRAを使用",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
105
+ "Use filelist": "ファイルリストを使用",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
107
+ "VITS Configuration": "VITS の構成",
108
+ "VQGAN Configuration": "VQGAN の構成",
109
+ "Validation Batch Size": "検証バッチサイズ",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
112
+ "WebUI Host": "WebUIホスト",
113
+ "WebUI Port": "WebUIポート",
114
+ "Whisper Model": "Whisperモデル",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
117
+ "latest": "最新",
118
+ "new": "新規",
119
+ "Realtime Transform Text": "リアルタイム変換テキスト",
120
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121
+ "Text Normalization": "テキスト正規化",
122
+ "Select Example Audio": "サンプル音声を選択"
123
+ }
fish_speech/i18n/locale/ko_KR.json ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10+ 시리즈 GPU에는 16-mixed를 권장합니다.",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "화자를 특정하는 데 유의미한 5~10초의 길이의 참조 오디오 데이터.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)에서 개발한 VQ-GAN 및 Llama 기반의 텍스트 음성 변환 모델.",
5
+ "Accumulate Gradient Batches": "그라디언트 배치 누적",
6
+ "Add to Processing Area": "처리 영역에 추가",
7
+ "Added path successfully!": "경로가 성공적으로 추가되었습니다!",
8
+ "Advanced Config": "고급 설정",
9
+ "Base LLAMA Model": "기본 LLAMA 모델",
10
+ "Batch Inference": "배치 추론",
11
+ "Batch Size": "배치 크기",
12
+ "Changing with the Model Path": "모델 경로에 따라 변경 중",
13
+ "Chinese": "중국어",
14
+ "Compile Model": "모델 컴파일",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "모델을 컴파일하면 추론 시간이 크게 줄어들지만, 초기 시작 시간이 길어집니다.",
16
+ "Copy": "복사",
17
+ "Data Preprocessing": "데이터 전처리",
18
+ "Data Preprocessing Path": "데이터 전처리 경로",
19
+ "Data Source": "데이터 소스",
20
+ "Decoder Model Config": "디코더 모델 설정",
21
+ "Decoder Model Path": "디코더 모델 경로",
22
+ "Disabled": "비활성화 됨",
23
+ "Enable Reference Audio": "참고 음성 활성화",
24
+ "English": "영어",
25
+ "Error Message": "오류 메시지",
26
+ "File Preprocessing": "파일 전처리",
27
+ "Generate": "생성",
28
+ "Generated Audio": "생성된 오디오",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "오디오애 대응하는 텍스트가 없을 경우, ASR을 적용해 지원하며, .txt 또는 .lab 형식을 지원합니다.",
30
+ "Infer interface is closed": "추론 인터페이스가 닫혔습니다.",
31
+ "Inference Configuration": "추론 설정",
32
+ "Inference Server Configuration": "추론 서버 설정",
33
+ "Inference Server Error": "추론 서버 오류",
34
+ "Inferring interface is launched at {}": "추론 인터페이스가 {}에서 시작되었습니다.",
35
+ "Initial Learning Rate": "초기 학습률",
36
+ "Input Audio & Source Path for Transcription": "전사할 입력 오디오 및 소스 경로",
37
+ "Input Text": "입력 텍스트",
38
+ "Invalid path: {}": "유효하지 않은 경로: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDA 사용을 권장하며, 낮은 사양일 경우 CPU를 사용하는 것을 권장합니다.",
40
+ "Iterative Prompt Length, 0 means off": "반복 프롬프트 길이. (0:비활성화)",
41
+ "Japanese": "일본어",
42
+ "LLAMA Configuration": "LLAMA 설정",
43
+ "LLAMA Model Config": "LLAMA 모델 설정",
44
+ "LLAMA Model Path": "LLAMA 모델 경로",
45
+ "Labeling Device": "라벨링 장치",
46
+ "LoRA Model to be merged": "병합할 LoRA 모델",
47
+ "Maximum Audio Duration": "최대 오디오 길이",
48
+ "Maximum Length per Sample": "샘플당 최대 길이",
49
+ "Maximum Training Steps": "최대 학습 단계",
50
+ "Maximum tokens per batch, 0 means no limit": "배치당 최대 토큰 수(0:제한 없음)",
51
+ "Merge": "병합",
52
+ "Merge LoRA": "LoRA 병합",
53
+ "Merge successfully": "성공적으로 병합 되었습니다.",
54
+ "Minimum Audio Duration": "최소 오디오 길이",
55
+ "Model Output Path": "모델 출력 경로",
56
+ "Model Size": "모델 크기",
57
+ "Move": "이동",
58
+ "Move files successfully": "파일이 성공적으로 이동되었습니다.",
59
+ "No audio generated, please check the input text.": "생성된 오디오가 없습니다. 입력된 텍스트를 확인하세요.",
60
+ "No selected options": "옵션이 선택되지 않았습니다.",
61
+ "Number of Workers": "작업자 수",
62
+ "Open Inference Server": "추론 서버 열기",
63
+ "Open Labeler WebUI": "라벨러 WebUI 열기",
64
+ "Open Tensorboard": "Tensorboard 열기",
65
+ "Opened labeler in browser": "브라우저에서 라벨러가 열렸습니다.",
66
+ "Optional Label Language": "선택적 라벨 언어",
67
+ "Optional online ver": "온라인 버전 선택",
68
+ "Output Path": "출력 경로",
69
+ "Path error, please check the model file exists in the corresponding path": "경로 오류, 해당 경로에 모델 파일이 있는지 확인하십시오.",
70
+ "Precision": "정밀도",
71
+ "Probability of applying Speaker Condition": "화자 조건 적용 확률",
72
+ "Put your text here.": "여기에 텍스트를 입력하세요.",
73
+ "Reference Audio": "참고 오디오",
74
+ "Reference Text": "참고 텍스트",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "관련 코드 및 가중치는 CC BY-NC-SA 4.0 라이선스 하에 배포됩니다.",
76
+ "Remove Selected Data": "선택한 데이터 제거",
77
+ "Removed path successfully!": "경로가 성공적으로 제거되었습니다!",
78
+ "Repetition Penalty": "���복 패널티",
79
+ "Save model every n steps": "n 단계마다 모델 저장",
80
+ "Select LLAMA ckpt": "LLAMA ckpt 선택",
81
+ "Select VITS ckpt": "VITS ckpt 선택",
82
+ "Select VQGAN ckpt": "VQGAN ckpt 선택",
83
+ "Select source file processing method": "소스 파일 처리 방법 선택",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "학습할 모델 선택(탭 페이지에 따라 다름)",
85
+ "Selected: {}": "선택됨: {}",
86
+ "Speaker": "화자",
87
+ "Speaker is identified by the folder name": "화자는 폴더 이름으로 식별됩니다",
88
+ "Start Training": "학습 시작",
89
+ "Streaming Audio": "스트리밍 오디오",
90
+ "Streaming Generate": "스트리밍 생성",
91
+ "Tensorboard Host": "Tensorboard 호스트",
92
+ "Tensorboard Log Path": "Tensorboard 로그 경로",
93
+ "Tensorboard Port": "Tensorboard 포트",
94
+ "Tensorboard interface is closed": "Tensorboard 인터페이스가 닫혔습니다",
95
+ "Tensorboard interface is launched at {}": "Tensorboard 인터페이스가 {}에서 시작되었습니다.",
96
+ "Text is too long, please keep it under {} characters.": "텍스트가 너무 깁니다. {}자 이하로 입력해주세요.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "왼쪽의 입력 폴더 경로 또는 파일 목록의 경로. 체크 여부에 관계없이 이 목록에서 후속 학습에 사용됩니다.",
98
+ "Training Configuration": "학습 설정",
99
+ "Training Error": "학습 오류",
100
+ "Training stopped": "학습이 중지되었습니다.",
101
+ "Type name of the speaker": "화자의 이름을 입력하세요.",
102
+ "Type the path or select from the dropdown": "경로를 입력하거나 드롭다운에서 선택하세요.",
103
+ "Use LoRA": "LoRA 사용",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRA를 사용하면 GPU 메모리를 절약할 수 있지만, 모델의 품질이 저하될 수 있습니다.",
105
+ "Use filelist": "파일 목록 사용",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 환경에선 large, 5G에선 medium, 2G에선 small을 사용할 것을 권장합니다.",
107
+ "VITS Configuration": "VITS 설정",
108
+ "VQGAN Configuration": "VQGAN 설정",
109
+ "Validation Batch Size": "검증 배치 크기",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "전처리 폴더의 상태를 확인합니다(슬라이더를 사용하여 트리의 깊이를 조절합니다)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "모델의 오용에 대해 책임지지 않습니다. 사용하기 전에 현지 법률과 규정을 고려하시길 바랍니다.",
112
+ "WebUI Host": "WebUI 호스트",
113
+ "WebUI Port": "WebUI 포트",
114
+ "Whisper Model": "Whisper 모델",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "소스 코드는 [이곳](https://github.com/fishaudio/fish-speech)에서, 모델은 [이곳](https://huggingface.co/fishaudio/fish-speech-1)에서 확인하실 수 있습니다.",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 시리즈 GPU에는 bf16-true를, 10+ 시리즈 GPU에는 16-mixed를 권장합니다",
117
+ "latest": "최신",
118
+ "new": "새로운",
119
+ "Realtime Transform Text": "실시간 텍스트 변환",
120
+ "Normalization Result Preview (Currently Only Chinese)": "정규화 결과 미리보기(현재 중국어만 지원)",
121
+ "Text Normalization": "텍스트 정규화",
122
+ "Select Example Audio": "예시 오디오 선택"
123
+ }
fish_speech/i18n/locale/pt_BR.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
3
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
4
+ "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
5
+ "Add to Processing Area": "Adicionar à Área de Processamento",
6
+ "Added path successfully!": "Caminho adicionado com sucesso!",
7
+ "Advanced Config": "Configuração Avançada",
8
+ "Base LLAMA Model": "Modelo LLAMA Base",
9
+ "Batch Inference": "Inferência em Lote",
10
+ "Batch Size": "Tamanho do Lote",
11
+ "Changing with the Model Path": "Alterando com o Caminho do Modelo",
12
+
13
+ "Compile Model": "Compilar Modelo",
14
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
15
+ "Copy": "Copiar",
16
+ "Data Preprocessing": "Pré-processamento de Dados",
17
+ "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
18
+ "Data Source": "Fonte de Dados",
19
+ "Decoder Model Config": "Configuração do Modelo Decodificador",
20
+ "Decoder Model Path": "Caminho do Modelo Decodificador",
21
+ "Disabled": "Desativado",
22
+ "Enable Initial Prompt": "Habilitar Prompt Inicial",
23
+ "Enable Reference Audio": "Habilitar Áudio de Referência",
24
+ "English": "Inglês",
25
+ "Japanese": "Japonês",
26
+ "Chinese": "Chinês",
27
+ "Portuguese": "Português",
28
+ "Spanish": "Espanhol",
29
+ "Error Message": "Mensagem de Erro",
30
+ "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
31
+ "File Preprocessing": "Pré-processamento de Arquivos",
32
+ "Generate": "Gerar",
33
+ "Generated Audio": "Áudio Gerado",
34
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
35
+ "Infer interface is closed": "A interface de inferência foi fechada",
36
+ "Inference Configuration": "Configuração de Inferência",
37
+ "Inference Server Configuration": "Configuração do Servidor de Inferência",
38
+ "Inference Server Error": "Erro do Servidor de Inferência",
39
+ "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
40
+ "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
41
+ "Initial Prompt": "Prompt Inicial",
42
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
43
+ "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
44
+ "Input Text": "Texto de Entrada",
45
+ "Invalid path: {}": "Caminho inválido: {}",
46
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
47
+ "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
48
+ "LLAMA Configuration": "Configuração do LLAMA",
49
+ "LLAMA Model Config": "Configuração do Modelo LLAMA",
50
+ "LLAMA Model Path": "Caminho do Modelo LLAMA",
51
+ "Labeling Device": "Dispositivo de Rotulagem",
52
+ "LoRA Model to be merged": "Modelo LoRA para mesclagem",
53
+ "Maximum Length per Sample": "Comprimento Máximo por Amostra",
54
+ "Maximum Training Steps": "Etapas Máximas de Treinamento",
55
+ "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
56
+ "Merge": "Mesclar",
57
+ "Merge LoRA": "Mesclar LoRA",
58
+ "Merge successfully": "Mesclado com sucesso",
59
+ "Model Output Path": "Caminho de Saída do Modelo",
60
+ "Model Quantization": "Quantização do Modelo",
61
+ "Model Size": "Tamanho do Modelo",
62
+ "Move": "Mover",
63
+ "Move files successfully": "Arquivos movidos com sucesso",
64
+ "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
65
+ "No selected options": "Nenhuma opção selecionada",
66
+ "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
67
+ "Number of Workers": "Número de Processos",
68
+ "Open Inference Server": "Abrir Servidor de Inferência",
69
+ "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
70
+ "Open Tensorboard": "Abrir Tensorboard",
71
+ "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
72
+ "Optional Label Language": "Idioma do Rótulo (Opcional)",
73
+ "Optional online ver": "Versão online (opcional)",
74
+ "Output Path": "Caminho de Saída",
75
+ "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
76
+ "Post-quantification Precision": "Precisão Pós-quantização",
77
+ "Precision": "Precisão",
78
+ "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
79
+ "Put your text here.": "Insira seu texto aqui.",
80
+ "Quantify": "Quantizar",
81
+ "Quantify successfully": "Quantizado com sucesso",
82
+ "Realtime Transform Text": "Transformar Texto em Tempo Real",
83
+ "Reference Audio": "Áudio de Referência",
84
+ "Reference Text": "Texto de Referência",
85
+ "warning": "Aviso",
86
+ "Pre-processing begins...": "O pré-processamento começou!",
87
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
88
+ "Remove Selected Data": "Remover Dados Selecionados",
89
+ "Removed path successfully!": "Caminho removido com sucesso!",
90
+ "Repetition Penalty": "Penalidade de Repetição",
91
+ "Save model every n steps": "Salvar modelo a cada n etapas",
92
+ "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
93
+ "Select source file processing method": "Escolha como processar o arquivo de origem",
94
+ "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
95
+ "Selected: {}": "Selecionado: {}",
96
+ "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
97
+ "Start Training": "Iniciar Treinamento",
98
+ "Streaming Audio": "Áudio em Streaming",
99
+ "Streaming Generate": "Geração em Streaming",
100
+ "Tensorboard Host": "Host do Tensorboard",
101
+ "Tensorboard Log Path": "Caminho de Log do Tensorboard",
102
+ "Tensorboard Port": "Porta do Tensorboard",
103
+ "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
104
+ "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
105
+ "Text Normalization": "Normalização de Texto",
106
+ "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
107
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
108
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
109
+ "Training Configuration": "Configuração de Treinamento",
110
+ "Training Error": "Erro de Treinamento",
111
+ "Training stopped": "Treinamento interrompido!",
112
+ "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
113
+ "Use LoRA": "Usar LoRA",
114
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
115
+ "Use filelist": "Usar lista de arquivos",
116
+ "VQGAN Configuration": "Configuração do VQGAN",
117
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
118
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
119
+ "WebUI Host": "Host da WebUI",
120
+ "WebUI Port": "Porta da WebUI",
121
+ "Whisper Model": "Modelo Whisper",
122
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
123
+ "auto": "automático",
124
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
125
+ "latest": "mais recente",
126
+ "new": "novo",
127
+ "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
128
+ "You don't need to train this model!": "Não é necessário treinar este modelo!",
129
+ "Yes": "Sim",
130
+ "No": "Não",
131
+ "version:": "versão:",
132
+ "author:": "autor:"
133
+ }
fish_speech/i18n/locale/zh_CN.json ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
5
+ "Accumulate Gradient Batches": "梯度累积批次",
6
+ "Add to Processing Area": "加入处理区",
7
+ "Added path successfully!": "添加路径成功!",
8
+ "Advanced Config": "高级参数",
9
+ "Base LLAMA Model": "基础 LLAMA 模型",
10
+ "Batch Inference": "批量推理",
11
+ "Batch Size": "批次大小",
12
+ "Changing with the Model Path": "随模型路径变化",
13
+ "Chinese": "中文",
14
+ "Compile Model": "编译模型",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
16
+ "Copy": "复制",
17
+ "Data Preprocessing": "数据预处理",
18
+ "Data Preprocessing Path": "数据预处理路径",
19
+ "Data Source": "数据源",
20
+ "Decoder Model Config": "解码器模型配置",
21
+ "Decoder Model Path": "解码器模型路径",
22
+ "Disabled": "禁用",
23
+ "Enable Reference Audio": "启用参考音频",
24
+ "English": "英文",
25
+ "Error Message": "错误信息",
26
+ "File Preprocessing": "文件预处理",
27
+ "Generate": "生成",
28
+ "Generated Audio": "音频",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
30
+ "Infer interface is closed": "推理界面已关闭",
31
+ "Inference Configuration": "推理配置",
32
+ "Inference Server Configuration": "推理服务器配置",
33
+ "Inference Server Error": "推理服务器错误",
34
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
35
+ "Initial Learning Rate": "初始学习率",
36
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
37
+ "Input Text": "输入文本",
38
+ "Invalid path: {}": "无效路径: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
40
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
41
+ "Japanese": "日文",
42
+ "LLAMA Configuration": "LLAMA 配置",
43
+ "LLAMA Model Config": "LLAMA 模型配置",
44
+ "LLAMA Model Path": "LLAMA 模型路径",
45
+ "Labeling Device": "标注加速设备",
46
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
47
+ "Maximum Audio Duration": "最大音频时长",
48
+ "Maximum Length per Sample": "每个样本的最大长度",
49
+ "Maximum Training Steps": "最大训练步数",
50
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
51
+ "Merge": "合并",
52
+ "Merge LoRA": "合并 LoRA",
53
+ "Merge successfully": "合并成功",
54
+ "Minimum Audio Duration": "最小音频时长",
55
+ "Model Output Path": "模型输出路径",
56
+ "Model Size": "模型规模",
57
+ "Move": "移动",
58
+ "Move files successfully": "移动文件成功",
59
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
60
+ "No selected options": "没有选择的选项",
61
+ "Number of Workers": "数据加载进程数",
62
+ "Open Inference Server": "打开推理服务器",
63
+ "Open Labeler WebUI": "打开标注工具",
64
+ "Open Tensorboard": "打开 Tensorboard",
65
+ "Opened labeler in browser": "在浏览器中打开标注工具",
66
+ "Optional Label Language": "[可选] 标注语言",
67
+ "Optional online ver": "[可选] 使用在线版",
68
+ "Output Path": "输出路径",
69
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
70
+ "Precision": "精度",
71
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
72
+ "Put your text here.": "在此处输入文本.",
73
+ "Reference Audio": "参考音频",
74
+ "Reference Text": "参考文本",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
76
+ "Remove Selected Data": "移除选中数据",
77
+ "Removed path successfully!": "移除路径成功!",
78
+ "Repetition Penalty": "重复惩罚",
79
+ "Save model every n steps": "每 n 步保存模型",
80
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
81
+ "Select VITS ckpt": "选择 VITS 检查点",
82
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
83
+ "Select source file processing method": "选择源文件处理方法",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
85
+ "Selected: {}": "已选择: {}",
86
+ "Speaker": "说话人",
87
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
88
+ "Start Training": "开始训练",
89
+ "Streaming Audio": "流式音频",
90
+ "Streaming Generate": "流式合成",
91
+ "Tensorboard Host": "Tensorboard 监听地址",
92
+ "Tensorboard Log Path": "Tensorboard 日志路径",
93
+ "Tensorboard Port": "Tensorboard 端口",
94
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
95
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
96
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
98
+ "Training Configuration": "训练配置",
99
+ "Training Error": "训练错误",
100
+ "Training stopped": "训练已停止",
101
+ "Type name of the speaker": "输入说话人的名称",
102
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
103
+ "Use LoRA": "使用 LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
105
+ "Use filelist": "使用文件列表",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
107
+ "VITS Configuration": "VITS 配置",
108
+ "VQGAN Configuration": "VQGAN 配置",
109
+ "Validation Batch Size": "验证批次大小",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
112
+ "WebUI Host": "WebUI 监听地址",
113
+ "WebUI Port": "WebUI 端口",
114
+ "Whisper Model": "Whisper 模型",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
117
+ "latest": "最近的检查点",
118
+ "new": "创建新的检查点",
119
+ "Realtime Transform Text": "实时规范化文本",
120
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
121
+ "Text Normalization": "文本规范化",
122
+ "Select Example Audio": "选择参考音频"
123
+ }
fish_speech/i18n/scan.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import glob
3
+ import json
4
+ from collections import OrderedDict
5
+ from pathlib import Path
6
+
7
+ from loguru import logger
8
+
9
+ from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
10
+
11
+
12
+ def extract_i18n_strings(node):
13
+ i18n_strings = []
14
+
15
+ if (
16
+ isinstance(node, ast.Call)
17
+ and isinstance(node.func, ast.Name)
18
+ and node.func.id == "i18n"
19
+ ):
20
+ for arg in node.args:
21
+ if isinstance(arg, ast.Str):
22
+ i18n_strings.append(arg.s)
23
+
24
+ for child_node in ast.iter_child_nodes(node):
25
+ i18n_strings.extend(extract_i18n_strings(child_node))
26
+
27
+ return i18n_strings
28
+
29
+
30
+ # scan the directory for all .py files (recursively)
31
+ # for each file, parse the code into an AST
32
+ # for each AST, extract the i18n strings
33
+
34
+ strings = []
35
+ folders = ["fish_speech", "tools"]
36
+ # for filename in glob.iglob("**/*.py", recursive=True):
37
+ for folder in folders:
38
+ for f in Path(folder).rglob("*.py"):
39
+ code = f.read_text(encoding="utf-8")
40
+ if "i18n(" in code:
41
+ tree = ast.parse(code)
42
+ i18n_strings = extract_i18n_strings(tree)
43
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
44
+ strings.extend(i18n_strings)
45
+
46
+ code_keys = set(strings)
47
+ logger.info(f"Total unique: {len(code_keys)}")
48
+
49
+
50
+ standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
51
+ with open(standard_file, "r", encoding="utf-8") as f:
52
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
53
+ standard_keys = set(standard_data.keys())
54
+
55
+ # Define the standard file name
56
+ unused_keys = standard_keys - code_keys
57
+ logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
58
+ for unused_key in unused_keys:
59
+ logger.info(f"\t{unused_key}")
60
+
61
+ missing_keys = code_keys - standard_keys
62
+ logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
63
+ for missing_key in missing_keys:
64
+ logger.info(f"\t{missing_key}")
65
+
66
+ code_keys_dict = OrderedDict()
67
+ for s in strings:
68
+ code_keys_dict[s] = s
69
+
70
+ # write back
71
+ with open(standard_file, "w", encoding="utf-8") as f:
72
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
73
+ f.write("\n")
74
+
75
+ logger.info(f"Updated {standard_file}")
76
+
77
+
78
+ # Define the standard file name
79
+ standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
80
+
81
+ # Find all JSON files in the directory
82
+ dir_path = I18N_FILE_PATH
83
+ languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
84
+
85
+ # Load the standard file
86
+ with open(standard_file, "r", encoding="utf-8") as f:
87
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
88
+
89
+ # Loop through each language file
90
+ for lang_file in languages:
91
+ # Load the language file
92
+ with open(lang_file, "r", encoding="utf-8") as f:
93
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
94
+
95
+ # Find the difference between the language file and the standard file
96
+ diff = set(standard_data.keys()) - set(lang_data.keys())
97
+
98
+ miss = set(lang_data.keys()) - set(standard_data.keys())
99
+
100
+ # Add any missing keys to the language file
101
+ for key in diff:
102
+ lang_data[key] = "#!" + key
103
+ logger.info(f"Added missing key: {key} to {lang_file}")
104
+
105
+ # Del any extra keys to the language file
106
+ for key in miss:
107
+ del lang_data[key]
108
+ logger.info(f"Del extra key: {key} from {lang_file}")
109
+
110
+ # Sort the keys of the language file to match the order of the standard file
111
+ lang_data = OrderedDict(
112
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
113
+ )
114
+
115
+ # Save the updated language file
116
+ with open(lang_file, "w", encoding="utf-8") as f:
117
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
118
+ f.write("\n")
119
+
120
+ logger.info(f"Updated {lang_file}")
121
+
122
+ logger.info("Done")
fish_speech/models/text2semantic/__init__.py ADDED
File without changes
fish_speech/models/text2semantic/lit_module.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional
2
+
3
+ import lightning as L
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from lightning.pytorch.utilities.types import OptimizerLRScheduler
7
+
8
+ import fish_speech.utils as utils
9
+ from fish_speech.conversation import CODEBOOK_PAD_TOKEN_ID
10
+ from fish_speech.models.text2semantic.llama import NaiveTransformer
11
+
12
+ log = utils.RankedLogger(__name__, rank_zero_only=True)
13
+
14
+
15
+ class TextToSemantic(L.LightningModule):
16
+ def __init__(
17
+ self,
18
+ model: NaiveTransformer,
19
+ optimizer: Any,
20
+ lr_scheduler: Any,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.model = model
25
+ self.optimizer_builder = optimizer
26
+ self.lr_scheduler_builder = lr_scheduler
27
+
28
+ def forward(self, x):
29
+ return self.model(x)
30
+
31
+ def on_save_checkpoint(self, checkpoint):
32
+ # Save only LoRA parameters
33
+ state_dict = checkpoint["state_dict"]
34
+ use_lora = any("lora" in name for name in state_dict.keys())
35
+ if not use_lora:
36
+ return
37
+
38
+ for name in list(state_dict.keys()):
39
+ if "lora" not in name:
40
+ state_dict.pop(name)
41
+
42
+ def configure_optimizers(self) -> OptimizerLRScheduler:
43
+ # Get weight decay parameters
44
+ weight_decay_parameters, other_parameters = [], []
45
+ for name, param in self.named_parameters():
46
+ if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
47
+ other_parameters.append(param)
48
+ else:
49
+ weight_decay_parameters.append(param)
50
+
51
+ optimizer = self.optimizer_builder(
52
+ [
53
+ {"params": weight_decay_parameters},
54
+ {"params": other_parameters, "weight_decay": 0.0},
55
+ ]
56
+ )
57
+
58
+ # Print the parameters and their weight decay
59
+ for i in optimizer.param_groups:
60
+ log.info(
61
+ f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
62
+ )
63
+
64
+ lr_scheduler = self.lr_scheduler_builder(optimizer)
65
+
66
+ return {
67
+ "optimizer": optimizer,
68
+ "lr_scheduler": {
69
+ "scheduler": lr_scheduler,
70
+ "interval": "step",
71
+ },
72
+ }
73
+
74
+ # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
75
+ def get_batch_logps(
76
+ self,
77
+ logits: torch.FloatTensor,
78
+ labels: torch.LongTensor,
79
+ average_log_prob: bool = False,
80
+ ) -> torch.FloatTensor:
81
+ """Compute the log probabilities of the given labels under the given logits.
82
+
83
+ Args:
84
+ logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
85
+ labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
86
+ average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
87
+
88
+ Returns:
89
+ A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
90
+ """
91
+ assert logits.shape[:-1] == labels.shape
92
+
93
+ labels = labels.clone()
94
+ loss_mask = labels != -100
95
+
96
+ # dummy token; we'll ignore the losses on these tokens later
97
+ labels[labels == -100] = 0
98
+
99
+ per_token_logps = torch.gather(
100
+ logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
101
+ ).squeeze(-1)
102
+
103
+ if average_log_prob:
104
+ return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
105
+ else:
106
+ return (per_token_logps * loss_mask).sum(-1)
107
+
108
+ def _step(self, batch, batch_idx, stage: str):
109
+ is_train = stage == "train"
110
+
111
+ if is_train:
112
+ # Key part to make lora work
113
+ # Otherwise the parameters are merged, which lead to incorrect gradients
114
+ self.model.train()
115
+
116
+ # Do positive and negative samples in the same batch to speed up training
117
+ labels = batch["labels"]
118
+ outputs = self.model(
119
+ inp=batch["inputs"],
120
+ key_padding_mask=batch["attention_masks"],
121
+ )
122
+ token_logits = outputs.token_logits
123
+ codebook_logits = outputs.codebook_logits
124
+
125
+ # Generate labels
126
+ base_loss = F.cross_entropy(
127
+ token_logits.view(-1, token_logits.size(-1)),
128
+ labels[:, 0].reshape(-1),
129
+ ignore_index=-100,
130
+ )
131
+
132
+ codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
133
+ semantic_loss = F.cross_entropy(
134
+ codebook_logits.view(-1, codebook_logits.size(-1)),
135
+ codebook_labels.reshape(-1),
136
+ ignore_index=-100,
137
+ )
138
+
139
+ loss = base_loss + semantic_loss
140
+
141
+ self.log(
142
+ f"{stage}/loss",
143
+ loss,
144
+ on_step=is_train,
145
+ on_epoch=not is_train,
146
+ prog_bar=True,
147
+ logger=True,
148
+ sync_dist=not is_train,
149
+ )
150
+
151
+ self.log(
152
+ f"{stage}/base_loss",
153
+ base_loss,
154
+ on_step=is_train,
155
+ on_epoch=not is_train,
156
+ prog_bar=False,
157
+ logger=True,
158
+ sync_dist=not is_train,
159
+ )
160
+
161
+ self.log(
162
+ f"{stage}/semantic_loss",
163
+ semantic_loss,
164
+ on_step=is_train,
165
+ on_epoch=not is_train,
166
+ prog_bar=False,
167
+ logger=True,
168
+ sync_dist=not is_train,
169
+ )
170
+
171
+ # Top-5 accuracy
172
+ accuracy = self.get_accuracy(codebook_logits, codebook_labels)
173
+ self.log(
174
+ f"{stage}/top_5_accuracy",
175
+ accuracy,
176
+ on_step=is_train,
177
+ on_epoch=not is_train,
178
+ prog_bar=True,
179
+ logger=True,
180
+ sync_dist=not is_train,
181
+ )
182
+
183
+ return loss
184
+
185
+ def get_accuracy(self, logits, labels):
186
+ mask = (labels != -100) & (labels != CODEBOOK_PAD_TOKEN_ID)
187
+ if mask.sum() == 0:
188
+ return torch.tensor(0.0, device=logits.device)
189
+
190
+ _, indices = logits.topk(5, dim=-1)
191
+ correct = indices.eq(labels.unsqueeze(-1))
192
+ correct[~mask] = 0
193
+ correct = correct.sum()
194
+ accuracy = correct / mask.sum()
195
+
196
+ return accuracy
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ return self._step(batch, batch_idx, "train")
200
+
201
+ def validation_step(self, batch, batch_idx):
202
+ return self._step(batch, batch_idx, "val")
fish_speech/models/text2semantic/llama.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+ import math
4
+ from collections import OrderedDict
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from loguru import logger
13
+ from torch import Tensor
14
+ from torch.nn import functional as F
15
+ from torch.nn.attention import SDPBackend, sdpa_kernel
16
+ from torch.utils.checkpoint import checkpoint
17
+ from transformers import AutoTokenizer
18
+
19
+ from fish_speech.conversation import SEMANTIC_TOKEN
20
+ from fish_speech.utils import RankedLogger
21
+
22
+ from .lora import LoraConfig, setup_lora
23
+
24
+ log = RankedLogger(__name__, rank_zero_only=True)
25
+
26
+
27
+ def find_multiple(n: int, k: int) -> int:
28
+ if n % k == 0:
29
+ return n
30
+ return n + k - (n % k)
31
+
32
+
33
+ @dataclass
34
+ class BaseModelArgs:
35
+ model_type: str = "base"
36
+
37
+ vocab_size: int = 32000
38
+ n_layer: int = 32
39
+ n_head: int = 32
40
+ dim: int = 4096
41
+ intermediate_size: int = None
42
+ n_local_heads: int = -1
43
+ head_dim: int = 64
44
+ rope_base: float = 10000
45
+ norm_eps: float = 1e-5
46
+ max_seq_len: int = 2048
47
+ dropout: float = 0.0
48
+ tie_word_embeddings: bool = True
49
+ attention_qkv_bias: bool = False
50
+
51
+ # Codebook configs
52
+ codebook_size: int = 160
53
+ num_codebooks: int = 4
54
+
55
+ # Gradient checkpointing
56
+ use_gradient_checkpointing: bool = True
57
+
58
+ # Initialize the model
59
+ initializer_range: float = 0.02
60
+
61
+ # Dummy vars
62
+ is_reward_model: bool = False
63
+ share_codebook_embeddings: bool = True
64
+
65
+ def __post_init__(self):
66
+ if self.n_local_heads == -1:
67
+ self.n_local_heads = self.n_head
68
+ if self.intermediate_size is None:
69
+ hidden_dim = 4 * self.dim
70
+ n_hidden = int(2 * hidden_dim / 3)
71
+ self.intermediate_size = find_multiple(n_hidden, 256)
72
+ self.head_dim = self.dim // self.n_head
73
+
74
+ @staticmethod
75
+ def from_pretrained(path: str):
76
+ path = Path(path)
77
+
78
+ if path.is_dir():
79
+ path = path / "config.json"
80
+
81
+ with open(path, "r", encoding="utf-8") as f:
82
+ data = json.load(f)
83
+
84
+ match data["model_type"]:
85
+ case "naive":
86
+ cls = NaiveModelArgs
87
+ case "dual_ar":
88
+ cls = DualARModelArgs
89
+ case _:
90
+ raise ValueError(f"Unknown model type: {data['model_type']}")
91
+
92
+ return cls(**data)
93
+
94
+ def save(self, path: str):
95
+ with open(path, "w") as f:
96
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
97
+
98
+
99
+ @dataclass
100
+ class NaiveModelArgs(BaseModelArgs):
101
+ model_type: str = "naive"
102
+
103
+
104
+ @dataclass
105
+ class DualARModelArgs(BaseModelArgs):
106
+ model_type: str = "dual_ar"
107
+ n_fast_layer: int = 4
108
+ fast_dim: int | None = None
109
+ fast_n_head: int | None = None
110
+ fast_n_local_heads: int | None = None
111
+ fast_head_dim: int | None = None
112
+ fast_intermediate_size: int | None = None
113
+ fast_attention_qkv_bias: bool | None = None
114
+
115
+ def __post_init__(self):
116
+ super().__post_init__()
117
+
118
+ self.fast_dim = self.fast_dim or self.dim
119
+ self.fast_n_head = self.fast_n_head or self.n_head
120
+ self.fast_n_local_heads = self.fast_n_local_heads or self.n_local_heads
121
+ self.fast_head_dim = self.fast_head_dim or self.head_dim
122
+ self.fast_intermediate_size = (
123
+ self.fast_intermediate_size or self.intermediate_size
124
+ )
125
+ self.fast_attention_qkv_bias = (
126
+ self.fast_attention_qkv_bias
127
+ if self.fast_attention_qkv_bias is not None
128
+ else self.attention_qkv_bias
129
+ )
130
+
131
+
132
+ class KVCache(nn.Module):
133
+ def __init__(
134
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
135
+ ):
136
+ super().__init__()
137
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
138
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
139
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
140
+
141
+ def update(self, input_pos, k_val, v_val):
142
+ # input_pos: [S], k_val: [B, H, S, D]
143
+ assert input_pos.shape[0] == k_val.shape[2]
144
+
145
+ k_out = self.k_cache
146
+ v_out = self.v_cache
147
+ k_out[:, :, input_pos] = k_val
148
+ v_out[:, :, input_pos] = v_val
149
+
150
+ return k_out, v_out
151
+
152
+
153
+ @dataclass
154
+ class TransformerForwardResult:
155
+ token_logits: Tensor
156
+ codebook_logits: Tensor
157
+
158
+
159
+ @dataclass
160
+ class BaseTransformerForwardResult:
161
+ logits: Tensor
162
+ hidden_states: Tensor
163
+
164
+
165
+ class BaseTransformer(nn.Module):
166
+ def __init__(
167
+ self, config: BaseModelArgs, tokenizer: AutoTokenizer, init_weights: bool = True
168
+ ) -> None:
169
+ super().__init__()
170
+ self.config = config
171
+ self.tokenizer = tokenizer
172
+
173
+ self.semantic_token_id = tokenizer.convert_tokens_to_ids(SEMANTIC_TOKEN)
174
+
175
+ # Slow transformer
176
+ self.embeddings = nn.Embedding(
177
+ config.vocab_size,
178
+ config.dim,
179
+ )
180
+ self.codebook_embeddings = nn.Embedding(
181
+ config.codebook_size * config.num_codebooks,
182
+ config.dim,
183
+ )
184
+ self.layers = nn.ModuleList(
185
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
186
+ )
187
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
188
+
189
+ if self.config.tie_word_embeddings is False:
190
+ self.output = nn.Linear(
191
+ config.dim,
192
+ config.vocab_size,
193
+ bias=False,
194
+ )
195
+
196
+ self.register_buffer(
197
+ "freqs_cis",
198
+ precompute_freqs_cis(
199
+ config.max_seq_len,
200
+ config.dim // config.n_head,
201
+ config.rope_base,
202
+ ),
203
+ persistent=False,
204
+ )
205
+ self.register_buffer(
206
+ "causal_mask",
207
+ torch.tril(
208
+ torch.ones(
209
+ config.max_seq_len,
210
+ config.max_seq_len,
211
+ dtype=torch.bool,
212
+ )
213
+ ),
214
+ persistent=False,
215
+ )
216
+
217
+ # For kv cache
218
+ self.max_batch_size = -1
219
+ self.max_seq_len = -1
220
+
221
+ if init_weights:
222
+ self.apply(self._init_weights)
223
+
224
+ def setup_caches(
225
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
226
+ ):
227
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
228
+ return
229
+
230
+ head_dim = self.config.dim // self.config.n_head
231
+ max_seq_len = find_multiple(max_seq_len, 8)
232
+ self.max_seq_len = max_seq_len
233
+ self.max_batch_size = max_batch_size
234
+
235
+ for b in self.layers:
236
+ b.attention.kv_cache = KVCache(
237
+ max_batch_size,
238
+ max_seq_len,
239
+ self.config.n_local_heads,
240
+ head_dim,
241
+ dtype=dtype,
242
+ )
243
+
244
+ def embed(self, x: Tensor) -> Tensor:
245
+ vocab_embeds = [self.embeddings(x[:, 0])]
246
+ for i in range(self.config.num_codebooks):
247
+ emb = self.codebook_embeddings(x[:, i + 1] + i * self.config.codebook_size)
248
+ emb[x[:, 0] != self.semantic_token_id] = 0
249
+ vocab_embeds.append(emb)
250
+
251
+ x = torch.stack(vocab_embeds, dim=3)
252
+ x = x.sum(dim=3)
253
+
254
+ return x
255
+
256
+ def forward(
257
+ self,
258
+ inp: Tensor,
259
+ key_padding_mask: Optional[Tensor] = None,
260
+ ) -> BaseTransformerForwardResult:
261
+ seq_len = inp.size(2)
262
+
263
+ # Here we want to merge the embeddings of the codebooks
264
+ x = self.embed(inp)
265
+
266
+ freqs_cis = self.freqs_cis[:seq_len]
267
+
268
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
269
+ # That is, FALSE means masked out
270
+ # To maintain consistency, key_padding_mask use TRUE to mask out
271
+ mask = None
272
+ if key_padding_mask is not None:
273
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
274
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
275
+
276
+ for layer in self.layers:
277
+ if self.config.use_gradient_checkpointing and self.training:
278
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
279
+ else:
280
+ x = layer(x, freqs_cis, mask)
281
+
282
+ # We got slow_out here
283
+ slow_out = self.norm(x)
284
+
285
+ if self.config.tie_word_embeddings:
286
+ token_logits = F.linear(slow_out, self.embeddings.weight)
287
+ else:
288
+ token_logits = self.output(slow_out)
289
+
290
+ return BaseTransformerForwardResult(
291
+ logits=token_logits,
292
+ hidden_states=x,
293
+ )
294
+
295
+ def forward_generate(
296
+ self,
297
+ x: Tensor,
298
+ input_pos: Optional[Tensor] = None,
299
+ return_all: bool = False,
300
+ ) -> BaseTransformerForwardResult:
301
+ # This is used for generation, optimized for torch compile
302
+ assert (
303
+ self.max_seq_len != -1 and self.max_batch_size != -1
304
+ ), "Please call setup_caches before forward_generate"
305
+
306
+ x = self.embed(x)
307
+
308
+ mask = self.causal_mask[
309
+ None, None, input_pos, : self.max_seq_len
310
+ ] # (B, N, Q, K)
311
+ freqs_cis = self.freqs_cis[input_pos]
312
+
313
+ for layer in self.layers:
314
+ x = layer(x, freqs_cis, mask, input_pos=input_pos)
315
+
316
+ # If prefill, we only calculate the logits of last token
317
+ if x.size(1) > 1 and not return_all:
318
+ x = x[:, -1:]
319
+
320
+ # We got slow_out here
321
+ slow_out = self.norm(x)
322
+
323
+ if self.config.tie_word_embeddings:
324
+ token_logits = F.linear(slow_out, self.embeddings.weight)
325
+ else:
326
+ token_logits = self.output(slow_out)
327
+
328
+ return BaseTransformerForwardResult(
329
+ logits=token_logits,
330
+ hidden_states=x,
331
+ )
332
+
333
+ def _init_weights(self, module):
334
+ std = self.config.initializer_range
335
+ if isinstance(module, nn.Linear):
336
+ module.weight.data.normal_(mean=0.0, std=std)
337
+ if module.bias is not None:
338
+ module.bias.data.zero_()
339
+ elif isinstance(module, nn.Embedding):
340
+ module.weight.data.normal_(mean=0.0, std=std)
341
+ if module.padding_idx is not None:
342
+ module.weight.data[module.padding_idx].zero_()
343
+
344
+ @staticmethod
345
+ def from_pretrained(
346
+ path: str,
347
+ load_weights: bool = False,
348
+ max_length: int | None = None,
349
+ lora_config: LoraConfig | None = None,
350
+ rope_base: int | None = None,
351
+ ) -> "BaseTransformer":
352
+ config = BaseModelArgs.from_pretrained(str(path))
353
+ if max_length is not None:
354
+ config.max_seq_len = max_length
355
+ log.info(f"Override max_seq_len to {max_length}")
356
+
357
+ if rope_base is not None:
358
+ config.rope_base = rope_base
359
+ log.info(f"Override rope_base to {rope_base}")
360
+
361
+ match config.model_type:
362
+ case "naive":
363
+ model_cls = NaiveTransformer
364
+ case "dual_ar":
365
+ model_cls = DualARTransformer
366
+ case _:
367
+ raise ValueError(f"Unknown model type: {config.model_type}")
368
+
369
+ tokenizer = AutoTokenizer.from_pretrained(str(path))
370
+ log.info(f"Loading model from {path}, config: {config}")
371
+ model = model_cls(config, tokenizer=tokenizer)
372
+
373
+ if lora_config is not None:
374
+ setup_lora(model, lora_config)
375
+ log.info(f"LoRA setup: {lora_config}")
376
+
377
+ if load_weights is False:
378
+ log.info("Randomly initialized model")
379
+ else:
380
+
381
+ if "int8" in str(Path(path)):
382
+ logger.info("Using int8 weight-only quantization!")
383
+ from tools.llama.quantize import WeightOnlyInt8QuantHandler
384
+
385
+ simple_quantizer = WeightOnlyInt8QuantHandler(model)
386
+ model = simple_quantizer.convert_for_runtime()
387
+
388
+ if "int4" in str(Path(path)):
389
+ logger.info("Using int4 quantization!")
390
+ path_comps = path.name.split("-")
391
+ assert path_comps[-2].startswith("g")
392
+ groupsize = int(path_comps[-2][1:])
393
+ from tools.llama.quantize import WeightOnlyInt4QuantHandler
394
+
395
+ simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
396
+ model = simple_quantizer.convert_for_runtime()
397
+
398
+ weights = torch.load(
399
+ Path(path) / "model.pth",
400
+ map_location="cpu",
401
+ mmap=True,
402
+ weights_only=True,
403
+ )
404
+
405
+ if "state_dict" in weights:
406
+ logger.warning(
407
+ "Using a TextToSemantic LightningModule checkpoint, "
408
+ "please make sure it is a full model, not a LoRA model."
409
+ )
410
+ weights = weights["state_dict"]
411
+
412
+ if next(iter(weights.keys())).startswith("model."):
413
+ logger.info(
414
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
415
+ )
416
+ new_weights = OrderedDict()
417
+ for k, v in weights.items():
418
+ new_weights[k.replace("model.", "")] = v
419
+ weights = new_weights
420
+
421
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
422
+ for k, v in model.named_parameters():
423
+ if k not in weights:
424
+ logger.warning(f"No weight for {k}")
425
+ elif v.shape != weights[k].shape:
426
+ logger.warning(
427
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
428
+ )
429
+
430
+ err = model.load_state_dict(weights, strict=False, assign=True)
431
+ log.info(f"Loaded weights with error: {err}")
432
+
433
+ return model
434
+
435
+ def save_pretrained(self, path: str, drop_lora: bool = False):
436
+ path = Path(path)
437
+ path.mkdir(parents=True, exist_ok=True)
438
+
439
+ self.config.save(path / "config.json")
440
+ state_dict = self.state_dict()
441
+
442
+ if drop_lora:
443
+ for key in list(state_dict.keys()):
444
+ if "lora" not in key:
445
+ continue
446
+
447
+ state_dict.pop(key)
448
+ log.info(f"Drop LoRA parameter: {key}")
449
+
450
+ torch.save(state_dict, path / "model.pth")
451
+ self.tokenizer.save_pretrained(path)
452
+
453
+
454
+ class NaiveTransformer(BaseTransformer):
455
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
456
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
457
+
458
+ self.codebook_norm = RMSNorm(config.dim, eps=config.norm_eps)
459
+ self.codebook_output = nn.Linear(
460
+ config.dim,
461
+ config.codebook_size * config.num_codebooks,
462
+ bias=False,
463
+ )
464
+
465
+ self.apply(self._init_weights)
466
+
467
+ def decode(self, result: BaseTransformerForwardResult) -> TransformerForwardResult:
468
+ token_logits = result.logits
469
+ x = result.hidden_states
470
+
471
+ # Codebook
472
+ codebook_logits = self.codebook_output(self.codebook_norm(x))
473
+ codebook_logits = rearrange(
474
+ codebook_logits, "b n (c d) -> b n c d", c=self.config.num_codebooks
475
+ )
476
+
477
+ return TransformerForwardResult(
478
+ token_logits=token_logits,
479
+ codebook_logits=codebook_logits,
480
+ )
481
+
482
+ def forward(
483
+ self,
484
+ inp: Tensor,
485
+ key_padding_mask: Optional[Tensor] = None,
486
+ ) -> TransformerForwardResult:
487
+ result = super().forward(
488
+ inp=inp,
489
+ key_padding_mask=key_padding_mask,
490
+ )
491
+ return self.decode(result)
492
+
493
+ def forward_generate(
494
+ self, x: Tensor, input_pos: Optional[Tensor] = None
495
+ ) -> TransformerForwardResult:
496
+ result = super().forward_generate(x, input_pos)
497
+ return self.decode(result)
498
+
499
+
500
+ class DualARTransformer(BaseTransformer):
501
+ def __init__(self, config: NaiveModelArgs, tokenizer: AutoTokenizer) -> None:
502
+ super().__init__(config, init_weights=False, tokenizer=tokenizer)
503
+
504
+ # Project to fast dim if needed
505
+ if config.fast_dim is not None and config.fast_dim != config.dim:
506
+ self.fast_project_in = nn.Linear(config.dim, config.fast_dim)
507
+ else:
508
+ self.fast_project_in = nn.Identity()
509
+
510
+ # Fast transformer
511
+ self.fast_embeddings = nn.Embedding(config.codebook_size, config.fast_dim)
512
+
513
+ # The equivalent bs is so large that sdpa doesn't work
514
+ override_config = dataclasses.replace(
515
+ config,
516
+ dim=config.fast_dim,
517
+ n_head=config.fast_n_head,
518
+ n_local_heads=config.fast_n_local_heads,
519
+ head_dim=config.fast_head_dim,
520
+ intermediate_size=config.fast_intermediate_size,
521
+ attention_qkv_bias=config.fast_attention_qkv_bias,
522
+ )
523
+
524
+ self.fast_layers = nn.ModuleList(
525
+ TransformerBlock(override_config, use_sdpa=False)
526
+ for _ in range(config.n_fast_layer)
527
+ )
528
+ self.fast_norm = RMSNorm(config.fast_dim, eps=config.norm_eps)
529
+ self.fast_output = nn.Linear(
530
+ config.fast_dim,
531
+ config.codebook_size,
532
+ bias=False,
533
+ )
534
+
535
+ self.register_buffer(
536
+ "fast_freqs_cis",
537
+ precompute_freqs_cis(
538
+ config.num_codebooks,
539
+ config.fast_dim // config.fast_n_head,
540
+ config.rope_base,
541
+ ),
542
+ persistent=False,
543
+ )
544
+ self.apply(self._init_weights)
545
+
546
+ def setup_caches(
547
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16
548
+ ):
549
+ super().setup_caches(max_batch_size, max_seq_len, dtype)
550
+
551
+ head_dim = self.config.fast_dim // self.config.fast_n_head
552
+
553
+ # Fast transformer
554
+ # The max seq len here is the number of codebooks
555
+ for b in self.fast_layers:
556
+ b.attention.kv_cache = KVCache(
557
+ max_batch_size,
558
+ self.config.num_codebooks,
559
+ self.config.fast_n_local_heads,
560
+ head_dim,
561
+ dtype=dtype,
562
+ )
563
+
564
+ def forward(
565
+ self,
566
+ inp: Tensor,
567
+ key_padding_mask: Optional[Tensor] = None,
568
+ ) -> TransformerForwardResult:
569
+ parent_result = super().forward(inp, key_padding_mask)
570
+ token_logits = parent_result.logits
571
+ x = parent_result.hidden_states
572
+ x = self.fast_project_in(x)
573
+
574
+ # Fast transformer
575
+ fast_seq_len = self.config.num_codebooks
576
+ fast_mask = self.causal_mask[
577
+ None, None, :fast_seq_len, :fast_seq_len
578
+ ] # (B, N, Q, K)
579
+
580
+ # Drop the last token and rotate left
581
+ codebooks = inp[:, 1:-1, 1:]
582
+ codebooks = F.pad(codebooks, (0, 1), value=0)
583
+ codebook_embeddings = self.fast_embeddings(codebooks)
584
+ x = torch.cat([x[:, None], codebook_embeddings], dim=1)
585
+ b, s = x.size(0), x.size(2)
586
+ x = rearrange(x, "b n s d -> (b s) n d") # flatten the batch and seq_len
587
+
588
+ # Remove padded part
589
+ codebooks = rearrange(codebooks, "b n s -> (b s) n")
590
+ codebook_mask = (codebooks == 0).all(dim=-1)
591
+
592
+ if torch.all(codebook_mask):
593
+ # If all codebooks are padded, we keep first 8 to make sure the model runs
594
+ codebook_mask[:8] = False
595
+
596
+ x_bs, x_len = x.size(0), x.size(1)
597
+ x = x[~codebook_mask]
598
+
599
+ for layer in self.fast_layers:
600
+ if self.config.use_gradient_checkpointing and self.training:
601
+ x = checkpoint(
602
+ layer, x, self.fast_freqs_cis, fast_mask, use_reentrant=True
603
+ )
604
+ else:
605
+ x = layer(x, self.fast_freqs_cis, fast_mask)
606
+
607
+ # unflatten the batch and num_codebooks
608
+ fast_out = self.fast_norm(x)
609
+ codebook_logits = self.fast_output(fast_out)
610
+
611
+ # Re-pad the codebook_logits
612
+ buffer = torch.zeros(
613
+ x_bs,
614
+ x_len,
615
+ codebook_logits.size(-1),
616
+ device=codebook_logits.device,
617
+ dtype=codebook_logits.dtype,
618
+ )
619
+ buffer[~codebook_mask] = codebook_logits
620
+ codebook_logits = buffer
621
+
622
+ assert codebook_logits.shape[1] == self.config.num_codebooks
623
+ codebook_logits = rearrange(
624
+ codebook_logits,
625
+ "(b s) n d -> b s n d",
626
+ b=b,
627
+ s=s,
628
+ n=self.config.num_codebooks,
629
+ )
630
+
631
+ return TransformerForwardResult(
632
+ token_logits=token_logits,
633
+ codebook_logits=codebook_logits,
634
+ )
635
+
636
+ def forward_generate_fast(
637
+ self, x: Tensor, input_pos: Optional[Tensor] = None
638
+ ) -> Tensor:
639
+ # Fast transformer
640
+ x = x.view(1, 1, -1)
641
+
642
+ fast_mask = self.causal_mask[
643
+ None, None, input_pos, : self.config.num_codebooks
644
+ ] # (B, N, Q, K)
645
+ fast_freqs_cis = self.fast_freqs_cis[input_pos]
646
+
647
+ for layer in self.fast_layers:
648
+ x = layer(x, fast_freqs_cis, fast_mask, input_pos=input_pos)
649
+
650
+ # unflatten the batch and num_codebooks
651
+ fast_out = self.fast_norm(x) # only take the last token
652
+ codebook_logits = self.fast_output(fast_out)
653
+
654
+ return codebook_logits
655
+
656
+ def forward_generate(
657
+ self, x: Tensor, input_pos: Optional[Tensor] = None
658
+ ) -> TransformerForwardResult:
659
+ x = super().forward_generate(x, input_pos)
660
+ x.hidden_states = self.fast_project_in(x.hidden_states)
661
+ return x
662
+
663
+
664
+ class TransformerBlock(nn.Module):
665
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
666
+ super().__init__()
667
+ self.attention = Attention(config, use_sdpa=use_sdpa)
668
+ self.feed_forward = FeedForward(config)
669
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
670
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
671
+
672
+ def forward(
673
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
674
+ ) -> Tensor:
675
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
676
+ out = h + self.feed_forward(self.ffn_norm(h))
677
+ return out
678
+
679
+
680
+ class Attention(nn.Module):
681
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
682
+ super().__init__()
683
+ assert config.dim % config.n_head == 0
684
+
685
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
686
+ # key, query, value projections for all heads, but in a batch
687
+ self.wqkv = nn.Linear(
688
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
689
+ )
690
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
691
+ self.kv_cache = None
692
+
693
+ self.dropout = config.dropout
694
+ self.n_head = config.n_head
695
+ self.head_dim = config.head_dim
696
+ self.n_local_heads = config.n_local_heads
697
+ self.dim = config.dim
698
+ self.use_sdpa = use_sdpa
699
+ self._register_load_state_dict_pre_hook(self.load_hook)
700
+
701
+ def load_hook(self, state_dict, prefix, *args):
702
+ if prefix + "wq.weight" in state_dict:
703
+ wq = state_dict.pop(prefix + "wq.weight")
704
+ wk = state_dict.pop(prefix + "wk.weight")
705
+ wv = state_dict.pop(prefix + "wv.weight")
706
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
707
+
708
+ def forward(
709
+ self,
710
+ x: Tensor,
711
+ freqs_cis: Tensor,
712
+ mask: Tensor,
713
+ input_pos: Optional[Tensor] = None,
714
+ ) -> Tensor:
715
+ bsz, seqlen, _ = x.shape
716
+
717
+ kv_size = self.n_local_heads * self.head_dim
718
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
719
+
720
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
721
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
722
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
723
+
724
+ q = apply_rotary_emb(q, freqs_cis)
725
+ k = apply_rotary_emb(k, freqs_cis)
726
+
727
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
728
+
729
+ if self.kv_cache is not None:
730
+ k, v = self.kv_cache.update(input_pos, k, v)
731
+
732
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
733
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
734
+
735
+ if self.use_sdpa:
736
+ if mask is None:
737
+ with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
738
+ y = F.scaled_dot_product_attention(
739
+ q,
740
+ k,
741
+ v,
742
+ dropout_p=self.dropout if self.training else 0.0,
743
+ is_causal=True,
744
+ # No third party attn_mask here to use flash_attention
745
+ )
746
+ else:
747
+ y = F.scaled_dot_product_attention(
748
+ q,
749
+ k,
750
+ v,
751
+ attn_mask=mask,
752
+ dropout_p=self.dropout if self.training else 0.0,
753
+ )
754
+ else:
755
+ y = self.eq_scaled_dot_product_attention(
756
+ q,
757
+ k,
758
+ v,
759
+ attn_mask=mask,
760
+ dropout_p=self.dropout if self.training else 0.0,
761
+ )
762
+
763
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
764
+
765
+ return self.wo(y)
766
+
767
+ def eq_scaled_dot_product_attention(
768
+ self,
769
+ query,
770
+ key,
771
+ value,
772
+ attn_mask=None,
773
+ dropout_p=0.0,
774
+ ) -> torch.Tensor:
775
+ # This is a standard scaled dot product attention
776
+ # It's low efficient, but it doesn't raise cuda error
777
+
778
+ L, S = query.size(-2), key.size(-2)
779
+ scale_factor = 1 / math.sqrt(query.size(-1))
780
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
781
+
782
+ if attn_mask is not None:
783
+ if attn_mask.dtype == torch.bool:
784
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
785
+ else:
786
+ attn_bias += attn_mask
787
+
788
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
789
+ attn_weight += attn_bias
790
+ attn_weight = torch.softmax(attn_weight, dim=-1)
791
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
792
+
793
+ return attn_weight @ value
794
+
795
+
796
+ class FeedForward(nn.Module):
797
+ def __init__(self, config: BaseModelArgs) -> None:
798
+ super().__init__()
799
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
800
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
801
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
802
+
803
+ def forward(self, x: Tensor) -> Tensor:
804
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
805
+
806
+
807
+ class RMSNorm(nn.Module):
808
+ def __init__(self, dim: int, eps: float = 1e-5):
809
+ super().__init__()
810
+ self.eps = eps
811
+ self.weight = nn.Parameter(torch.ones(dim))
812
+
813
+ def _norm(self, x):
814
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
815
+
816
+ def forward(self, x: Tensor) -> Tensor:
817
+ output = self._norm(x.float()).type_as(x)
818
+ return output * self.weight
819
+
820
+
821
+ def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
822
+ freqs = 1.0 / (
823
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
824
+ )
825
+ t = torch.arange(seq_len, device=freqs.device)
826
+ freqs = torch.outer(t, freqs)
827
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
828
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
829
+ return cache.to(dtype=torch.bfloat16)
830
+
831
+
832
+ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
833
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
834
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
835
+ x_out2 = torch.stack(
836
+ [
837
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
838
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
839
+ ],
840
+ -1,
841
+ )
842
+
843
+ x_out2 = x_out2.flatten(3)
844
+ return x_out2.type_as(x)
fish_speech/models/text2semantic/lora.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import loralib as lora
4
+
5
+
6
+ @dataclass
7
+ class LoraConfig:
8
+ r: int
9
+ lora_alpha: float
10
+ lora_dropout: float = 0.0
11
+
12
+
13
+ def setup_lora(model, lora_config):
14
+ # Replace the embedding layer with a LoRA layer
15
+ model.embeddings = lora.Embedding(
16
+ num_embeddings=model.embeddings.num_embeddings,
17
+ embedding_dim=model.embeddings.embedding_dim,
18
+ padding_idx=model.embeddings.padding_idx,
19
+ r=lora_config.r,
20
+ lora_alpha=lora_config.lora_alpha,
21
+ )
22
+
23
+ model.codebook_embeddings = lora.Embedding(
24
+ num_embeddings=model.codebook_embeddings.num_embeddings,
25
+ embedding_dim=model.codebook_embeddings.embedding_dim,
26
+ padding_idx=model.codebook_embeddings.padding_idx,
27
+ r=lora_config.r,
28
+ lora_alpha=lora_config.lora_alpha,
29
+ )
30
+
31
+ # Replace output layer with a LoRA layer
32
+ linears = [(model, "output")]
33
+
34
+ # Replace all linear layers with LoRA layers
35
+ for layer in model.layers:
36
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
37
+ linears.extend(
38
+ [
39
+ (layer.feed_forward, "w1"),
40
+ (layer.feed_forward, "w2"),
41
+ (layer.feed_forward, "w3"),
42
+ ]
43
+ )
44
+
45
+ if hasattr(model, "fast_layers"):
46
+ model.fast_embeddings = lora.Embedding(
47
+ num_embeddings=model.fast_embeddings.num_embeddings,
48
+ embedding_dim=model.fast_embeddings.embedding_dim,
49
+ padding_idx=model.fast_embeddings.padding_idx,
50
+ r=lora_config.r,
51
+ lora_alpha=lora_config.lora_alpha,
52
+ )
53
+
54
+ # Dual-AR model
55
+ linears.append((model, "fast_output"))
56
+
57
+ for layer in model.fast_layers:
58
+ linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
59
+ linears.extend(
60
+ [
61
+ (layer.feed_forward, "w1"),
62
+ (layer.feed_forward, "w2"),
63
+ (layer.feed_forward, "w3"),
64
+ ]
65
+ )
66
+
67
+ for module, layer in linears:
68
+ updated_linear = lora.Linear(
69
+ in_features=getattr(module, layer).in_features,
70
+ out_features=getattr(module, layer).out_features,
71
+ bias=getattr(module, layer).bias,
72
+ r=lora_config.r,
73
+ lora_alpha=lora_config.lora_alpha,
74
+ lora_dropout=lora_config.lora_dropout,
75
+ )
76
+ setattr(module, layer, updated_linear)
77
+
78
+ # Mark only the LoRA layers as trainable
79
+ lora.mark_only_lora_as_trainable(model, bias="none")
80
+
81
+
82
+ def get_merged_state_dict(model):
83
+ # This line will merge the state dict of the model and the LoRA parameters
84
+ model.eval()
85
+
86
+ # Then we need to remove the LoRA parameters from the state dict
87
+ state_dict = model.state_dict()
88
+ for name in list(state_dict.keys()):
89
+ if "lora" in name:
90
+ state_dict.pop(name)
91
+
92
+ return state_dict
fish_speech/models/vqgan/__init__.py ADDED
File without changes
fish_speech/models/vqgan/modules/firefly.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from math import prod
4
+ from typing import Callable
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ from torch.nn.utils.parametrizations import weight_norm
10
+ from torch.nn.utils.parametrize import remove_parametrizations
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+
14
+ def sequence_mask(length, max_length=None):
15
+ if max_length is None:
16
+ max_length = length.max()
17
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
+ return x.unsqueeze(0) < length.unsqueeze(1)
19
+
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv1D") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def get_padding(kernel_size, dilation=1):
28
+ return (kernel_size * dilation - dilation) // 2
29
+
30
+
31
+ def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
32
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
33
+ padding_left, padding_right = paddings
34
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
35
+ assert (padding_left + padding_right) <= x.shape[-1]
36
+ end = x.shape[-1] - padding_right
37
+ return x[..., padding_left:end]
38
+
39
+
40
+ def get_extra_padding_for_conv1d(
41
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
42
+ ) -> int:
43
+ """See `pad_for_conv1d`."""
44
+ length = x.shape[-1]
45
+ n_frames = (length - kernel_size + padding_total) / stride + 1
46
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
47
+ return ideal_length - length
48
+
49
+
50
+ def pad1d(
51
+ x: torch.Tensor,
52
+ paddings: tuple[int, int],
53
+ mode: str = "zeros",
54
+ value: float = 0.0,
55
+ ):
56
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
57
+ If this is the case, we insert extra 0 padding to the right
58
+ before the reflection happen.
59
+ """
60
+ length = x.shape[-1]
61
+ padding_left, padding_right = paddings
62
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
63
+ if mode == "reflect":
64
+ max_pad = max(padding_left, padding_right)
65
+ extra_pad = 0
66
+ if length <= max_pad:
67
+ extra_pad = max_pad - length + 1
68
+ x = F.pad(x, (0, extra_pad))
69
+ padded = F.pad(x, paddings, mode, value)
70
+ end = padded.shape[-1] - extra_pad
71
+ return padded[..., :end]
72
+ else:
73
+ return F.pad(x, paddings, mode, value)
74
+
75
+
76
+ class FishConvNet(nn.Module):
77
+ def __init__(
78
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
79
+ ):
80
+ super(FishConvNet, self).__init__()
81
+ self.conv = nn.Conv1d(
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size,
85
+ stride=stride,
86
+ dilation=dilation,
87
+ groups=groups,
88
+ )
89
+ self.stride = stride
90
+ self.kernel_size = (kernel_size - 1) * dilation + 1
91
+ self.dilation = dilation
92
+
93
+ def forward(self, x):
94
+ pad = self.kernel_size - self.stride
95
+ extra_padding = get_extra_padding_for_conv1d(
96
+ x, self.kernel_size, self.stride, pad
97
+ )
98
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
99
+ return self.conv(x).contiguous()
100
+
101
+ def weight_norm(self, name="weight", dim=0):
102
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
103
+ return self
104
+
105
+ def remove_parametrizations(self, name="weight"):
106
+ self.conv = remove_parametrizations(self.conv, name)
107
+ return self
108
+
109
+
110
+ class FishTransConvNet(nn.Module):
111
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
112
+ super(FishTransConvNet, self).__init__()
113
+ self.conv = nn.ConvTranspose1d(
114
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
115
+ )
116
+ self.stride = stride
117
+ self.kernel_size = kernel_size
118
+
119
+ def forward(self, x):
120
+ x = self.conv(x)
121
+ pad = self.kernel_size - self.stride
122
+ padding_right = math.ceil(pad)
123
+ padding_left = pad - padding_right
124
+ x = unpad1d(x, (padding_left, padding_right))
125
+ return x.contiguous()
126
+
127
+ def weight_norm(self, name="weight", dim=0):
128
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
129
+ return self
130
+
131
+ def remove_parametrizations(self, name="weight"):
132
+ self.conv = remove_parametrizations(self.conv, name)
133
+ return self
134
+
135
+
136
+ class ResBlock1(torch.nn.Module):
137
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
138
+ super().__init__()
139
+
140
+ self.convs1 = nn.ModuleList(
141
+ [
142
+ FishConvNet(
143
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
144
+ ).weight_norm(),
145
+ FishConvNet(
146
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
147
+ ).weight_norm(),
148
+ FishConvNet(
149
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
150
+ ).weight_norm(),
151
+ ]
152
+ )
153
+ self.convs1.apply(init_weights)
154
+
155
+ self.convs2 = nn.ModuleList(
156
+ [
157
+ FishConvNet(
158
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
159
+ ).weight_norm(),
160
+ FishConvNet(
161
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
162
+ ).weight_norm(),
163
+ FishConvNet(
164
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
165
+ ).weight_norm(),
166
+ ]
167
+ )
168
+ self.convs2.apply(init_weights)
169
+
170
+ def forward(self, x):
171
+ for c1, c2 in zip(self.convs1, self.convs2):
172
+ xt = F.silu(x)
173
+ xt = c1(xt)
174
+ xt = F.silu(xt)
175
+ xt = c2(xt)
176
+ x = xt + x
177
+ return x
178
+
179
+ def remove_parametrizations(self):
180
+ for conv in self.convs1:
181
+ conv.remove_parametrizations()
182
+ for conv in self.convs2:
183
+ conv.remove_parametrizations()
184
+
185
+
186
+ class ParallelBlock(nn.Module):
187
+ def __init__(
188
+ self,
189
+ channels: int,
190
+ kernel_sizes: tuple[int] = (3, 7, 11),
191
+ dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
192
+ ):
193
+ super().__init__()
194
+
195
+ assert len(kernel_sizes) == len(dilation_sizes)
196
+
197
+ self.blocks = nn.ModuleList()
198
+ for k, d in zip(kernel_sizes, dilation_sizes):
199
+ self.blocks.append(ResBlock1(channels, k, d))
200
+
201
+ def forward(self, x):
202
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
203
+
204
+ def remove_parametrizations(self):
205
+ for block in self.blocks:
206
+ block.remove_parametrizations()
207
+
208
+
209
+ class HiFiGANGenerator(nn.Module):
210
+ def __init__(
211
+ self,
212
+ *,
213
+ hop_length: int = 512,
214
+ upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
215
+ upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
216
+ resblock_kernel_sizes: tuple[int] = (3, 7, 11),
217
+ resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
218
+ num_mels: int = 128,
219
+ upsample_initial_channel: int = 512,
220
+ pre_conv_kernel_size: int = 7,
221
+ post_conv_kernel_size: int = 7,
222
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
223
+ ):
224
+ super().__init__()
225
+
226
+ assert (
227
+ prod(upsample_rates) == hop_length
228
+ ), f"hop_length must be {prod(upsample_rates)}"
229
+
230
+ self.conv_pre = FishConvNet(
231
+ num_mels,
232
+ upsample_initial_channel,
233
+ pre_conv_kernel_size,
234
+ stride=1,
235
+ ).weight_norm()
236
+
237
+ self.num_upsamples = len(upsample_rates)
238
+ self.num_kernels = len(resblock_kernel_sizes)
239
+
240
+ self.noise_convs = nn.ModuleList()
241
+ self.ups = nn.ModuleList()
242
+
243
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
244
+ self.ups.append(
245
+ FishTransConvNet(
246
+ upsample_initial_channel // (2**i),
247
+ upsample_initial_channel // (2 ** (i + 1)),
248
+ k,
249
+ stride=u,
250
+ ).weight_norm()
251
+ )
252
+
253
+ self.resblocks = nn.ModuleList()
254
+ for i in range(len(self.ups)):
255
+ ch = upsample_initial_channel // (2 ** (i + 1))
256
+ self.resblocks.append(
257
+ ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
258
+ )
259
+
260
+ self.activation_post = post_activation()
261
+ self.conv_post = FishConvNet(
262
+ ch, 1, post_conv_kernel_size, stride=1
263
+ ).weight_norm()
264
+ self.ups.apply(init_weights)
265
+ self.conv_post.apply(init_weights)
266
+
267
+ def forward(self, x):
268
+ x = self.conv_pre(x)
269
+
270
+ for i in range(self.num_upsamples):
271
+ x = F.silu(x, inplace=True)
272
+ x = self.ups[i](x)
273
+
274
+ if self.training and self.checkpointing:
275
+ x = checkpoint(
276
+ self.resblocks[i],
277
+ x,
278
+ use_reentrant=False,
279
+ )
280
+ else:
281
+ x = self.resblocks[i](x)
282
+
283
+ x = self.activation_post(x)
284
+ x = self.conv_post(x)
285
+ x = torch.tanh(x)
286
+
287
+ return x
288
+
289
+ def remove_parametrizations(self):
290
+ for up in self.ups:
291
+ up.remove_parametrizations()
292
+ for block in self.resblocks:
293
+ block.remove_parametrizations()
294
+ self.conv_pre.remove_parametrizations()
295
+ self.conv_post.remove_parametrizations()
296
+
297
+
298
+ # DropPath copied from timm library
299
+ def drop_path(
300
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
301
+ ):
302
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
303
+
304
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
305
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
306
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
307
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
308
+ 'survival rate' as the argument.
309
+
310
+ """ # noqa: E501
311
+
312
+ if drop_prob == 0.0 or not training:
313
+ return x
314
+ keep_prob = 1 - drop_prob
315
+ shape = (x.shape[0],) + (1,) * (
316
+ x.ndim - 1
317
+ ) # work with diff dim tensors, not just 2D ConvNets
318
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
319
+ if keep_prob > 0.0 and scale_by_keep:
320
+ random_tensor.div_(keep_prob)
321
+ return x * random_tensor
322
+
323
+
324
+ class DropPath(nn.Module):
325
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
326
+
327
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
328
+ super(DropPath, self).__init__()
329
+ self.drop_prob = drop_prob
330
+ self.scale_by_keep = scale_by_keep
331
+
332
+ def forward(self, x):
333
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
334
+
335
+ def extra_repr(self):
336
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
337
+
338
+
339
+ class LayerNorm(nn.Module):
340
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
341
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
342
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
343
+ with shape (batch_size, channels, height, width).
344
+ """ # noqa: E501
345
+
346
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
347
+ super().__init__()
348
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
349
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
350
+ self.eps = eps
351
+ self.data_format = data_format
352
+ if self.data_format not in ["channels_last", "channels_first"]:
353
+ raise NotImplementedError
354
+ self.normalized_shape = (normalized_shape,)
355
+
356
+ def forward(self, x):
357
+ if self.data_format == "channels_last":
358
+ return F.layer_norm(
359
+ x, self.normalized_shape, self.weight, self.bias, self.eps
360
+ )
361
+ elif self.data_format == "channels_first":
362
+ u = x.mean(1, keepdim=True)
363
+ s = (x - u).pow(2).mean(1, keepdim=True)
364
+ x = (x - u) / torch.sqrt(s + self.eps)
365
+ x = self.weight[:, None] * x + self.bias[:, None]
366
+ return x
367
+
368
+
369
+ # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
370
+ class ConvNeXtBlock(nn.Module):
371
+ r"""ConvNeXt Block. There are two equivalent implementations:
372
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
373
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
374
+ We use (2) as we find it slightly faster in PyTorch
375
+
376
+ Args:
377
+ dim (int): Number of input channels.
378
+ drop_path (float): Stochastic depth rate. Default: 0.0
379
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
380
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
381
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
382
+ dilation (int): Dilation for depthwise conv. Default: 1.
383
+ """ # noqa: E501
384
+
385
+ def __init__(
386
+ self,
387
+ dim: int,
388
+ drop_path: float = 0.0,
389
+ layer_scale_init_value: float = 1e-6,
390
+ mlp_ratio: float = 4.0,
391
+ kernel_size: int = 7,
392
+ dilation: int = 1,
393
+ ):
394
+ super().__init__()
395
+
396
+ self.dwconv = FishConvNet(
397
+ dim,
398
+ dim,
399
+ kernel_size=kernel_size,
400
+ # padding=int(dilation * (kernel_size - 1) / 2),
401
+ groups=dim,
402
+ ) # depthwise conv
403
+ self.norm = LayerNorm(dim, eps=1e-6)
404
+ self.pwconv1 = nn.Linear(
405
+ dim, int(mlp_ratio * dim)
406
+ ) # pointwise/1x1 convs, implemented with linear layers
407
+ self.act = nn.GELU()
408
+ self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
409
+ self.gamma = (
410
+ nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
411
+ if layer_scale_init_value > 0
412
+ else None
413
+ )
414
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
415
+
416
+ def forward(self, x, apply_residual: bool = True):
417
+ input = x
418
+
419
+ x = self.dwconv(x)
420
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
421
+ x = self.norm(x)
422
+ x = self.pwconv1(x)
423
+ x = self.act(x)
424
+ x = self.pwconv2(x)
425
+
426
+ if self.gamma is not None:
427
+ x = self.gamma * x
428
+
429
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
430
+ x = self.drop_path(x)
431
+
432
+ if apply_residual:
433
+ x = input + x
434
+
435
+ return x
436
+
437
+
438
+ class ConvNeXtEncoder(nn.Module):
439
+ def __init__(
440
+ self,
441
+ input_channels: int = 3,
442
+ depths: list[int] = [3, 3, 9, 3],
443
+ dims: list[int] = [96, 192, 384, 768],
444
+ drop_path_rate: float = 0.0,
445
+ layer_scale_init_value: float = 1e-6,
446
+ kernel_size: int = 7,
447
+ ):
448
+ super().__init__()
449
+ assert len(depths) == len(dims)
450
+
451
+ self.downsample_layers = nn.ModuleList()
452
+ stem = nn.Sequential(
453
+ FishConvNet(
454
+ input_channels,
455
+ dims[0],
456
+ kernel_size=7,
457
+ # padding=3,
458
+ # padding_mode="replicate",
459
+ # padding_mode="zeros",
460
+ ),
461
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
462
+ )
463
+ self.downsample_layers.append(stem)
464
+
465
+ for i in range(len(depths) - 1):
466
+ mid_layer = nn.Sequential(
467
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
468
+ nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
469
+ )
470
+ self.downsample_layers.append(mid_layer)
471
+
472
+ self.stages = nn.ModuleList()
473
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
474
+
475
+ cur = 0
476
+ for i in range(len(depths)):
477
+ stage = nn.Sequential(
478
+ *[
479
+ ConvNeXtBlock(
480
+ dim=dims[i],
481
+ drop_path=dp_rates[cur + j],
482
+ layer_scale_init_value=layer_scale_init_value,
483
+ kernel_size=kernel_size,
484
+ )
485
+ for j in range(depths[i])
486
+ ]
487
+ )
488
+ self.stages.append(stage)
489
+ cur += depths[i]
490
+
491
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
492
+ self.apply(self._init_weights)
493
+
494
+ def _init_weights(self, m):
495
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
496
+ nn.init.trunc_normal_(m.weight, std=0.02)
497
+ nn.init.constant_(m.bias, 0)
498
+
499
+ def forward(
500
+ self,
501
+ x: torch.Tensor,
502
+ ) -> torch.Tensor:
503
+ for i in range(len(self.downsample_layers)):
504
+ x = self.downsample_layers[i](x)
505
+ x = self.stages[i](x)
506
+
507
+ return self.norm(x)
508
+
509
+
510
+ class FireflyArchitecture(nn.Module):
511
+ def __init__(
512
+ self,
513
+ backbone: nn.Module,
514
+ head: nn.Module,
515
+ quantizer: nn.Module,
516
+ spec_transform: nn.Module,
517
+ ):
518
+ super().__init__()
519
+
520
+ self.backbone = backbone
521
+ self.head = head
522
+ self.quantizer = quantizer
523
+ self.spec_transform = spec_transform
524
+ self.downsample_factor = math.prod(self.quantizer.downsample_factor)
525
+
526
+ def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
527
+ if self.spec_transform is not None:
528
+ x = self.spec_transform(x)
529
+
530
+ x = self.backbone(x)
531
+ if mask is not None:
532
+ x = x * mask
533
+
534
+ if self.quantizer is not None:
535
+ vq_result = self.quantizer(x)
536
+ x = vq_result.z
537
+
538
+ if mask is not None:
539
+ x = x * mask
540
+
541
+ x = self.head(x, template=template)
542
+
543
+ if x.ndim == 2:
544
+ x = x[:, None, :]
545
+
546
+ if self.vq is not None:
547
+ return x, vq_result
548
+
549
+ return x
550
+
551
+ def encode(self, audios, audio_lengths):
552
+ audios = audios.float()
553
+
554
+ mels = self.spec_transform(audios)
555
+ mel_lengths = audio_lengths // self.spec_transform.hop_length
556
+ mel_masks = sequence_mask(mel_lengths, mels.shape[2])
557
+ mel_masks_float_conv = mel_masks[:, None, :].float()
558
+ mels = mels * mel_masks_float_conv
559
+
560
+ # Encode
561
+ encoded_features = self.backbone(mels) * mel_masks_float_conv
562
+ feature_lengths = mel_lengths // self.downsample_factor
563
+
564
+ return self.quantizer.encode(encoded_features), feature_lengths
565
+
566
+ def decode(self, indices, feature_lengths) -> torch.Tensor:
567
+ mel_masks = sequence_mask(
568
+ feature_lengths * self.downsample_factor,
569
+ indices.shape[2] * self.downsample_factor,
570
+ )
571
+ mel_masks_float_conv = mel_masks[:, None, :].float()
572
+ audio_lengths = (
573
+ feature_lengths * self.downsample_factor * self.spec_transform.hop_length
574
+ )
575
+
576
+ audio_masks = sequence_mask(
577
+ audio_lengths,
578
+ indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
579
+ )
580
+ audio_masks_float_conv = audio_masks[:, None, :].float()
581
+
582
+ z = self.quantizer.decode(indices) * mel_masks_float_conv
583
+ x = self.head(z) * audio_masks_float_conv
584
+
585
+ return x, audio_lengths
586
+
587
+ def remove_parametrizations(self):
588
+ if hasattr(self.backbone, "remove_parametrizations"):
589
+ self.backbone.remove_parametrizations()
590
+
591
+ if hasattr(self.head, "remove_parametrizations"):
592
+ self.head.remove_parametrizations()
593
+
594
+ @property
595
+ def device(self):
596
+ return next(self.parameters()).device
fish_speech/models/vqgan/modules/fsq.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from vector_quantize_pytorch import GroupedResidualFSQ
8
+
9
+ from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
10
+
11
+
12
+ @dataclass
13
+ class FSQResult:
14
+ z: torch.Tensor
15
+ codes: torch.Tensor
16
+ latents: torch.Tensor
17
+
18
+
19
+ class DownsampleFiniteScalarQuantize(nn.Module):
20
+ def __init__(
21
+ self,
22
+ input_dim: int = 512,
23
+ n_codebooks: int = 9,
24
+ n_groups: int = 1,
25
+ levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
+ downsample_factor: tuple[int] = (2, 2),
27
+ downsample_dims: tuple[int] | None = None,
28
+ ):
29
+ super().__init__()
30
+
31
+ if downsample_dims is None:
32
+ downsample_dims = [input_dim for _ in range(len(downsample_factor))]
33
+
34
+ all_dims = (input_dim,) + tuple(downsample_dims)
35
+
36
+ self.residual_fsq = GroupedResidualFSQ(
37
+ dim=all_dims[-1],
38
+ levels=levels,
39
+ num_quantizers=n_codebooks,
40
+ groups=n_groups,
41
+ )
42
+
43
+ self.downsample_factor = downsample_factor
44
+ self.downsample_dims = downsample_dims
45
+
46
+ self.downsample = nn.Sequential(
47
+ *[
48
+ nn.Sequential(
49
+ FishConvNet(
50
+ all_dims[idx],
51
+ all_dims[idx + 1],
52
+ kernel_size=factor,
53
+ stride=factor,
54
+ ),
55
+ ConvNeXtBlock(dim=all_dims[idx + 1]),
56
+ )
57
+ for idx, factor in enumerate(downsample_factor)
58
+ ]
59
+ )
60
+
61
+ self.upsample = nn.Sequential(
62
+ *[
63
+ nn.Sequential(
64
+ FishTransConvNet(
65
+ all_dims[idx + 1],
66
+ all_dims[idx],
67
+ kernel_size=factor,
68
+ stride=factor,
69
+ ),
70
+ ConvNeXtBlock(dim=all_dims[idx]),
71
+ )
72
+ for idx, factor in reversed(list(enumerate(downsample_factor)))
73
+ ]
74
+ )
75
+
76
+ self.apply(self._init_weights)
77
+
78
+ def _init_weights(self, m):
79
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
80
+ nn.init.trunc_normal_(m.weight, std=0.02)
81
+ nn.init.constant_(m.bias, 0)
82
+
83
+ def forward(self, z) -> FSQResult:
84
+ original_shape = z.shape
85
+ z = self.downsample(z)
86
+ quantized, indices = self.residual_fsq(z.mT)
87
+ result = FSQResult(
88
+ z=quantized.mT,
89
+ codes=indices.mT,
90
+ latents=z,
91
+ )
92
+ result.z = self.upsample(result.z)
93
+
94
+ # Pad or crop z to match original shape
95
+ diff = original_shape[-1] - result.z.shape[-1]
96
+ left = diff // 2
97
+ right = diff - left
98
+
99
+ if diff > 0:
100
+ result.z = F.pad(result.z, (left, right))
101
+ elif diff < 0:
102
+ result.z = result.z[..., left:-right]
103
+
104
+ return result
105
+
106
+ def encode(self, z):
107
+ z = self.downsample(z)
108
+ _, indices = self.residual_fsq(z.mT)
109
+ indices = rearrange(indices, "g b l r -> b (g r) l")
110
+ return indices
111
+
112
+ def decode(self, indices: torch.Tensor):
113
+ indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
114
+ z_q = self.residual_fsq.get_output_from_indices(indices)
115
+ z_q = self.upsample(z_q.mT)
116
+ return z_q
fish_speech/models/vqgan/utils.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import torch
3
+ from matplotlib import pyplot as plt
4
+
5
+ matplotlib.use("Agg")
6
+
7
+
8
+ def convert_pad_shape(pad_shape):
9
+ l = pad_shape[::-1]
10
+ pad_shape = [item for sublist in l for item in sublist]
11
+ return pad_shape
12
+
13
+
14
+ def sequence_mask(length, max_length=None):
15
+ if max_length is None:
16
+ max_length = length.max()
17
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
+ return x.unsqueeze(0) < length.unsqueeze(1)
19
+
20
+
21
+ def init_weights(m, mean=0.0, std=0.01):
22
+ classname = m.__class__.__name__
23
+ if classname.find("Conv") != -1:
24
+ m.weight.data.normal_(mean, std)
25
+
26
+
27
+ def get_padding(kernel_size, dilation=1):
28
+ return int((kernel_size * dilation - dilation) / 2)
29
+
30
+
31
+ def plot_mel(data, titles=None):
32
+ fig, axes = plt.subplots(len(data), 1, squeeze=False)
33
+
34
+ if titles is None:
35
+ titles = [None for i in range(len(data))]
36
+
37
+ plt.tight_layout()
38
+
39
+ for i in range(len(data)):
40
+ mel = data[i]
41
+
42
+ if isinstance(mel, torch.Tensor):
43
+ mel = mel.float().detach().cpu().numpy()
44
+
45
+ axes[i][0].imshow(mel, origin="lower")
46
+ axes[i][0].set_aspect(2.5, adjustable="box")
47
+ axes[i][0].set_ylim(0, mel.shape[0])
48
+ axes[i][0].set_title(titles[i], fontsize="medium")
49
+ axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
50
+ axes[i][0].set_anchor("W")
51
+
52
+ return fig
53
+
54
+
55
+ def slice_segments(x, ids_str, segment_size=4):
56
+ ret = torch.zeros_like(x[:, :, :segment_size])
57
+ for i in range(x.size(0)):
58
+ idx_str = ids_str[i]
59
+ idx_end = idx_str + segment_size
60
+ ret[i] = x[i, :, idx_str:idx_end]
61
+
62
+ return ret
63
+
64
+
65
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
66
+ b, d, t = x.size()
67
+ if x_lengths is None:
68
+ x_lengths = t
69
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
70
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
71
+ ret = slice_segments(x, ids_str, segment_size)
72
+ return ret, ids_str
73
+
74
+
75
+ @torch.jit.script
76
+ def fused_add_tanh_sigmoid_multiply(in_act, n_channels):
77
+ n_channels_int = n_channels[0]
78
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
79
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
80
+ acts = t_act * s_act
81
+
82
+ return acts
83
+
84
+
85
+ def avg_with_mask(x, mask):
86
+ assert mask.dtype == torch.float, "Mask should be float"
87
+
88
+ if mask.ndim == 2:
89
+ mask = mask.unsqueeze(1)
90
+
91
+ if mask.shape[1] == 1:
92
+ mask = mask.expand_as(x)
93
+
94
+ return (x * mask).sum() / mask.sum()
fish_speech/scheduler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ def get_cosine_schedule_with_warmup_lr_lambda(
5
+ current_step: int,
6
+ *,
7
+ num_warmup_steps: int | float,
8
+ num_training_steps: int,
9
+ num_cycles: float = 0.5,
10
+ final_lr_ratio: float = 0.0,
11
+ ):
12
+ if 0 < num_warmup_steps < 1: # float mode
13
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
14
+
15
+ if current_step < num_warmup_steps:
16
+ return float(current_step) / float(max(1, num_warmup_steps))
17
+
18
+ progress = float(current_step - num_warmup_steps) / float(
19
+ max(1, num_training_steps - num_warmup_steps)
20
+ )
21
+
22
+ return max(
23
+ final_lr_ratio,
24
+ 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
25
+ )
26
+
27
+
28
+ def get_constant_schedule_with_warmup_lr_lambda(
29
+ current_step: int,
30
+ *,
31
+ num_warmup_steps: int | float,
32
+ num_training_steps: int | None = None,
33
+ ):
34
+ if 0 < num_warmup_steps < 1: # float mode
35
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
36
+
37
+ if current_step < num_warmup_steps:
38
+ return float(current_step) / float(max(1, num_warmup_steps))
39
+
40
+ return 1.0
fish_speech/text/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .clean import clean_text
2
+ from .spliter import split_text
3
+
4
+ __all__ = ["clean_text", "split_text"]
fish_speech/text/chn_text_norm/.gitignore ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ # Usually these files are written by a python script from a template
30
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Installer logs
35
+ pip-log.txt
36
+ pip-delete-this-directory.txt
37
+
38
+ # Unit test / coverage reports
39
+ htmlcov/
40
+ .tox/
41
+ .coverage
42
+ .coverage.*
43
+ .cache
44
+ nosetests.xml
45
+ coverage.xml
46
+ *.cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # pyenv
76
+ .python-version
77
+
78
+ # celery beat schedule file
79
+ celerybeat-schedule
80
+
81
+ # SageMath parsed files
82
+ *.sage.py
83
+
84
+ # Environments
85
+ .env
86
+ .venv
87
+ env/
88
+ venv/
89
+ ENV/
90
+ env.bak/
91
+ venv.bak/
92
+
93
+ # Spyder project settings
94
+ .spyderproject
95
+ .spyproject
96
+
97
+ # Rope project settings
98
+ .ropeproject
99
+
100
+ # mkdocs documentation
101
+ /site
102
+
103
+ # mypy
104
+ .mypy_cache/
105
+
106
+ # JetBrains PyCharm
107
+ .idea
108
+
109
+ # Customize
110
+ references
111
+ url.txt
112
+
113
+ # Git
114
+ .git
fish_speech/text/chn_text_norm/README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This account is no longer in use, see [Atomicoo](https://github.com/atomicoo) for my latest works.
2
+
3
+ # Chn Text Norm
4
+
5
+ this is a repository for chinese text normalization (no longer maintained).
6
+
7
+ ## Quick Start ##
8
+
9
+ ### Git Clone Repo ###
10
+
11
+ git clone this repo to the root directory of your project which need to use it.
12
+
13
+ cd /path/to/proj
14
+ git clone https://github.com/Joee1995/chn-text-norm.git
15
+
16
+ after that, your doc tree should be:
17
+ ```
18
+ proj # root of your project
19
+ |--- chn_text_norm # this chn-text-norm tool
20
+ |--- text.py
21
+ |--- ...
22
+ |--- text_normalize.py # your text normalization code
23
+ |--- ...
24
+ ```
25
+
26
+ ### How to Use ? ###
27
+
28
+ # text_normalize.py
29
+ from chn_text_norm.text import *
30
+
31
+ raw_text = 'your raw text'
32
+ text = Text(raw_text=raw_text).normalize()
33
+
34
+ ### How to add quantums ###
35
+
36
+ 打开test.py,然后你就知道怎么做了。
fish_speech/text/chn_text_norm/__init__.py ADDED
File without changes
fish_speech/text/chn_text_norm/basic_class.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """基本类
3
+ 中文字符类
4
+ 中文数字/数位类
5
+ 中文数字类
6
+ 中文数位类
7
+ 中文数字系统类
8
+ 中文数学符号类
9
+ *中文其他符号类
10
+ """
11
+
12
+ __author__ = "Zhiyang Zhou <[email protected]>"
13
+ __data__ = "2019-05-02"
14
+
15
+ from fish_speech.text.chn_text_norm.basic_constant import NUMBERING_TYPES
16
+
17
+
18
+ class ChineseChar(object):
19
+ """
20
+ 中文字符
21
+ 每个字符对应简体和繁体,
22
+ e.g. 简体 = '负', 繁体 = '負'
23
+ 转换时可转换为简体或繁体
24
+ """
25
+
26
+ def __init__(self, simplified, traditional):
27
+ self.simplified = simplified
28
+ self.traditional = traditional
29
+ self.__repr__ = self.__str__
30
+
31
+ def __str__(self):
32
+ return self.simplified or self.traditional or None
33
+
34
+ def __repr__(self):
35
+ return self.__str__()
36
+
37
+
38
+ class ChineseNumberUnit(ChineseChar):
39
+ """
40
+ 中文数字/数位字符
41
+ 每个字符除繁简体外还有一个额外的大写字符
42
+ e.g. '陆' 和 '陸'
43
+ """
44
+
45
+ def __init__(self, power, simplified, traditional, big_s, big_t):
46
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
47
+ self.power = power
48
+ self.big_s = big_s
49
+ self.big_t = big_t
50
+
51
+ def __str__(self):
52
+ return "10^{}".format(self.power)
53
+
54
+ @classmethod
55
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
56
+
57
+ if small_unit:
58
+ return ChineseNumberUnit(
59
+ power=index + 1,
60
+ simplified=value[0],
61
+ traditional=value[1],
62
+ big_s=value[1],
63
+ big_t=value[1],
64
+ )
65
+ elif numbering_type == NUMBERING_TYPES[0]:
66
+ return ChineseNumberUnit(
67
+ power=index + 8,
68
+ simplified=value[0],
69
+ traditional=value[1],
70
+ big_s=value[0],
71
+ big_t=value[1],
72
+ )
73
+ elif numbering_type == NUMBERING_TYPES[1]:
74
+ return ChineseNumberUnit(
75
+ power=(index + 2) * 4,
76
+ simplified=value[0],
77
+ traditional=value[1],
78
+ big_s=value[0],
79
+ big_t=value[1],
80
+ )
81
+ elif numbering_type == NUMBERING_TYPES[2]:
82
+ return ChineseNumberUnit(
83
+ power=pow(2, index + 3),
84
+ simplified=value[0],
85
+ traditional=value[1],
86
+ big_s=value[0],
87
+ big_t=value[1],
88
+ )
89
+ else:
90
+ raise ValueError(
91
+ "Counting type should be in {0} ({1} provided).".format(
92
+ NUMBERING_TYPES, numbering_type
93
+ )
94
+ )
95
+
96
+
97
+ class ChineseNumberDigit(ChineseChar):
98
+ """
99
+ 中文数字字符
100
+ """
101
+
102
+ def __init__(
103
+ self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None
104
+ ):
105
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
106
+ self.value = value
107
+ self.big_s = big_s
108
+ self.big_t = big_t
109
+ self.alt_s = alt_s
110
+ self.alt_t = alt_t
111
+
112
+ def __str__(self):
113
+ return str(self.value)
114
+
115
+ @classmethod
116
+ def create(cls, i, v):
117
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
118
+
119
+
120
+ class ChineseMath(ChineseChar):
121
+ """
122
+ 中文数位字符
123
+ """
124
+
125
+ def __init__(self, simplified, traditional, symbol, expression=None):
126
+ super(ChineseMath, self).__init__(simplified, traditional)
127
+ self.symbol = symbol
128
+ self.expression = expression
129
+ self.big_s = simplified
130
+ self.big_t = traditional
131
+
132
+
133
+ CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
134
+
135
+
136
+ class NumberSystem(object):
137
+ """
138
+ 中文数字系统
139
+ """
140
+
141
+ pass
142
+
143
+
144
+ class MathSymbol(object):
145
+ """
146
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
147
+ positive = ['正', '正']
148
+ negative = ['负', '負']
149
+ point = ['点', '點']
150
+ """
151
+
152
+ def __init__(self, positive, negative, point):
153
+ self.positive = positive
154
+ self.negative = negative
155
+ self.point = point
156
+
157
+ def __iter__(self):
158
+ for v in self.__dict__.values():
159
+ yield v
160
+
161
+
162
+ # class OtherSymbol(object):
163
+ # """
164
+ # 其他符号
165
+ # """
166
+ #
167
+ # def __init__(self, sil):
168
+ # self.sil = sil
169
+ #
170
+ # def __iter__(self):
171
+ # for v in self.__dict__.values():
172
+ # yield v
fish_speech/text/chn_text_norm/basic_constant.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """基本常量
3
+ 中文数字/数位/符号字符常量
4
+ """
5
+
6
+ __author__ = "Zhiyang Zhou <[email protected]>"
7
+ __data__ = "2019-05-02"
8
+
9
+ CHINESE_DIGIS = "零一二三四五六七八九"
10
+ BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
11
+ BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
12
+ SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
13
+ SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
14
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
15
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
16
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
17
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
18
+
19
+ ZERO_ALT = "〇"
20
+ ONE_ALT = "幺"
21
+ TWO_ALTS = ["两", "兩"]
22
+
23
+ POSITIVE = ["正", "正"]
24
+ NEGATIVE = ["负", "負"]
25
+ POINT = ["点", "點"]
26
+ # PLUS = [u'加', u'加']
27
+ # SIL = [u'杠', u'槓']
28
+
29
+ # 中文数字系统类型
30
+ NUMBERING_TYPES = ["low", "mid", "high"]
fish_speech/text/chn_text_norm/basic_util.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """基本方法
3
+ 创建中文数字系统 方法
4
+ 中文字符串 <=> 数字串 方法
5
+ 数字串 <=> 中文字符串 方法
6
+ """
7
+
8
+ __author__ = "Zhiyang Zhou <[email protected]>"
9
+ __data__ = "2019-05-02"
10
+
11
+ from fish_speech.text.chn_text_norm.basic_class import *
12
+ from fish_speech.text.chn_text_norm.basic_constant import *
13
+
14
+
15
+ def create_system(numbering_type=NUMBERING_TYPES[1]):
16
+ """
17
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
18
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
19
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
20
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
21
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
22
+ 返回对应的数字系统
23
+ """
24
+
25
+ # chinese number units of '亿' and larger
26
+ all_larger_units = zip(
27
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
28
+ LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL,
29
+ )
30
+ larger_units = [
31
+ CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)
32
+ ]
33
+ # chinese number units of '十, 百, 千, 万'
34
+ all_smaller_units = zip(
35
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED,
36
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL,
37
+ )
38
+ smaller_units = [
39
+ CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)
40
+ ]
41
+ # digis
42
+ chinese_digis = zip(
43
+ CHINESE_DIGIS,
44
+ CHINESE_DIGIS,
45
+ BIG_CHINESE_DIGIS_SIMPLIFIED,
46
+ BIG_CHINESE_DIGIS_TRADITIONAL,
47
+ )
48
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
49
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
50
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
51
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
52
+
53
+ # symbols
54
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
55
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
56
+ point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
57
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
58
+ system = NumberSystem()
59
+ system.units = smaller_units + larger_units
60
+ system.digits = digits
61
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
62
+ # system.symbols = OtherSymbol(sil_cn)
63
+ return system
64
+
65
+
66
+ def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
67
+
68
+ def get_symbol(char, system):
69
+ for u in system.units:
70
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
71
+ return u
72
+ for d in system.digits:
73
+ if char in [
74
+ d.traditional,
75
+ d.simplified,
76
+ d.big_s,
77
+ d.big_t,
78
+ d.alt_s,
79
+ d.alt_t,
80
+ ]:
81
+ return d
82
+ for m in system.math:
83
+ if char in [m.traditional, m.simplified]:
84
+ return m
85
+
86
+ def string2symbols(chinese_string, system):
87
+ int_string, dec_string = chinese_string, ""
88
+ for p in [system.math.point.simplified, system.math.point.traditional]:
89
+ if p in chinese_string:
90
+ int_string, dec_string = chinese_string.split(p)
91
+ break
92
+ return [get_symbol(c, system) for c in int_string], [
93
+ get_symbol(c, system) for c in dec_string
94
+ ]
95
+
96
+ def correct_symbols(integer_symbols, system):
97
+ """
98
+ 一百八 to 一百八十
99
+ 一亿一千三百万 to 一亿 一千万 三百万
100
+ """
101
+
102
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
103
+ if integer_symbols[0].power == 1:
104
+ integer_symbols = [system.digits[1]] + integer_symbols
105
+
106
+ if len(integer_symbols) > 1:
107
+ if isinstance(integer_symbols[-1], CND) and isinstance(
108
+ integer_symbols[-2], CNU
109
+ ):
110
+ integer_symbols.append(
111
+ CNU(integer_symbols[-2].power - 1, None, None, None, None)
112
+ )
113
+
114
+ result = []
115
+ unit_count = 0
116
+ for s in integer_symbols:
117
+ if isinstance(s, CND):
118
+ result.append(s)
119
+ unit_count = 0
120
+ elif isinstance(s, CNU):
121
+ current_unit = CNU(s.power, None, None, None, None)
122
+ unit_count += 1
123
+
124
+ if unit_count == 1:
125
+ result.append(current_unit)
126
+ elif unit_count > 1:
127
+ for i in range(len(result)):
128
+ if (
129
+ isinstance(result[-i - 1], CNU)
130
+ and result[-i - 1].power < current_unit.power
131
+ ):
132
+ result[-i - 1] = CNU(
133
+ result[-i - 1].power + current_unit.power,
134
+ None,
135
+ None,
136
+ None,
137
+ None,
138
+ )
139
+ return result
140
+
141
+ def compute_value(integer_symbols):
142
+ """
143
+ Compute the value.
144
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
145
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
146
+ """
147
+ value = [0]
148
+ last_power = 0
149
+ for s in integer_symbols:
150
+ if isinstance(s, CND):
151
+ value[-1] = s.value
152
+ elif isinstance(s, CNU):
153
+ value[-1] *= pow(10, s.power)
154
+ if s.power > last_power:
155
+ value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
156
+ last_power = s.power
157
+ value.append(0)
158
+ return sum(value)
159
+
160
+ system = create_system(numbering_type)
161
+ int_part, dec_part = string2symbols(chinese_string, system)
162
+ int_part = correct_symbols(int_part, system)
163
+ int_str = str(compute_value(int_part))
164
+ dec_str = "".join([str(d.value) for d in dec_part])
165
+ if dec_part:
166
+ return "{0}.{1}".format(int_str, dec_str)
167
+ else:
168
+ return int_str
169
+
170
+
171
+ def num2chn(
172
+ number_string,
173
+ numbering_type=NUMBERING_TYPES[1],
174
+ big=False,
175
+ traditional=False,
176
+ alt_zero=False,
177
+ alt_one=False,
178
+ alt_two=True,
179
+ use_zeros=True,
180
+ use_units=True,
181
+ ):
182
+
183
+ def get_value(value_string, use_zeros=True):
184
+
185
+ striped_string = value_string.lstrip("0")
186
+
187
+ # record nothing if all zeros
188
+ if not striped_string:
189
+ return []
190
+
191
+ # record one digits
192
+ elif len(striped_string) == 1:
193
+ if use_zeros and len(value_string) != len(striped_string):
194
+ return [system.digits[0], system.digits[int(striped_string)]]
195
+ else:
196
+ return [system.digits[int(striped_string)]]
197
+
198
+ # recursively record multiple digits
199
+ else:
200
+ result_unit = next(
201
+ u for u in reversed(system.units) if u.power < len(striped_string)
202
+ )
203
+ result_string = value_string[: -result_unit.power]
204
+ return (
205
+ get_value(result_string)
206
+ + [result_unit]
207
+ + get_value(striped_string[-result_unit.power :])
208
+ )
209
+
210
+ system = create_system(numbering_type)
211
+
212
+ int_dec = number_string.split(".")
213
+ if len(int_dec) == 1:
214
+ int_string = int_dec[0]
215
+ dec_string = ""
216
+ elif len(int_dec) == 2:
217
+ int_string = int_dec[0]
218
+ dec_string = int_dec[1]
219
+ else:
220
+ raise ValueError(
221
+ "invalid input num string with more than one dot: {}".format(number_string)
222
+ )
223
+
224
+ if use_units and len(int_string) > 1:
225
+ result_symbols = get_value(int_string)
226
+ else:
227
+ result_symbols = [system.digits[int(c)] for c in int_string]
228
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
229
+ if dec_string:
230
+ result_symbols += [system.math.point] + dec_symbols
231
+
232
+ if alt_two:
233
+ liang = CND(
234
+ 2,
235
+ system.digits[2].alt_s,
236
+ system.digits[2].alt_t,
237
+ system.digits[2].big_s,
238
+ system.digits[2].big_t,
239
+ )
240
+ for i, v in enumerate(result_symbols):
241
+ if isinstance(v, CND) and v.value == 2:
242
+ next_symbol = (
243
+ result_symbols[i + 1] if i < len(result_symbols) - 1 else None
244
+ )
245
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
246
+ if isinstance(next_symbol, CNU) and isinstance(
247
+ previous_symbol, (CNU, type(None))
248
+ ):
249
+ if next_symbol.power != 1 and (
250
+ (previous_symbol is None) or (previous_symbol.power != 1)
251
+ ):
252
+ result_symbols[i] = liang
253
+
254
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
255
+ if big:
256
+ attr_name = "big_"
257
+ if traditional:
258
+ attr_name += "t"
259
+ else:
260
+ attr_name += "s"
261
+ else:
262
+ if traditional:
263
+ attr_name = "traditional"
264
+ else:
265
+ attr_name = "simplified"
266
+
267
+ result = "".join([getattr(s, attr_name) for s in result_symbols])
268
+
269
+ # if not use_zeros:
270
+ # result = result.strip(getattr(system.digits[0], attr_name))
271
+
272
+ if alt_zero:
273
+ result = result.replace(
274
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s
275
+ )
276
+
277
+ if alt_one:
278
+ result = result.replace(
279
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s
280
+ )
281
+
282
+ for i, p in enumerate(POINT):
283
+ if result.startswith(p):
284
+ return CHINESE_DIGIS[0] + result
285
+
286
+ # ^10, 11, .., 19
287
+ if (
288
+ len(result) >= 2
289
+ and result[1]
290
+ in [
291
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
292
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
293
+ ]
294
+ and result[0]
295
+ in [
296
+ CHINESE_DIGIS[1],
297
+ BIG_CHINESE_DIGIS_SIMPLIFIED[1],
298
+ BIG_CHINESE_DIGIS_TRADITIONAL[1],
299
+ ]
300
+ ):
301
+ result = result[1:]
302
+
303
+ return result
304
+
305
+
306
+ if __name__ == "__main__":
307
+
308
+ # 测试程序
309
+ all_chinese_number_string = (
310
+ CHINESE_DIGIS
311
+ + BIG_CHINESE_DIGIS_SIMPLIFIED
312
+ + BIG_CHINESE_DIGIS_TRADITIONAL
313
+ + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED
314
+ + LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
315
+ + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED
316
+ + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
317
+ + ZERO_ALT
318
+ + ONE_ALT
319
+ + "".join(TWO_ALTS + POSITIVE + NEGATIVE + POINT)
320
+ )
321
+
322
+ print("num:", chn2num("一万零四百零三点八零五"))
323
+ print("num:", chn2num("一亿六点三"))
324
+ print("num:", chn2num("一亿零六点三"))
325
+ print("num:", chn2num("两千零一亿六点三"))
326
+ # print('num:', chn2num('一零零八六'))
327
+ print("txt:", num2chn("10260.03", alt_zero=True))
328
+ print("txt:", num2chn("20037.090", numbering_type="low", traditional=True))
329
+ print("txt:", num2chn("100860001.77", numbering_type="high", big=True))
330
+ print(
331
+ "txt:",
332
+ num2chn(
333
+ "059523810880",
334
+ alt_one=True,
335
+ alt_two=False,
336
+ use_lzeros=True,
337
+ use_rzeros=True,
338
+ use_units=False,
339
+ ),
340
+ )
341
+
342
+ print(all_chinese_number_string)
fish_speech/text/chn_text_norm/cardinal.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """CARDINAL类 (包含小数DECIMAL类)
3
+ 纯数 <=> 中文字符串 方法
4
+ 中文字符串 <=> 纯数 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Cardinal:
14
+ """
15
+ CARDINAL类
16
+ """
17
+
18
+ def __init__(self, cardinal=None, chntext=None):
19
+ self.cardinal = cardinal
20
+ self.chntext = chntext
21
+
22
+ def chntext2cardinal(self):
23
+ return chn2num(self.chntext)
24
+
25
+ def cardinal2chntext(self):
26
+ return num2chn(self.cardinal)
27
+
28
+
29
+ if __name__ == "__main__":
30
+
31
+ # 测试程序
32
+ print(Cardinal(cardinal="21357.230").cardinal2chntext())
fish_speech/text/chn_text_norm/date.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """DATE类
3
+ 日期 <=> 中文字符串 方法
4
+ 中文字符串 <=> 日期 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-07"
9
+
10
+ from fish_speech.text.chn_text_norm.cardinal import Cardinal
11
+ from fish_speech.text.chn_text_norm.digit import Digit
12
+
13
+
14
+ class Date:
15
+ """
16
+ DATE类
17
+ """
18
+
19
+ def __init__(self, date=None, chntext=None):
20
+ self.date = date
21
+ self.chntext = chntext
22
+
23
+ # def chntext2date(self):
24
+ # chntext = self.chntext
25
+ # try:
26
+ # year, other = chntext.strip().split('年', maxsplit=1)
27
+ # year = Digit(chntext=year).digit2chntext() + '年'
28
+ # except ValueError:
29
+ # other = chntext
30
+ # year = ''
31
+ # if other:
32
+ # try:
33
+ # month, day = other.strip().split('月', maxsplit=1)
34
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
35
+ # except ValueError:
36
+ # day = chntext
37
+ # month = ''
38
+ # if day:
39
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
40
+ # else:
41
+ # month = ''
42
+ # day = ''
43
+ # date = year + month + day
44
+ # self.date = date
45
+ # return self.date
46
+
47
+ def date2chntext(self):
48
+ date = self.date
49
+ try:
50
+ year, other = date.strip().split("年", maxsplit=1)
51
+ year = Digit(digit=year).digit2chntext() + "年"
52
+ except ValueError:
53
+ other = date
54
+ year = ""
55
+ if other:
56
+ try:
57
+ month, day = other.strip().split("月", maxsplit=1)
58
+ month = Cardinal(cardinal=month).cardinal2chntext() + "月"
59
+ except ValueError:
60
+ day = date
61
+ month = ""
62
+ if day:
63
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
64
+ else:
65
+ month = ""
66
+ day = ""
67
+ chntext = year + month + day
68
+ self.chntext = chntext
69
+ return self.chntext
70
+
71
+
72
+ if __name__ == "__main__":
73
+
74
+ # 测试
75
+ print(Date(date="09年3月16日").date2chntext())
fish_speech/text/chn_text_norm/digit.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """DIGIT类
3
+ 数字串 <=> 中文字符串 方法
4
+ 中文字符串 <=> 数字串 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Digit:
14
+ """
15
+ DIGIT类
16
+ """
17
+
18
+ def __init__(self, digit=None, chntext=None):
19
+ self.digit = digit
20
+ self.chntext = chntext
21
+
22
+ # def chntext2digit(self):
23
+ # return chn2num(self.chntext)
24
+
25
+ def digit2chntext(self):
26
+ return num2chn(self.digit, alt_two=False, use_units=False)
27
+
28
+
29
+ if __name__ == "__main__":
30
+
31
+ # 测试程序
32
+ print(Digit(digit="2016").digit2chntext())
fish_speech/text/chn_text_norm/fraction.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """FRACTION类
3
+ 分数 <=> 中文字符串 方法
4
+ 中文字符串 <=> 分数 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Fraction:
14
+ """
15
+ FRACTION类
16
+ """
17
+
18
+ def __init__(self, fraction=None, chntext=None):
19
+ self.fraction = fraction
20
+ self.chntext = chntext
21
+
22
+ def chntext2fraction(self):
23
+ denominator, numerator = self.chntext.split("分之")
24
+ return chn2num(numerator) + "/" + chn2num(denominator)
25
+
26
+ def fraction2chntext(self):
27
+ numerator, denominator = self.fraction.split("/")
28
+ return num2chn(denominator) + "分之" + num2chn(numerator)
29
+
30
+
31
+ if __name__ == "__main__":
32
+
33
+ # 测试程序
34
+ print(Fraction(fraction="2135/7230").fraction2chntext())
35
+ print(Fraction(chntext="五百八十一分之三百六十九").chntext2fraction())
fish_speech/text/chn_text_norm/money.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """MONEY类
3
+ 金钱 <=> 中文字符串 方法
4
+ 中文字符串 <=> 金钱 方法
5
+ """
6
+ import re
7
+
8
+ __author__ = "Zhiyang Zhou <[email protected]>"
9
+ __data__ = "2019-05-08"
10
+
11
+ from fish_speech.text.chn_text_norm.cardinal import Cardinal
12
+
13
+
14
+ class Money:
15
+ """
16
+ MONEY类
17
+ """
18
+
19
+ def __init__(self, money=None, chntext=None):
20
+ self.money = money
21
+ self.chntext = chntext
22
+
23
+ # def chntext2money(self):
24
+ # return self.money
25
+
26
+ def money2chntext(self):
27
+ money = self.money
28
+ pattern = re.compile(r"(\d+(\.\d+)?)")
29
+ matchers = pattern.findall(money)
30
+ if matchers:
31
+ for matcher in matchers:
32
+ money = money.replace(
33
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()
34
+ )
35
+ self.chntext = money
36
+ return self.chntext
37
+
38
+
39
+ if __name__ == "__main__":
40
+
41
+ # 测试
42
+ print(Money(money="21.5万元").money2chntext())
43
+ print(Money(money="230块5毛").money2chntext())
fish_speech/text/chn_text_norm/percentage.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """PERCENTAGE类
3
+ 百分数 <=> 中文字符串 方法
4
+ 中文字符串 <=> 百分数 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-06"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class Percentage:
14
+ """
15
+ PERCENTAGE类
16
+ """
17
+
18
+ def __init__(self, percentage=None, chntext=None):
19
+ self.percentage = percentage
20
+ self.chntext = chntext
21
+
22
+ def chntext2percentage(self):
23
+ return chn2num(self.chntext.strip().strip("百分之")) + "%"
24
+
25
+ def percentage2chntext(self):
26
+ return "百分之" + num2chn(self.percentage.strip().strip("%"))
27
+
28
+
29
+ if __name__ == "__main__":
30
+
31
+ # 测试程序
32
+ print(Percentage(chntext="百分之五十六点零三").chntext2percentage())
33
+ print(Percentage(percentage="65.3%").percentage2chntext())
fish_speech/text/chn_text_norm/telephone.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """TELEPHONE类
3
+ 电话号码 <=> 中文字符串 方法
4
+ 中文字符串 <=> 电话号码 方法
5
+ """
6
+
7
+ __author__ = "Zhiyang Zhou <[email protected]>"
8
+ __data__ = "2019-05-03"
9
+
10
+ from fish_speech.text.chn_text_norm.basic_util import *
11
+
12
+
13
+ class TelePhone:
14
+ """
15
+ TELEPHONE类
16
+ """
17
+
18
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
19
+ self.telephone = telephone
20
+ self.raw_chntext = raw_chntext
21
+ self.chntext = chntext
22
+
23
+ # def chntext2telephone(self):
24
+ # sil_parts = self.raw_chntext.split('<SIL>')
25
+ # self.telephone = '-'.join([
26
+ # str(chn2num(p)) for p in sil_parts
27
+ # ])
28
+ # return self.telephone
29
+
30
+ def telephone2chntext(self, fixed=False):
31
+
32
+ if fixed:
33
+ sil_parts = self.telephone.split("-")
34
+ self.raw_chntext = "<SIL>".join(
35
+ [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
36
+ )
37
+ self.chntext = self.raw_chntext.replace("<SIL>", "")
38
+ else:
39
+ sp_parts = self.telephone.strip("+").split()
40
+ self.raw_chntext = "<SP>".join(
41
+ [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
42
+ )
43
+ self.chntext = self.raw_chntext.replace("<SP>", "")
44
+ return self.chntext
45
+
46
+
47
+ if __name__ == "__main__":
48
+
49
+ # 测试程序
50
+ print(TelePhone(telephone="0595-23980880").telephone2chntext())
51
+ # print(TelePhone(raw_chntext='零五九五杠二三八六五零九八').chntext2telephone())
fish_speech/text/chn_text_norm/text.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ TEXT类
4
+ """
5
+
6
+ __author__ = "Zhiyang Zhou <[email protected]>"
7
+ __data__ = "2019-05-03"
8
+
9
+ import re
10
+
11
+ from fish_speech.text.chn_text_norm.cardinal import Cardinal
12
+ from fish_speech.text.chn_text_norm.date import Date
13
+ from fish_speech.text.chn_text_norm.digit import Digit
14
+ from fish_speech.text.chn_text_norm.fraction import Fraction
15
+ from fish_speech.text.chn_text_norm.money import Money
16
+ from fish_speech.text.chn_text_norm.percentage import Percentage
17
+ from fish_speech.text.chn_text_norm.telephone import TelePhone
18
+
19
+ CURRENCY_NAMES = (
20
+ "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
21
+ "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
22
+ )
23
+ CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
24
+ COM_QUANTIFIERS = (
25
+ "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
26
+ "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
27
+ "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
28
+ "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
29
+ "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
30
+ "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块|人|抽)"
31
+ )
32
+
33
+
34
+ class Text:
35
+ """
36
+ Text类
37
+ """
38
+
39
+ def __init__(self, raw_text, norm_text=None):
40
+ self.raw_text = "^" + raw_text + "$"
41
+ self.norm_text = norm_text
42
+
43
+ def _particular(self):
44
+ text = self.norm_text
45
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
46
+ matchers = pattern.findall(text)
47
+ if matchers:
48
+ # print('particular')
49
+ for matcher in matchers:
50
+ text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
51
+ self.norm_text = text
52
+ return self.norm_text
53
+
54
+ def normalize(self):
55
+ text = self.raw_text
56
+
57
+ # 规范化日期
58
+ pattern = re.compile(
59
+ r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)"
60
+ )
61
+ matchers = pattern.findall(text)
62
+ if matchers:
63
+ # print('date')
64
+ for matcher in matchers:
65
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
66
+
67
+ # 规范化金钱
68
+ pattern = re.compile(
69
+ r"\D+((\d+(\.\d+)?)[多余几]?"
70
+ + CURRENCY_UNITS
71
+ + "(\d"
72
+ + CURRENCY_UNITS
73
+ + "?)?)"
74
+ )
75
+ matchers = pattern.findall(text)
76
+ if matchers:
77
+ # print('money')
78
+ for matcher in matchers:
79
+ text = text.replace(
80
+ matcher[0], Money(money=matcher[0]).money2chntext(), 1
81
+ )
82
+
83
+ # 规范化固话/手机号码
84
+ # 手机
85
+ # http://www.jihaoba.com/news/show/13680
86
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
87
+ # 联通:130、131、132、156、155、186、185、176
88
+ # 电信:133、153、189、180、181、177
89
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
90
+ matchers = pattern.findall(text)
91
+ if matchers:
92
+ # print('telephone')
93
+ for matcher in matchers:
94
+ text = text.replace(
95
+ matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
96
+ )
97
+ # 固话
98
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
99
+ matchers = pattern.findall(text)
100
+ if matchers:
101
+ # print('fixed telephone')
102
+ for matcher in matchers:
103
+ text = text.replace(
104
+ matcher[0],
105
+ TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True),
106
+ 1,
107
+ )
108
+
109
+ # 规范化分数
110
+ pattern = re.compile(r"(\d+/\d+)")
111
+ matchers = pattern.findall(text)
112
+ if matchers:
113
+ # print('fraction')
114
+ for matcher in matchers:
115
+ text = text.replace(
116
+ matcher, Fraction(fraction=matcher).fraction2chntext(), 1
117
+ )
118
+
119
+ # 规范化百分数
120
+ text = text.replace("%", "%")
121
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
122
+ matchers = pattern.findall(text)
123
+ if matchers:
124
+ # print('percentage')
125
+ for matcher in matchers:
126
+ text = text.replace(
127
+ matcher[0],
128
+ Percentage(percentage=matcher[0]).percentage2chntext(),
129
+ 1,
130
+ )
131
+
132
+ # 规范化纯数+量词
133
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
134
+ matchers = pattern.findall(text)
135
+ if matchers:
136
+ # print('cardinal+quantifier')
137
+ for matcher in matchers:
138
+ text = text.replace(
139
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
140
+ )
141
+
142
+ # 规范化数字编号
143
+ pattern = re.compile(r"(\d{4,32})")
144
+ matchers = pattern.findall(text)
145
+ if matchers:
146
+ # print('digit')
147
+ for matcher in matchers:
148
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
149
+
150
+ # 规范化纯数
151
+ pattern = re.compile(r"(\d+(\.\d+)?)")
152
+ matchers = pattern.findall(text)
153
+ if matchers:
154
+ # print('cardinal')
155
+ for matcher in matchers:
156
+ text = text.replace(
157
+ matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1
158
+ )
159
+
160
+ self.norm_text = text
161
+ self._particular()
162
+
163
+ return self.norm_text.lstrip("^").rstrip("$")
164
+
165
+
166
+ if __name__ == "__main__":
167
+
168
+ # 测试程序
169
+ print(Text(raw_text="固话:0595-23865596或23880880。").normalize())
170
+ print(Text(raw_text="手机:+86 19859213959或15659451527。").normalize())
171
+ print(Text(raw_text="分数:32477/76391。").normalize())
172
+ print(Text(raw_text="百分数:80.03%。").normalize())
173
+ print(Text(raw_text="编号:31520181154418。").normalize())
174
+ print(Text(raw_text="纯数:2983.07克或12345.60米。").normalize())
175
+ print(Text(raw_text="日期:1999年2月20日或09年3月15号。").normalize())
176
+ print(Text(raw_text="金钱:12块5,34.5元,20.1万").normalize())
177
+ print(Text(raw_text="特殊:O2O或B2C。").normalize())
fish_speech/text/clean.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ SYMBOLS_MAPPING = {
4
+ "\n": "",
5
+ "…": ".",
6
+ "“": "'",
7
+ "”": "'",
8
+ "‘": "'",
9
+ "’": "'",
10
+ "【": "",
11
+ "】": "",
12
+ "[": "",
13
+ "]": "",
14
+ "(": "",
15
+ ")": "",
16
+ "(": "",
17
+ ")": "",
18
+ "・": "",
19
+ "·": "",
20
+ "「": "'",
21
+ "」": "'",
22
+ "《": "'",
23
+ "》": "'",
24
+ "—": "",
25
+ "~": "",
26
+ "~": "",
27
+ ":": ",",
28
+ ";": ",",
29
+ ";": ",",
30
+ ":": ",",
31
+ }
32
+
33
+ REPLACE_SYMBOL_REGEX = re.compile(
34
+ "|".join(re.escape(p) for p in SYMBOLS_MAPPING.keys())
35
+ )
36
+
37
+
38
+ EMOJI_REGEX = re.compile(
39
+ "["
40
+ "\U0001F600-\U0001F64F" # emoticons
41
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
42
+ "\U0001F680-\U0001F6FF" # transport & map symbols
43
+ "\U0001F1E0-\U0001F1FF" # flags (iOS)
44
+ "]+",
45
+ flags=re.UNICODE,
46
+ )
47
+
48
+
49
+ def clean_text(text):
50
+ # Clean the text
51
+ text = text.strip()
52
+
53
+ # Replace all chinese symbols with their english counterparts
54
+ text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
55
+
56
+ # Remove emojis
57
+ text = EMOJI_REGEX.sub(r"", text)
58
+
59
+ # Remove continuous periods (...) and commas (,,,)
60
+ text = re.sub(r"[.,]{2,}", lambda m: m.group()[0], text)
61
+
62
+ return text