KdaiP commited on
Commit
3dd84f8
1 Parent(s): 015e033

Upload 80 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ audios/4.wav filter=lfs diff=lfs merge=lfs -text
api.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from dataclasses import asdict
5
+
6
+ from utils.audio import LogMelSpectrogram
7
+ from config import ModelConfig, MelConfig
8
+ from models.model import StableTTS
9
+
10
+ from text import symbols
11
+ from text import cleaned_text_to_sequence
12
+ from text.mandarin import chinese_to_cnm3
13
+ from text.english import english_to_ipa2
14
+ from text.japanese import japanese_to_ipa2
15
+
16
+
17
+ from datas.dataset import intersperse
18
+ from utils.audio import load_and_resample_audio
19
+
20
+ def get_vocoder(model_path, model_name='ffgan') -> nn.Module:
21
+ if model_name == 'ffgan':
22
+ # training or changing ffgan config is not supported in this repo
23
+ # you can train your own model at https://github.com/fishaudio/vocoder
24
+ from vocoders.ffgan.model import FireflyGANBaseWrapper
25
+ vocoder = FireflyGANBaseWrapper(model_path)
26
+
27
+ elif model_name == 'vocos':
28
+ from vocoders.vocos.models.model import Vocos
29
+ from config import VocosConfig, MelConfig
30
+ vocoder = Vocos(VocosConfig(), MelConfig())
31
+ vocoder.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu'))
32
+ vocoder.eval()
33
+
34
+ else:
35
+ raise NotImplementedError(f"Unsupported model: {model_name}")
36
+
37
+ return vocoder
38
+
39
+ class StableTTSAPI(nn.Module):
40
+ def __init__(self, tts_model_path, vocoder_model_path, vocoder_name='ffgan'):
41
+ super().__init__()
42
+
43
+ self.mel_config = MelConfig()
44
+ self.tts_model_config = ModelConfig()
45
+
46
+ self.mel_extractor = LogMelSpectrogram(**asdict(self.mel_config))
47
+
48
+ # text to mel spectrogram
49
+ self.tts_model = StableTTS(len(symbols), self.mel_config.n_mels, **asdict(self.tts_model_config))
50
+ self.tts_model.load_state_dict(torch.load(tts_model_path, map_location='cpu', weights_only=True))
51
+ self.tts_model.eval()
52
+
53
+ # mel spectrogram to waveform
54
+ self.vocoder_model = get_vocoder(vocoder_model_path, vocoder_name)
55
+ self.vocoder_model.eval()
56
+
57
+ self.g2p_mapping = {
58
+ 'chinese': chinese_to_cnm3,
59
+ 'japanese': japanese_to_ipa2,
60
+ 'english': english_to_ipa2,
61
+ }
62
+ self.supported_languages = self.g2p_mapping.keys()
63
+
64
+ @ torch.inference_mode()
65
+ def inference(self, text, ref_audio, language, step, temperature=1.0, length_scale=1.0, solver=None, cfg=3.0):
66
+ device = next(self.parameters()).device
67
+ phonemizer = self.g2p_mapping.get(language)
68
+
69
+ text = phonemizer(text)
70
+ text = torch.tensor(intersperse(cleaned_text_to_sequence(text), item=0), dtype=torch.long, device=device).unsqueeze(0)
71
+ text_length = torch.tensor([text.size(-1)], dtype=torch.long, device=device)
72
+
73
+ ref_audio = load_and_resample_audio(ref_audio, self.mel_config.sample_rate).to(device)
74
+ ref_audio = self.mel_extractor(ref_audio)
75
+
76
+ mel_output = self.tts_model.synthesise(text, text_length, step, temperature, ref_audio, length_scale, solver, cfg)['decoder_outputs']
77
+ audio_output = self.vocoder_model(mel_output)
78
+ return audio_output.cpu(), mel_output.cpu()
79
+
80
+ def get_params(self):
81
+ tts_param = sum(p.numel() for p in self.tts_model.parameters()) / 1e6
82
+ vocoder_param = sum(p.numel() for p in self.vocoder_model.parameters()) / 1e6
83
+ return tts_param, vocoder_param
84
+
85
+ if __name__ == '__main__':
86
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
87
+ tts_model_path = './checkpoints/checkpoint_0.pt'
88
+ vocoder_model_path = './vocoders/pretrained/vocos.pt'
89
+
90
+ model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos')
91
+ model.to(device)
92
+
93
+ text = '樱落满殇祈念集……殇歌花落集思祈……樱花满地集于我心……揲舞纷飞祈愿相随……'
94
+ audio = './audio_1.wav'
95
+
96
+ audio_output, mel_output = model.inference(text, audio, 'chinese', 10, solver='dopri5', cfg=3)
97
+ print(audio_output.shape)
98
+ print(mel_output.shape)
99
+
100
+ import torchaudio
101
+ torchaudio.save('output.wav', audio_output, MelConfig().sample_rate)
102
+
103
+
audios/1.wav ADDED
Binary file (374 kB). View file
 
audios/2.wav ADDED
Binary file (182 kB). View file
 
audios/3.wav ADDED
Binary file (529 kB). View file
 
audios/4.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6672b81d7dd41cac56cf49b75bb66a5486b5fe969ddab0f98f14b05be7857df
3
+ size 1349150
audios/5.wav ADDED
Binary file (368 kB). View file
 
audios/6.wav ADDED
Binary file (431 kB). View file
 
audios/7.wav ADDED
Binary file (514 kB). View file
 
audios/8.wav ADDED
Binary file (420 kB). View file
 
checkpoints/.keep ADDED
File without changes
checkpoints/checkpoint_0.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b041bea13241b402bbfcdbfffd14381774be1179bae78e99ebd505d6d89f9367
3
+ size 126657600
config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class MelConfig:
5
+ sample_rate: int = 44100
6
+ n_fft: int = 2048
7
+ win_length: int = 2048
8
+ hop_length: int = 512
9
+ f_min: float = 0.0
10
+ f_max: float = None
11
+ pad: int = 0
12
+ n_mels: int = 128
13
+ center: bool = False
14
+ pad_mode: str = "reflect"
15
+ mel_scale: str = "slaney"
16
+
17
+ def __post_init__(self):
18
+ if self.pad == 0:
19
+ self.pad = (self.n_fft - self.hop_length) // 2
20
+
21
+ @dataclass
22
+ class ModelConfig:
23
+ hidden_channels: int = 256
24
+ filter_channels: int = 1024
25
+ n_heads: int = 4
26
+ n_enc_layers: int = 3
27
+ n_dec_layers: int = 6
28
+ kernel_size: int = 3
29
+ p_dropout: int = 0.1
30
+ gin_channels: int = 256
31
+
32
+ @dataclass
33
+ class TrainConfig:
34
+ train_dataset_path: str = 'filelists/filelist.json'
35
+ test_dataset_path: str = 'filelists/filelist.json' # not used
36
+ batch_size: int = 32
37
+ learning_rate: float = 1e-4
38
+ num_epochs: int = 10000
39
+ model_save_path: str = './checkpoints'
40
+ log_dir: str = './runs'
41
+ log_interval: int = 16
42
+ save_interval: int = 1
43
+ warmup_steps: int = 200
44
+
45
+ @dataclass
46
+ class VocosConfig:
47
+ input_channels: int = 128
48
+ dim: int = 512
49
+ intermediate_dim: int = 1536
50
+ num_layers: int = 8
datas/__init__.py ADDED
File without changes
datas/dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import json
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ from text import cleaned_text_to_sequence
9
+
10
+ def intersperse(lst: list, item: int):
11
+ """
12
+ putting a blank token between any two input tokens to improve pronunciation
13
+ see https://github.com/jaywalnut310/glow-tts/issues/43 for more details
14
+ """
15
+ result = [item] * (len(lst) * 2 + 1)
16
+ result[1::2] = lst
17
+ return result
18
+
19
+ class StableDataset(Dataset):
20
+ def __init__(self, filelist_path, hop_length):
21
+ self.filelist_path = filelist_path
22
+ self.hop_length = hop_length
23
+
24
+ self._load_filelist(filelist_path)
25
+
26
+ def _load_filelist(self, filelist_path):
27
+ filelist, lengths = [], []
28
+ with open(filelist_path, 'r', encoding='utf-8') as f:
29
+ for line in f:
30
+ line = json.loads(line.strip())
31
+ filelist.append((line['mel_path'], line['phone']))
32
+ lengths.append(line['mel_length'])
33
+
34
+ self.filelist = filelist
35
+ self.lengths = lengths # length is used for DistributedBucketSampler
36
+
37
+ def __len__(self):
38
+ return len(self.filelist)
39
+
40
+ def __getitem__(self, idx):
41
+ mel_path, phone = self.filelist[idx]
42
+ mel = torch.load(mel_path, map_location='cpu', weights_only=True)
43
+ phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
44
+ return mel, phone
45
+
46
+ def collate_fn(batch):
47
+ texts = [item[1] for item in batch]
48
+ mels = [item[0] for item in batch]
49
+ mels_sliced = [random_slice_tensor(mel) for mel in mels]
50
+
51
+ text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
52
+ mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
53
+ mels_sliced_lengths = torch.tensor([mel_sliced.size(-1) for mel_sliced in mels_sliced], dtype=torch.long)
54
+
55
+ # pad to the same length
56
+ texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
57
+ mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
58
+ mels_sliced_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels_sliced), padding=0)
59
+
60
+ return texts_padded, text_lengths, mels_padded, mel_lengths, mels_sliced_padded, mels_sliced_lengths
61
+
62
+ # random slice mel for reference encoder to prevent overfitting
63
+ def random_slice_tensor(x: torch.Tensor):
64
+ length = x.size(-1)
65
+ if length < 8:
66
+ return x
67
+ segmnt_size = random.randint(length // 12, length // 3)
68
+ start = random.randint(0, length - segmnt_size)
69
+ return x[..., start : start + segmnt_size]
datas/sampler.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py
4
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
5
+ """
6
+ Maintain similar input lengths in a batch.
7
+ Length groups are specified by boundaries.
8
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
9
+
10
+ It removes samples which are not included in the boundaries.
11
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ dataset,
17
+ batch_size,
18
+ boundaries,
19
+ num_replicas=None,
20
+ rank=None,
21
+ shuffle=True,
22
+ ):
23
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
24
+ self.lengths = dataset.lengths
25
+ self.batch_size = batch_size
26
+ self.boundaries = boundaries
27
+
28
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
29
+ self.total_size = sum(self.num_samples_per_bucket)
30
+ self.num_samples = self.total_size // self.num_replicas
31
+
32
+ def _create_buckets(self):
33
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
34
+ for i in range(len(self.lengths)):
35
+ length = self.lengths[i]
36
+ idx_bucket = self._bisect(length)
37
+ if idx_bucket != -1:
38
+ buckets[idx_bucket].append(i)
39
+
40
+ # from https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/data_utils.py
41
+ # avoid "integer division or modulo by zero" error for very small dataset
42
+ try:
43
+ for i in range(len(buckets) - 1, 0, -1):
44
+ if len(buckets[i]) == 0:
45
+ buckets.pop(i)
46
+ self.boundaries.pop(i + 1)
47
+ assert all(len(bucket) > 0 for bucket in buckets)
48
+ # When one bucket is not traversed
49
+ except Exception as e:
50
+ print('Bucket warning ', e)
51
+ for i in range(len(buckets) - 1, -1, -1):
52
+ if len(buckets[i]) == 0:
53
+ buckets.pop(i)
54
+ self.boundaries.pop(i + 1)
55
+
56
+ num_samples_per_bucket = []
57
+ for i in range(len(buckets)):
58
+ len_bucket = len(buckets[i])
59
+ total_batch_size = self.num_replicas * self.batch_size
60
+ rem = (
61
+ total_batch_size - (len_bucket % total_batch_size)
62
+ ) % total_batch_size
63
+ num_samples_per_bucket.append(len_bucket + rem)
64
+ return buckets, num_samples_per_bucket
65
+
66
+ def __iter__(self):
67
+ # deterministically shuffle based on epoch
68
+ g = torch.Generator()
69
+ g.manual_seed(self.epoch)
70
+
71
+ indices = []
72
+ if self.shuffle:
73
+ for bucket in self.buckets:
74
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
75
+ else:
76
+ for bucket in self.buckets:
77
+ indices.append(list(range(len(bucket))))
78
+
79
+ batches = []
80
+ for i in range(len(self.buckets)):
81
+ bucket = self.buckets[i]
82
+ len_bucket = len(bucket)
83
+ ids_bucket = indices[i]
84
+ num_samples_bucket = self.num_samples_per_bucket[i]
85
+
86
+ # add extra samples to make it evenly divisible
87
+ rem = num_samples_bucket - len_bucket
88
+ ids_bucket = (
89
+ ids_bucket
90
+ + ids_bucket * (rem // len_bucket)
91
+ + ids_bucket[: (rem % len_bucket)]
92
+ )
93
+
94
+ # subsample
95
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
96
+
97
+ # batching
98
+ for j in range(len(ids_bucket) // self.batch_size):
99
+ batch = [
100
+ bucket[idx]
101
+ for idx in ids_bucket[
102
+ j * self.batch_size : (j + 1) * self.batch_size
103
+ ]
104
+ ]
105
+ batches.append(batch)
106
+
107
+ if self.shuffle:
108
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
109
+ batches = [batches[i] for i in batch_ids]
110
+ self.batches = batches
111
+
112
+ assert len(self.batches) * self.batch_size == self.num_samples
113
+ return iter(self.batches)
114
+
115
+ def _bisect(self, x, lo=0, hi=None):
116
+ if hi is None:
117
+ hi = len(self.boundaries) - 1
118
+
119
+ if hi > lo:
120
+ mid = (hi + lo) // 2
121
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
122
+ return mid
123
+ elif x <= self.boundaries[mid]:
124
+ return self._bisect(x, lo, mid)
125
+ else:
126
+ return self._bisect(x, mid + 1, hi)
127
+ else:
128
+ return -1
129
+
130
+ def __len__(self):
131
+ return self.num_samples // self.batch_size
models/__init__.py ADDED
File without changes
models/diffusion_transformer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # References:
2
+ # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
3
+ # https://github.com/jaywalnut310/vits/blob/main/attentions.py
4
+ # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ class FFN(nn.Module):
11
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
12
+ super().__init__()
13
+ self.in_channels = in_channels
14
+ self.out_channels = out_channels
15
+ self.filter_channels = filter_channels
16
+ self.kernel_size = kernel_size
17
+ self.p_dropout = p_dropout
18
+ self.gin_channels = gin_channels
19
+
20
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
21
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
22
+ self.drop = nn.Dropout(p_dropout)
23
+ self.act1 = nn.SiLU(inplace=True)
24
+
25
+ def forward(self, x, x_mask):
26
+ x = self.conv_1(x * x_mask)
27
+ x = self.act1(x)
28
+ x = self.drop(x)
29
+ x = self.conv_2(x * x_mask)
30
+ return x * x_mask
31
+
32
+ class MultiHeadAttention(nn.Module):
33
+ def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
34
+ super().__init__()
35
+ assert channels % n_heads == 0
36
+
37
+ self.channels = channels
38
+ self.out_channels = out_channels
39
+ self.n_heads = n_heads
40
+ self.p_dropout = p_dropout
41
+
42
+ self.k_channels = channels // n_heads
43
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
44
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
45
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
46
+
47
+ # from https://nn.labml.ai/transformers/rope/index.html
48
+ self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
49
+ self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
50
+
51
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
52
+ self.drop = torch.nn.Dropout(p_dropout)
53
+
54
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
55
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
56
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
57
+
58
+ def forward(self, x, attn_mask=None):
59
+ q = self.conv_q(x)
60
+ k = self.conv_k(x)
61
+ v = self.conv_v(x)
62
+
63
+ x = self.attention(q, k, v, mask=attn_mask)
64
+
65
+ x = self.conv_o(x)
66
+ return x
67
+
68
+ def attention(self, query, key, value, mask=None):
69
+ b, d, t_s, t_t = (*key.size(), query.size(2))
70
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
71
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
72
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
73
+
74
+ query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
75
+ key = self.key_rotary_pe(key)
76
+
77
+ output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
78
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
79
+ return output
80
+
81
+ # modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
82
+ class DiTConVBlock(nn.Module):
83
+ """
84
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
85
+ """
86
+ def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
87
+ super().__init__()
88
+ self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False)
89
+ self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout)
90
+ self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False)
91
+ self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
92
+ self.adaLN_modulation = nn.Sequential(
93
+ nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(),
94
+ nn.SiLU(),
95
+ nn.Linear(hidden_channels, 6 * hidden_channels, bias=True)
96
+ )
97
+
98
+ def forward(self, x, c, x_mask):
99
+ """
100
+ Args:
101
+ x : [batch_size, channel, time]
102
+ c : [batch_size, channel]
103
+ x_mask : [batch_size, 1, time]
104
+ return the same shape as x
105
+ """
106
+ x = x * x_mask
107
+ attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
108
+ attn_mask = torch.zeros_like(attn_mask).masked_fill(attn_mask == 0, -torch.finfo(x.dtype).max)
109
+
110
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
111
+ x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask
112
+ x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask)
113
+
114
+ # no condition version
115
+ # x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask)
116
+ # x = x + self.mlp(self.norm2(x.transpose(1,2)).transpose(1,2), x_mask)
117
+ return x
118
+
119
+ @staticmethod
120
+ def modulate(x, shift, scale):
121
+ return x * (1 + scale) + shift
122
+
123
+ class RotaryPositionalEmbeddings(nn.Module):
124
+ """
125
+ ## RoPE module
126
+
127
+ Rotary encoding transforms pairs of features by rotating in the 2D plane.
128
+ That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
129
+ Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
130
+ by an angle depending on the position of the token.
131
+ """
132
+
133
+ def __init__(self, d: int, base: int = 10_000):
134
+ r"""
135
+ * `d` is the number of features $d$
136
+ * `base` is the constant used for calculating $\Theta$
137
+ """
138
+ super().__init__()
139
+
140
+ self.base = base
141
+ self.d = int(d)
142
+ self.cos_cached = None
143
+ self.sin_cached = None
144
+
145
+ def _build_cache(self, x: torch.Tensor):
146
+ r"""
147
+ Cache $\cos$ and $\sin$ values
148
+ """
149
+ # Return if cache is already built
150
+ if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
151
+ return
152
+
153
+ # Get sequence length
154
+ seq_len = x.shape[0]
155
+
156
+ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
157
+ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
158
+
159
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
160
+ seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
161
+
162
+ # Calculate the product of position index and $\theta_i$
163
+ idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
164
+
165
+ # Concatenate so that for row $m$ we have
166
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
167
+ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
168
+
169
+ # Cache them
170
+ self.cos_cached = idx_theta2.cos()[:, None, None, :]
171
+ self.sin_cached = idx_theta2.sin()[:, None, None, :]
172
+
173
+ def _neg_half(self, x: torch.Tensor):
174
+ # $\frac{d}{2}$
175
+ d_2 = self.d // 2
176
+
177
+ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
178
+ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
179
+
180
+ def forward(self, x: torch.Tensor):
181
+ """
182
+ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
183
+ """
184
+ # Cache $\cos$ and $\sin$ values
185
+ x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
186
+
187
+ self._build_cache(x)
188
+
189
+ # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
190
+ x_rope, x_pass = x[..., : self.d], x[..., self.d :]
191
+
192
+ # Calculate
193
+ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
194
+ neg_half_x = self._neg_half(x_rope)
195
+
196
+ x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
197
+
198
+ return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
199
+
200
+ class Transpose(nn.Identity):
201
+ """(N, T, D) -> (N, D, T)"""
202
+
203
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
204
+ return input.transpose(1, 2)
205
+
models/duration_predictor.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98
5
+ class DurationPredictor(nn.Module):
6
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
7
+ super().__init__()
8
+
9
+ self.in_channels = in_channels
10
+ self.filter_channels = filter_channels
11
+ self.kernel_size = kernel_size
12
+ self.p_dropout = p_dropout
13
+ self.gin_channels = gin_channels
14
+
15
+ self.drop = nn.Dropout(p_dropout)
16
+ self.conv1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
17
+ self.norm1 = nn.LayerNorm(filter_channels)
18
+ self.conv2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
19
+ self.norm2 = nn.LayerNorm(filter_channels)
20
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
21
+
22
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
23
+
24
+ def forward(self, x, x_mask, g):
25
+ x = x.detach()
26
+ x = x + self.cond(g.unsqueeze(2).detach())
27
+ x = self.conv1(x * x_mask)
28
+ x = torch.relu(x)
29
+ x = self.norm1(x.transpose(1,2)).transpose(1,2)
30
+ x = self.drop(x)
31
+ x = self.conv2(x * x_mask)
32
+ x = torch.relu(x)
33
+ x = self.norm2(x.transpose(1,2)).transpose(1,2)
34
+ x = self.drop(x)
35
+ x = self.proj(x * x_mask)
36
+ return x * x_mask
37
+
38
+ def duration_loss(logw, logw_, lengths):
39
+ loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
40
+ return loss
models/estimator.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from models.diffusion_transformer import DiTConVBlock
7
+
8
+ class DitWrapper(nn.Module):
9
+ """ add FiLM layer to condition time embedding to DiT """
10
+ def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0):
11
+ super().__init__()
12
+ self.time_fusion = FiLMLayer(hidden_channels, time_channels)
13
+ self.block = DiTConVBlock(hidden_channels, filter_channels, num_heads, kernel_size, p_dropout, gin_channels)
14
+
15
+ def forward(self, x, c, t, x_mask):
16
+ x = self.time_fusion(x, t) * x_mask
17
+ x = self.block(x, c, x_mask)
18
+ return x
19
+
20
+ class FiLMLayer(nn.Module):
21
+ """
22
+ Feature-wise Linear Modulation (FiLM) layer
23
+ Reference: https://arxiv.org/abs/1709.07871
24
+ """
25
+ def __init__(self, in_channels, cond_channels):
26
+
27
+ super(FiLMLayer, self).__init__()
28
+ self.in_channels = in_channels
29
+ self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
30
+
31
+ def forward(self, x, c):
32
+ gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
33
+ return gamma * x + beta
34
+
35
+ class SinusoidalPosEmb(nn.Module):
36
+ def __init__(self, dim):
37
+ super().__init__()
38
+ self.dim = dim
39
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
40
+
41
+ def forward(self, x, scale=1000):
42
+ if x.ndim < 1:
43
+ x = x.unsqueeze(0)
44
+ half_dim = self.dim // 2
45
+ emb = math.log(10000) / (half_dim - 1)
46
+ emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
47
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
48
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
49
+ return emb
50
+
51
+ class TimestepEmbedding(nn.Module):
52
+ def __init__(self, in_channels, out_channels, filter_channels):
53
+ super().__init__()
54
+
55
+ self.layer = nn.Sequential(
56
+ nn.Linear(in_channels, filter_channels),
57
+ nn.SiLU(inplace=True),
58
+ nn.Linear(filter_channels, out_channels)
59
+ )
60
+
61
+ def forward(self, x):
62
+ return self.layer(x)
63
+
64
+ # reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
65
+ class Decoder(nn.Module):
66
+ def __init__(self, noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, dropout=0.1, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0, use_lsc=True):
67
+ super().__init__()
68
+ self.noise_channels = noise_channels
69
+ self.cond_channels = cond_channels
70
+ self.hidden_channels = hidden_channels
71
+ self.out_channels = out_channels
72
+ self.filter_channels = filter_channels
73
+ self.use_lsc = use_lsc # whether to use unet-like long skip connection
74
+
75
+ self.time_embeddings = SinusoidalPosEmb(hidden_channels)
76
+ self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels)
77
+
78
+ self.in_proj = nn.Conv1d(hidden_channels + noise_channels, hidden_channels, 1) # cat noise and encoder output as input
79
+ self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)])
80
+ self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1)
81
+
82
+ # prenet for encoder output
83
+ self.cond_proj = nn.Sequential(
84
+ nn.Conv1d(cond_channels, filter_channels, kernel_size, padding=kernel_size//2),
85
+ nn.SiLU(inplace=True),
86
+ nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2), # add about 3M params
87
+ nn.SiLU(inplace=True),
88
+ nn.Conv1d(filter_channels, hidden_channels, kernel_size, padding=kernel_size//2)
89
+ )
90
+
91
+ if use_lsc:
92
+ assert n_layers % 2 == 0
93
+ self.n_lsc_layers = n_layers // 2
94
+ self.lsc_layers = nn.ModuleList([nn.Conv1d(hidden_channels + hidden_channels, hidden_channels, kernel_size, padding = kernel_size // 2) for _ in range(self.n_lsc_layers)])
95
+
96
+ self.initialize_weights()
97
+
98
+ def initialize_weights(self):
99
+ for block in self.blocks:
100
+ nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0)
101
+ nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0)
102
+
103
+ def forward(self, t, x, mask, mu, c):
104
+ """Forward pass of the DiT model.
105
+
106
+ Args:
107
+ t (torch.Tensor): timestep, shape (batch_size)
108
+ x (torch.Tensor): noise, shape (batch_size, in_channels, time)
109
+ mask (torch.Tensor): shape (batch_size, 1, time)
110
+ mu (torch.Tensor): output of encoder, shape (batch_size, in_channels, time)
111
+ c (torch.Tensor): shape (batch_size, gin_channels)
112
+
113
+ Returns:
114
+ _type_: _description_
115
+ """
116
+
117
+ t = self.time_mlp(self.time_embeddings(t))
118
+ mu = self.cond_proj(mu)
119
+
120
+ x = torch.cat((x, mu), dim=1)
121
+ x = self.in_proj(x)
122
+
123
+ lsc_outputs = [] if self.use_lsc else None
124
+
125
+ for idx, block in enumerate(self.blocks):
126
+ # add long skip connection, see https://arxiv.org/pdf/2209.12152 for more details
127
+ if self.use_lsc:
128
+ if idx < self.n_lsc_layers:
129
+ lsc_outputs.append(x)
130
+ else:
131
+ x = torch.cat((x, lsc_outputs.pop()), dim=1)
132
+ x = self.lsc_layers[idx - self.n_lsc_layers](x)
133
+
134
+ x = block(x, c, t, mask)
135
+
136
+ output = self.final_proj(x * mask)
137
+
138
+ return output * mask
models/flow_matching.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import functools
6
+ from torchdiffeq import odeint
7
+
8
+ from models.estimator import Decoder
9
+
10
+ # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py
11
+ class CFMDecoder(torch.nn.Module):
12
+ def __init__(self, noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
13
+ super().__init__()
14
+ self.noise_channels = noise_channels
15
+ self.cond_channels = cond_channels
16
+ self.hidden_channels = hidden_channels
17
+ self.out_channels = out_channels
18
+ self.filter_channels = filter_channels
19
+ self.gin_channels = gin_channels
20
+ self.sigma_min = 1e-4
21
+
22
+ self.estimator = Decoder(noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels)
23
+
24
+ @torch.inference_mode()
25
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None, solver=None, cfg_kwargs=None):
26
+ """Forward diffusion
27
+
28
+ Args:
29
+ mu (torch.Tensor): output of encoder
30
+ shape: (batch_size, n_feats, mel_timesteps)
31
+ mask (torch.Tensor): output_mask
32
+ shape: (batch_size, 1, mel_timesteps)
33
+ n_timesteps (int): number of diffusion steps
34
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
35
+ c (torch.Tensor, optional): speaker embedding
36
+ shape: (batch_size, gin_channels)
37
+ solver: see https://github.com/rtqichen/torchdiffeq for supported solvers
38
+ cfg_kwargs: used for cfg inference
39
+
40
+ Returns:
41
+ sample: generated mel-spectrogram
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ """
44
+
45
+ z = torch.randn_like(mu) * temperature
46
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
47
+
48
+ # cfg control
49
+ if cfg_kwargs is None:
50
+ estimator = functools.partial(self.estimator, mask=mask, mu=mu, c=c)
51
+ else:
52
+ estimator = functools.partial(self.cfg_wrapper, mask=mask, mu=mu, c=c, cfg_kwargs=cfg_kwargs)
53
+
54
+ trajectory = odeint(estimator, z, t_span, method=solver, rtol=1e-5, atol=1e-5)
55
+ return trajectory[-1]
56
+
57
+ # cfg inference
58
+ def cfg_wrapper(self, t, x, mask, mu, c, cfg_kwargs):
59
+ fake_speaker = cfg_kwargs['fake_speaker'].repeat(x.size(0), 1)
60
+ fake_content = cfg_kwargs['fake_content'].repeat(x.size(0), 1, x.size(-1))
61
+ cfg_strength = cfg_kwargs['cfg_strength']
62
+
63
+ cond_output = self.estimator(t, x, mask, mu, c)
64
+ uncond_output = self.estimator(t, x, mask, fake_content, fake_speaker)
65
+
66
+ output = uncond_output + cfg_strength * (cond_output - uncond_output)
67
+ return output
68
+
69
+ def compute_loss(self, x1, mask, mu, c):
70
+ """Computes diffusion loss
71
+
72
+ Args:
73
+ x1 (torch.Tensor): Target
74
+ shape: (batch_size, n_feats, mel_timesteps)
75
+ mask (torch.Tensor): target mask
76
+ shape: (batch_size, 1, mel_timesteps)
77
+ mu (torch.Tensor): output of encoder
78
+ shape: (batch_size, n_feats, mel_timesteps)
79
+ c (torch.Tensor, optional): speaker condition.
80
+
81
+ Returns:
82
+ loss: conditional flow matching loss
83
+ y: conditional flow
84
+ shape: (batch_size, n_feats, mel_timesteps)
85
+ """
86
+ b, _, t = mu.shape
87
+
88
+ # random timestep
89
+ # use cosine timestep scheduler from cosyvoice: https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/flow/flow_matching.py
90
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
91
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
92
+
93
+ # sample noise p(x_0)
94
+ z = torch.randn_like(x1)
95
+
96
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
97
+ u = x1 - (1 - self.sigma_min) * z
98
+
99
+ loss = F.mse_loss(self.estimator(t.squeeze(), y, mask, mu, c), u, reduction="sum") / (torch.sum(mask) * u.size(1))
100
+ return loss, y
models/model.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ import monotonic_align
6
+ from models.text_encoder import TextEncoder
7
+ from models.flow_matching import CFMDecoder
8
+ from models.reference_encoder import MelStyleEncoder
9
+ from models.duration_predictor import DurationPredictor, duration_loss
10
+ from utils.mask import sequence_mask
11
+
12
+ def convert_pad_shape(pad_shape):
13
+ inverted_shape = pad_shape[::-1]
14
+ pad_shape = [item for sublist in inverted_shape for item in sublist]
15
+ return pad_shape
16
+
17
+ def generate_path(duration, mask):
18
+ b, t_x, t_y = mask.shape
19
+ cum_duration = torch.cumsum(duration, 1)
20
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype, device=duration.device)
21
+
22
+ cum_duration_flat = cum_duration.view(b * t_x)
23
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
24
+ path = path.view(b, t_x, t_y)
25
+ path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
26
+ path = path * mask
27
+ return path
28
+
29
+ # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
30
+ class StableTTS(nn.Module):
31
+ def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
32
+ super().__init__()
33
+
34
+ self.n_vocab = n_vocab
35
+ self.mel_channels = mel_channels
36
+
37
+ self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
38
+ self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=5, dropout=0.25)
39
+ self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, 0.5, gin_channels)
40
+ self.decoder = CFMDecoder(mel_channels, mel_channels, hidden_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
41
+
42
+ # uncondition input for cfg
43
+ self.fake_speaker = nn.Parameter(torch.zeros(1, gin_channels))
44
+ self.fake_content = nn.Parameter(torch.zeros(1, mel_channels, 1))
45
+
46
+ self.cfg_dropout = 0.2
47
+
48
+ @torch.inference_mode()
49
+ def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0, solver=None, cfg=1.0):
50
+ """
51
+ Generates mel-spectrogram from text. Returns:
52
+ 1. encoder outputs
53
+ 2. decoder outputs
54
+ 3. generated alignment
55
+
56
+ Args:
57
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
58
+ shape: (batch_size, max_text_length)
59
+ x_lengths (torch.Tensor): lengths of texts in batch.
60
+ shape: (batch_size,)
61
+ n_timesteps (int): number of steps to use for reverse diffusion in decoder.
62
+ temperature (float, optional): controls variance of terminal distribution.
63
+ y (torch.Tensor): mel spectrogram of reference audio
64
+ shape: (batch_size, mel_channels, time)
65
+ length_scale (float, optional): controls speech pace.
66
+ Increase value to slow down generated speech and vice versa.
67
+
68
+ Returns:
69
+ dict: {
70
+ "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
71
+ # Average mel spectrogram generated by the encoder
72
+ "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
73
+ # Refined mel spectrogram improved by the CFM
74
+ "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
75
+ # Alignment map between text and mel spectrogram
76
+ """
77
+
78
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
79
+ c = self.ref_encoder(y, None)
80
+ x, mu_x, x_mask = self.encoder(x, c, x_lengths)
81
+ logw = self.dp(x, x_mask, c)
82
+
83
+ w = torch.exp(logw) * x_mask
84
+ w_ceil = torch.ceil(w) * length_scale
85
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
86
+ y_max_length = y_lengths.max()
87
+
88
+ # Using obtained durations `w` construct alignment map `attn`
89
+ y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
90
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
91
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
92
+
93
+ # Align encoded text and get mu_y
94
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
95
+ mu_y = mu_y.transpose(1, 2)
96
+ encoder_outputs = mu_y[:, :, :y_max_length]
97
+
98
+ # Generate sample tracing the probability flow
99
+ if cfg == 1.0:
100
+ decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver)
101
+ else:
102
+ cfg_kwargs = {'fake_speaker': self.fake_speaker, 'fake_content': self.fake_content, 'cfg_strength': cfg}
103
+ decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver, cfg_kwargs)
104
+
105
+ decoder_outputs = decoder_outputs[:, :, :y_max_length]
106
+
107
+
108
+ return {
109
+ "encoder_outputs": encoder_outputs,
110
+ "decoder_outputs": decoder_outputs,
111
+ "attn": attn[:, :, :y_max_length],
112
+ }
113
+
114
+ def forward(self, x, x_lengths, y, y_lengths, z, z_lengths):
115
+ """
116
+ Computes 3 losses:
117
+ 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
118
+ 2. prior loss: loss between mel-spectrogram and encoder outputs.
119
+ 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
120
+
121
+ Args:
122
+ x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
123
+ shape: (batch_size, max_text_length)
124
+ x_lengths (torch.Tensor): lengths of texts in batch.
125
+ shape: (batch_size,)
126
+ y (torch.Tensor): batch of corresponding mel-spectrograms.
127
+ shape: (batch_size, n_feats, max_mel_length)
128
+ y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
129
+ shape: (batch_size,)
130
+ z (torch.Tensor): batch of cliced mel-spectrograms.
131
+ shape: (batch_size, n_feats, max_mel_length)
132
+ z_lengths (torch.Tensor): lengths of sliced mel-spectrograms in batch.
133
+ shape: (batch_size,)
134
+ """
135
+ # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
136
+ y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
137
+ z_mask = sequence_mask(z_lengths, z.size(2)).unsqueeze(1).to(z.dtype)
138
+ cfg_mask = torch.rand(y.size(0), 1, device=y.device) > self.cfg_dropout
139
+
140
+ # compute global speaker embedding
141
+ c = self.ref_encoder(z, z_mask) * cfg_mask + ~cfg_mask * self.fake_speaker.repeat(z.size(0), 1)
142
+
143
+ x, mu_x, x_mask = self.encoder(x, c, x_lengths)
144
+ logw = self.dp(x, x_mask, c)
145
+
146
+ attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
147
+
148
+ # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
149
+ with torch.no_grad():
150
+ s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
151
+ neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True)
152
+ neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
153
+ neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
154
+ neg_cent4 = torch.sum(-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True)
155
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
156
+
157
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
158
+ attn = (monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach())
159
+
160
+ # Compute loss between predicted log-scaled durations and those obtained from MAS
161
+ # refered to as prior loss in the paper
162
+ logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
163
+ dur_loss = duration_loss(logw, logw_, x_lengths)
164
+
165
+ # Align encoded text with mel-spectrogram and get mu_y segment
166
+ attn = attn.squeeze(1).transpose(1,2)
167
+ mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
168
+ mu_y = mu_y.transpose(1, 2)
169
+
170
+ # Compute loss of the decoder
171
+ cfg_mask = cfg_mask.unsqueeze(-1)
172
+ mu_y_masked = mu_y * cfg_mask + ~cfg_mask * self.fake_content.repeat(mu_y.size(0), 1, mu_y.size(-1)) # mask content information for better diversity for flow-matching
173
+ diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y_masked, c)
174
+
175
+ prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
176
+ prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
177
+
178
+ return dur_loss, diff_loss, prior_loss, attn
models/reference_encoder.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class Conv1dGLU(nn.Module):
5
+ """
6
+ Conv1d + GLU(Gated Linear Unit) with residual connection.
7
+ For GLU refer to https://arxiv.org/abs/1612.08083 paper.
8
+ """
9
+
10
+ def __init__(self, in_channels, out_channels, kernel_size, dropout):
11
+ super(Conv1dGLU, self).__init__()
12
+ self.out_channels = out_channels
13
+ self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
14
+ self.dropout = nn.Dropout(dropout)
15
+
16
+ def forward(self, x):
17
+ residual = x
18
+ x = self.conv1(x)
19
+ x1, x2 = torch.split(x, self.out_channels, dim=1)
20
+ x = x1 * torch.sigmoid(x2)
21
+ x = residual + self.dropout(x)
22
+ return x
23
+
24
+ # modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766
25
+ class MelStyleEncoder(nn.Module):
26
+ """MelStyleEncoder"""
27
+
28
+ def __init__(
29
+ self,
30
+ n_mel_channels=80,
31
+ style_hidden=128,
32
+ style_vector_dim=256,
33
+ style_kernel_size=5,
34
+ style_head=2,
35
+ dropout=0.1,
36
+ ):
37
+ super(MelStyleEncoder, self).__init__()
38
+ self.in_dim = n_mel_channels
39
+ self.hidden_dim = style_hidden
40
+ self.out_dim = style_vector_dim
41
+ self.kernel_size = style_kernel_size
42
+ self.n_head = style_head
43
+ self.dropout = dropout
44
+
45
+ self.spectral = nn.Sequential(
46
+ nn.Linear(self.in_dim, self.hidden_dim),
47
+ nn.Mish(inplace=True),
48
+ nn.Dropout(self.dropout),
49
+ nn.Linear(self.hidden_dim, self.hidden_dim),
50
+ nn.Mish(inplace=True),
51
+ nn.Dropout(self.dropout),
52
+ )
53
+
54
+ self.temporal = nn.Sequential(
55
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
56
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
57
+ )
58
+
59
+ self.slf_attn = nn.MultiheadAttention(
60
+ self.hidden_dim,
61
+ self.n_head,
62
+ self.dropout,
63
+ batch_first=True
64
+ )
65
+
66
+ self.fc = nn.Linear(self.hidden_dim, self.out_dim)
67
+
68
+ def temporal_avg_pool(self, x, mask=None):
69
+ if mask is None:
70
+ return torch.mean(x, dim=1)
71
+ else:
72
+ return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / (~mask).sum(dim=1).unsqueeze(1)
73
+
74
+ def forward(self, x, x_mask=None):
75
+ x = x.transpose(1, 2)
76
+
77
+ # spectral
78
+ x = self.spectral(x)
79
+ # temporal
80
+ x = x.transpose(1, 2)
81
+ x = self.temporal(x)
82
+ x = x.transpose(1, 2)
83
+ # self-attention
84
+ if x_mask is not None:
85
+ x_mask = ~x_mask.squeeze(1).to(torch.bool)
86
+ x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask, need_weights=False)
87
+ # fc
88
+ x = self.fc(x)
89
+ # temoral average pooling
90
+ w = self.temporal_avg_pool(x, mask=x_mask)
91
+
92
+ return w
93
+
94
+ # Attention Pool version of MelStyleEncoder, not used
95
+ class AttnMelStyleEncoder(nn.Module):
96
+ """MelStyleEncoder"""
97
+
98
+ def __init__(
99
+ self,
100
+ n_mel_channels=80,
101
+ style_hidden=128,
102
+ style_vector_dim=256,
103
+ style_kernel_size=5,
104
+ style_head=2,
105
+ dropout=0.1,
106
+ ):
107
+ super().__init__()
108
+ self.in_dim = n_mel_channels
109
+ self.hidden_dim = style_hidden
110
+ self.out_dim = style_vector_dim
111
+ self.kernel_size = style_kernel_size
112
+ self.n_head = style_head
113
+ self.dropout = dropout
114
+
115
+ self.spectral = nn.Sequential(
116
+ nn.Linear(self.in_dim, self.hidden_dim),
117
+ nn.Mish(inplace=True),
118
+ nn.Dropout(self.dropout),
119
+ nn.Linear(self.hidden_dim, self.hidden_dim),
120
+ nn.Mish(inplace=True),
121
+ nn.Dropout(self.dropout),
122
+ )
123
+
124
+ self.temporal = nn.Sequential(
125
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
126
+ Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
127
+ )
128
+
129
+ self.slf_attn = nn.MultiheadAttention(
130
+ self.hidden_dim,
131
+ self.n_head,
132
+ self.dropout,
133
+ batch_first=True
134
+ )
135
+
136
+ self.fc = nn.Linear(self.hidden_dim, self.out_dim)
137
+
138
+ def temporal_avg_pool(self, x, mask=None):
139
+ if mask is None:
140
+ return torch.mean(x, dim=1)
141
+ else:
142
+ return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / (~mask).sum(dim=1).unsqueeze(1)
143
+
144
+ def forward(self, x, x_mask=None):
145
+ x = x.transpose(1, 2)
146
+
147
+ # spectral
148
+ x = self.spectral(x)
149
+ # temporal
150
+ x = x.transpose(1, 2)
151
+ x = self.temporal(x)
152
+ x = x.transpose(1, 2)
153
+ # self-attention
154
+ if x_mask is not None:
155
+ x_mask = ~x_mask.squeeze(1).to(torch.bool)
156
+ zeros = torch.zeros(x_mask.size(0), 1, device=x_mask.device, dtype=x_mask.dtype)
157
+ x_attn_mask = torch.cat((zeros, x_mask), dim=1)
158
+ else:
159
+ x_attn_mask = None
160
+
161
+ avg = self.temporal_avg_pool(x, x_mask).unsqueeze(1)
162
+ x = torch.cat([avg, x], dim=1)
163
+ x, _ = self.slf_attn(x, x, x, key_padding_mask=x_attn_mask, need_weights=False)
164
+ x = x[:, 0, :]
165
+ # fc
166
+ x = self.fc(x)
167
+
168
+ return x
models/text_encoder.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from models.diffusion_transformer import DiTConVBlock
5
+ from utils.mask import sequence_mask
6
+
7
+ # modified from https://github.com/jaywalnut310/vits/blob/main/models.py
8
+ class TextEncoder(nn.Module):
9
+ def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
10
+ super().__init__()
11
+ self.n_vocab = n_vocab
12
+ self.out_channels = out_channels
13
+ self.hidden_channels = hidden_channels
14
+ self.filter_channels = filter_channels
15
+ self.n_heads = n_heads
16
+ self.n_layers = n_layers
17
+ self.kernel_size = kernel_size
18
+ self.p_dropout = p_dropout
19
+
20
+ self.scale = self.hidden_channels ** 0.5
21
+
22
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
23
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
24
+
25
+ self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
26
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
27
+
28
+ self.initialize_weights()
29
+
30
+ def initialize_weights(self):
31
+ for block in self.encoder:
32
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
33
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
34
+
35
+ def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
36
+ x = self.emb(x) * self.scale # [b, t, h]
37
+ x = x.transpose(1, -1) # [b, h, t]
38
+ x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
39
+
40
+ for layer in self.encoder:
41
+ x = layer(x, c, x_mask)
42
+ mu_x = self.proj(x) * x_mask
43
+
44
+ return x, mu_x, x_mask
monotonic_align/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import maximum_path_jit
5
+
6
+
7
+ def maximum_path(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/core.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+
4
+ tqdm
5
+ numpy
6
+ soundfile # to make sure that torchaudio has at least one valid backend
7
+
8
+ tensorboard
9
+
10
+ # for monotonic_align
11
+ numba
12
+
13
+ # ODE-solver
14
+ torchdiffeq
15
+
16
+ # for g2p
17
+ # chinese
18
+ pypinyin
19
+ jieba
20
+ # english
21
+ eng_to_ipa
22
+ unidecode
23
+ inflect
24
+ # japanese
25
+ # if pyopenjtalk fail to download open_jtalk_dic_utf_8-1.11.tar.gz, manually download and unzip the file below
26
+ # https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz
27
+ # and set os.environ['OPEN_JTALK_DICT_DIR'] to the folder path
28
+ pyopenjtalk-prebuilt # if using python >= 3.12, install pyopenjtalk instead
29
+
30
+ # for webui
31
+ gradio
32
+ matplotlib
33
+
text/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
text/__init__.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+ from text.symbols import symbols
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+
11
+ def text_to_sequence(text, symbols, cleaner_names):
12
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13
+ Args:
14
+ text: string to convert to a sequence
15
+ cleaner_names: names of the cleaner functions to run the text through
16
+ Returns:
17
+ List of integers corresponding to the symbols in the text
18
+ '''
19
+ sequence = []
20
+ symbol_to_id = {s: i for i, s in enumerate(symbols)}
21
+ clean_text = _clean_text(text, cleaner_names)
22
+ print(clean_text)
23
+ print(f" length:{len(clean_text)}")
24
+ for symbol in clean_text:
25
+ if symbol not in symbol_to_id.keys():
26
+ continue
27
+ symbol_id = symbol_to_id[symbol]
28
+ sequence += [symbol_id]
29
+ print(f" length:{len(sequence)}")
30
+ return sequence
31
+
32
+
33
+ def cleaned_text_to_sequence(cleaned_text):
34
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
35
+ Args:
36
+ text: string to convert to a sequence
37
+ Returns:
38
+ List of integers corresponding to the symbols in the text
39
+ '''
40
+ # symbol_to_id = {s: i for i, s in enumerate(symbols)}
41
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
42
+ return sequence
43
+
44
+ def cleaned_text_to_sequence_chinese(cleaned_text):
45
+ '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
46
+ Args:
47
+ text: string to convert to a sequence
48
+ Returns:
49
+ List of integers corresponding to the symbols in the text
50
+ '''
51
+ # symbol_to_id = {s: i for i, s in enumerate(symbols)}
52
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split(' ') if symbol in _symbol_to_id.keys()]
53
+ return sequence
54
+
55
+
56
+ def sequence_to_text(sequence):
57
+ '''Converts a sequence of IDs back to a string'''
58
+ result = ''
59
+ for symbol_id in sequence:
60
+ s = _id_to_symbol[symbol_id]
61
+ result += s
62
+ return result
63
+
64
+
65
+ def _clean_text(text, cleaner_names):
66
+ for name in cleaner_names:
67
+ cleaner = getattr(cleaners, name)
68
+ if not cleaner:
69
+ raise Exception('Unknown cleaner: %s' % name)
70
+ text = cleaner(text)
71
+ return text
text/cleaners.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from text.english import english_to_ipa2
4
+ from text.mandarin import chinese_to_cnm3
5
+ from text.japanese import japanese_to_ipa2
6
+
7
+ language_module_map = {"PAD":0, "ZH": 1, "EN": 2, "JA": 3}
8
+
9
+ # 预编译正则表达式
10
+ ZH_PATTERN = re.compile(r'[\u3400-\u4DBF\u4e00-\u9FFF\uF900-\uFAFF\u3000-\u303F]')
11
+ EN_PATTERN = re.compile(r'[a-zA-Z.,!?\'"(){}[\]<>:;@#$%^&*-_+=/\\|~`]+')
12
+ JP_PATTERN = re.compile(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FAF\u31F0-\u31FF\uFF00-\uFFEF\u3000-\u303F]')
13
+ CLEANER_PATTERN = re.compile(r'\[(ZH|EN|JA)\]')
14
+
15
+ def detect_language(text: str, prev_lang=None):
16
+ """
17
+ 根据给定的文本检测语言
18
+
19
+ :param text: 输入文本
20
+ :param prev_lang: 上一个检测到的语言
21
+ :return: 'ZH' for Chinese, 'EN' for English, 'JA' for Japanese, or prev_lang for spaces
22
+ """
23
+ if ZH_PATTERN.search(text): return 'ZH'
24
+ if EN_PATTERN.search(text): return 'EN'
25
+ if JP_PATTERN.search(text): return 'JA'
26
+ if text.isspace(): return prev_lang # 若是空格,则返回前一个语言
27
+ return None
28
+
29
+ # auto detect language using re
30
+ def cjke_cleaners4(text: str):
31
+ """
32
+ 根据文本内容自动检测语言并转换为IPA音标
33
+
34
+ :param text: 输入文本
35
+ :return: 转换为IPA音标的文本
36
+ """
37
+ text = CLEANER_PATTERN.sub('', text)
38
+ pointer = 0
39
+ output = ''
40
+ current_language = detect_language(text[pointer])
41
+
42
+ while pointer < len(text):
43
+ temp_text = ''
44
+ while pointer < len(text) and detect_language(text[pointer], current_language) == current_language:
45
+ temp_text += text[pointer]
46
+ pointer += 1
47
+ if current_language == 'ZH':
48
+ output += chinese_to_cnm3(temp_text)
49
+ elif current_language == 'JA':
50
+ output += japanese_to_ipa2(temp_text)
51
+ elif current_language == 'EN':
52
+ output += english_to_ipa2(temp_text)
53
+ if pointer < len(text):
54
+ current_language = detect_language(text[pointer])
55
+
56
+ output = re.sub(r'\s+$', '', output)
57
+ output = re.sub(r'([^\.,!\?\-…~])$', r'\1.', output)
58
+ return output
text/cn2an/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.5.22"
2
+
3
+ from .cn2an import Cn2An
4
+ from .an2cn import An2Cn
5
+ from .transform import Transform
6
+
7
+ cn2an = Cn2An().cn2an
8
+ an2cn = An2Cn().an2cn
9
+ transform = Transform().transform
10
+
11
+ __all__ = [
12
+ "__version__",
13
+ "cn2an",
14
+ "an2cn",
15
+ "transform"
16
+ ]
text/cn2an/an2cn.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ from warnings import warn
3
+
4
+ # from proces import preprocess
5
+
6
+ from .conf import NUMBER_LOW_AN2CN, NUMBER_UP_AN2CN, UNIT_LOW_ORDER_AN2CN, UNIT_UP_ORDER_AN2CN
7
+
8
+
9
+ class An2Cn(object):
10
+ def __init__(self) -> None:
11
+ self.all_num = "0123456789"
12
+ self.number_low = NUMBER_LOW_AN2CN
13
+ self.number_up = NUMBER_UP_AN2CN
14
+ self.mode_list = ["low", "up", "rmb", "direct"]
15
+
16
+ def an2cn(self, inputs: Union[str, int, float] = None, mode: str = "low") -> str:
17
+ """阿拉伯数字转中文数字
18
+
19
+ :param inputs: 阿拉伯数字
20
+ :param mode: low 小写数字,up 大写数字,rmb 人民币大写,direct 直接转化
21
+ :return: 中文数字
22
+ """
23
+ if inputs is not None and inputs != "":
24
+ if mode not in self.mode_list:
25
+ raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
26
+
27
+ # 将数字转化为字符串,这里会有Python会自动做转化
28
+ # 1. -> 1.0 1.00 -> 1.0 -0 -> 0
29
+ if not isinstance(inputs, str):
30
+ inputs = self.__number_to_string(inputs)
31
+
32
+ # 数据预处理:
33
+ # 1. 繁体转简体
34
+ # 2. 全角转半角
35
+ # inputs = preprocess(inputs, pipelines=[
36
+ # "traditional_to_simplified",
37
+ # "full_angle_to_half_angle"
38
+ # ])
39
+
40
+ # 检查数据是否有效
41
+ self.__check_inputs_is_valid(inputs)
42
+
43
+ # 判断正负
44
+ if inputs[0] == "-":
45
+ sign = "负"
46
+ inputs = inputs[1:]
47
+ else:
48
+ sign = ""
49
+
50
+ if mode == "direct":
51
+ output = self.__direct_convert(inputs)
52
+ else:
53
+ # 切割整数部分和小数部分
54
+ split_result = inputs.split(".")
55
+ len_split_result = len(split_result)
56
+ if len_split_result == 1:
57
+ # 不包含小数的输入
58
+ integer_data = split_result[0]
59
+ if mode == "rmb":
60
+ output = self.__integer_convert(integer_data, "up") + "元整"
61
+ else:
62
+ output = self.__integer_convert(integer_data, mode)
63
+ elif len_split_result == 2:
64
+ # 包含小数的输入
65
+ integer_data, decimal_data = split_result
66
+ if mode == "rmb":
67
+ int_data = self.__integer_convert(integer_data, "up")
68
+ dec_data = self.__decimal_convert(decimal_data, "up")
69
+ len_dec_data = len(dec_data)
70
+
71
+ if len_dec_data == 0:
72
+ output = int_data + "元整"
73
+ elif len_dec_data == 1:
74
+ raise ValueError(f"异常输出:{dec_data}")
75
+ elif len_dec_data == 2:
76
+ if dec_data[1] != "零":
77
+ if int_data == "零":
78
+ output = dec_data[1] + "角"
79
+ else:
80
+ output = int_data + "元" + dec_data[1] + "角"
81
+ else:
82
+ output = int_data + "元整"
83
+ else:
84
+ if dec_data[1] != "零":
85
+ if dec_data[2] != "零":
86
+ if int_data == "零":
87
+ output = dec_data[1] + "角" + dec_data[2] + "分"
88
+ else:
89
+ output = int_data + "元" + dec_data[1] + "角" + dec_data[2] + "分"
90
+ else:
91
+ if int_data == "零":
92
+ output = dec_data[1] + "角"
93
+ else:
94
+ output = int_data + "元" + dec_data[1] + "角"
95
+ else:
96
+ if dec_data[2] != "零":
97
+ if int_data == "零":
98
+ output = dec_data[2] + "分"
99
+ else:
100
+ output = int_data + "元" + "零" + dec_data[2] + "分"
101
+ else:
102
+ output = int_data + "元整"
103
+ else:
104
+ output = self.__integer_convert(integer_data, mode) + self.__decimal_convert(decimal_data, mode)
105
+ else:
106
+ raise ValueError(f"输入格式错误:{inputs}!")
107
+ else:
108
+ raise ValueError("输入数据为空!")
109
+
110
+ return sign + output
111
+
112
+ def __direct_convert(self, inputs: str) -> str:
113
+ _output = ""
114
+ for d in inputs:
115
+ if d == ".":
116
+ _output += "点"
117
+ else:
118
+ _output += self.number_low[int(d)]
119
+ return _output
120
+
121
+ @staticmethod
122
+ def __number_to_string(number_data: Union[int, float]) -> str:
123
+ # 小数处理:python 会自动把 0.00005 转化成 5e-05,因此 str(0.00005) != "0.00005"
124
+ string_data = str(number_data)
125
+ if "e" in string_data:
126
+ string_data_list = string_data.split("e")
127
+ string_key = string_data_list[0]
128
+ string_value = string_data_list[1]
129
+ if string_value[0] == "-":
130
+ string_data = "0." + "0" * (int(string_value[1:]) - 1) + string_key
131
+ else:
132
+ string_data = string_key + "0" * int(string_value)
133
+ return string_data
134
+
135
+ def __check_inputs_is_valid(self, check_data: str) -> None:
136
+ # 检查输入数据是否在规定的字典中
137
+ all_check_keys = self.all_num + ".-"
138
+ for data in check_data:
139
+ if data not in all_check_keys:
140
+ raise ValueError(f"输入的数据不在转化范围内:{data}!")
141
+
142
+ def __integer_convert(self, integer_data: str, mode: str) -> str:
143
+ if mode == "low":
144
+ numeral_list = NUMBER_LOW_AN2CN
145
+ unit_list = UNIT_LOW_ORDER_AN2CN
146
+ elif mode == "up":
147
+ numeral_list = NUMBER_UP_AN2CN
148
+ unit_list = UNIT_UP_ORDER_AN2CN
149
+ else:
150
+ raise ValueError(f"error mode: {mode}")
151
+
152
+ # 去除前面的 0,比如 007 => 7
153
+ integer_data = str(int(integer_data))
154
+
155
+ len_integer_data = len(integer_data)
156
+ if len_integer_data > len(unit_list):
157
+ raise ValueError(f"超出数据范围,最长支持 {len(unit_list)} 位")
158
+
159
+ output_an = ""
160
+ for i, d in enumerate(integer_data):
161
+ if int(d):
162
+ output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
163
+ else:
164
+ if not (len_integer_data - i - 1) % 4:
165
+ output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
166
+
167
+ if i > 0 and not output_an[-1] == "零":
168
+ output_an += numeral_list[int(d)]
169
+
170
+ output_an = output_an.replace("零零", "零").replace("零万", "万").replace("零亿", "亿").replace("亿万", "亿") \
171
+ .strip("零")
172
+
173
+ # 解决「一十几」问题
174
+ if output_an[:2] in ["一十"]:
175
+ output_an = output_an[1:]
176
+
177
+ # 0 - 1 之间的小数
178
+ if not output_an:
179
+ output_an = "零"
180
+
181
+ return output_an
182
+
183
+ def __decimal_convert(self, decimal_data: str, o_mode: str) -> str:
184
+ len_decimal_data = len(decimal_data)
185
+
186
+ if len_decimal_data > 16:
187
+ warn(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
188
+ decimal_data = decimal_data[:16]
189
+
190
+ if len_decimal_data:
191
+ output_an = "点"
192
+ else:
193
+ output_an = ""
194
+
195
+ if o_mode == "low":
196
+ numeral_list = NUMBER_LOW_AN2CN
197
+ elif o_mode == "up":
198
+ numeral_list = NUMBER_UP_AN2CN
199
+ else:
200
+ raise ValueError(f"error mode: {o_mode}")
201
+
202
+ for data in decimal_data:
203
+ output_an += numeral_list[int(data)]
204
+ return output_an
text/cn2an/cn2an.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from warnings import warn
3
+ from typing import Union
4
+
5
+ # from proces import preprocess
6
+
7
+ from .an2cn import An2Cn
8
+ from .conf import NUMBER_CN2AN, UNIT_CN2AN, STRICT_CN_NUMBER, NORMAL_CN_NUMBER, NUMBER_LOW_AN2CN, UNIT_LOW_AN2CN
9
+
10
+
11
+ class Cn2An(object):
12
+ def __init__(self) -> None:
13
+ self.all_num = "".join(list(NUMBER_CN2AN.keys()))
14
+ self.all_unit = "".join(list(UNIT_CN2AN.keys()))
15
+ self.strict_cn_number = STRICT_CN_NUMBER
16
+ self.normal_cn_number = NORMAL_CN_NUMBER
17
+ self.check_key_dict = {
18
+ "strict": "".join(self.strict_cn_number.values()) + "点负",
19
+ "normal": "".join(self.normal_cn_number.values()) + "点负",
20
+ "smart": "".join(self.normal_cn_number.values()) + "点负" + "01234567890.-"
21
+ }
22
+ self.pattern_dict = self.__get_pattern()
23
+ self.ac = An2Cn()
24
+ self.mode_list = ["strict", "normal", "smart"]
25
+ self.yjf_pattern = re.compile(fr"^.*?[元圆][{self.all_num}]角([{self.all_num}]分)?$")
26
+ self.pattern1 = re.compile(fr"^-?\d+(\.\d+)?[{self.all_unit}]?$")
27
+ self.ptn_all_num = re.compile(f"^[{self.all_num}]+$")
28
+ # "十?" is for special case "十一万三"
29
+ self.ptn_speaking_mode = re.compile(f"^([{self.all_num}]{{0,2}}[{self.all_unit}])+[{self.all_num}]$")
30
+
31
+ def cn2an(self, inputs: Union[str, int, float] = None, mode: str = "strict") -> Union[float, int]:
32
+ """中文数字转阿拉伯数字
33
+
34
+ :param inputs: 中文数字、阿拉伯数字、中文数字和阿拉伯数字
35
+ :param mode: strict 严格,normal 正常,smart 智能
36
+ :return: 阿拉伯数字
37
+ """
38
+ if inputs is not None or inputs == "":
39
+ if mode not in self.mode_list:
40
+ raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
41
+
42
+ # 将数字转化为字符串
43
+ if not isinstance(inputs, str):
44
+ inputs = str(inputs)
45
+
46
+ # 数据预处理:
47
+ # 1. 繁体转简体
48
+ # 2. 全角转半角
49
+ # inputs = preprocess(inputs, pipelines=[
50
+ # "traditional_to_simplified",
51
+ # "full_angle_to_half_angle"
52
+ # ])
53
+
54
+ # 特殊转化 廿
55
+ inputs = inputs.replace("廿", "二十")
56
+
57
+ # 检查输入数据是否有效
58
+ sign, integer_data, decimal_data, is_all_num = self.__check_input_data_is_valid(inputs, mode)
59
+
60
+ # smart 下的特殊情况
61
+ if sign == 0:
62
+ return integer_data
63
+ else:
64
+ if not is_all_num:
65
+ if decimal_data is None:
66
+ output = self.__integer_convert(integer_data)
67
+ else:
68
+ output = self.__integer_convert(integer_data) + self.__decimal_convert(decimal_data)
69
+ # fix 1 + 0.57 = 1.5699999999999998
70
+ output = round(output, len(decimal_data))
71
+ else:
72
+ if decimal_data is None:
73
+ output = self.__direct_convert(integer_data)
74
+ else:
75
+ output = self.__direct_convert(integer_data) + self.__decimal_convert(decimal_data)
76
+ # fix 1 + 0.57 = 1.5699999999999998
77
+ output = round(output, len(decimal_data))
78
+ else:
79
+ raise ValueError("输入数据为空!")
80
+
81
+ return sign * output
82
+
83
+ def __get_pattern(self) -> dict:
84
+ # 整数严格检查
85
+ _0 = "[零]"
86
+ _1_9 = "[一二三四五六七八九]"
87
+ _10_99 = f"{_1_9}?[十]{_1_9}?"
88
+ _1_99 = f"({_10_99}|{_1_9})"
89
+ _100_999 = f"({_1_9}[百]([零]{_1_9})?|{_1_9}[百]{_10_99})"
90
+ _1_999 = f"({_100_999}|{_1_99})"
91
+ _1000_9999 = f"({_1_9}[千]([零]{_1_99})?|{_1_9}[千]{_100_999})"
92
+ _1_9999 = f"({_1000_9999}|{_1_999})"
93
+ _10000_99999999 = f"({_1_9999}[万]([零]{_1_999})?|{_1_9999}[万]{_1000_9999})"
94
+ _1_99999999 = f"({_10000_99999999}|{_1_9999})"
95
+ _100000000_9999999999999999 = f"({_1_99999999}[亿]([零]{_1_99999999})?|{_1_99999999}[亿]{_10000_99999999})"
96
+ _1_9999999999999999 = f"({_100000000_9999999999999999}|{_1_99999999})"
97
+ str_int_pattern = f"^({_0}|{_1_9999999999999999})$"
98
+ nor_int_pattern = f"^({_0}|{_1_9999999999999999})$"
99
+
100
+ str_dec_pattern = "^[零一二三四五六七八九]{0,15}[一二三四五六七八九]$"
101
+ nor_dec_pattern = "^[零一二三四五六七八九]{0,16}$"
102
+
103
+ for str_num in self.strict_cn_number.keys():
104
+ str_int_pattern = str_int_pattern.replace(str_num, self.strict_cn_number[str_num])
105
+ str_dec_pattern = str_dec_pattern.replace(str_num, self.strict_cn_number[str_num])
106
+ for nor_num in self.normal_cn_number.keys():
107
+ nor_int_pattern = nor_int_pattern.replace(nor_num, self.normal_cn_number[nor_num])
108
+ nor_dec_pattern = nor_dec_pattern.replace(nor_num, self.normal_cn_number[nor_num])
109
+
110
+ pattern_dict = {
111
+ "strict": {
112
+ "int": re.compile(str_int_pattern),
113
+ "dec": re.compile(str_dec_pattern)
114
+ },
115
+ "normal": {
116
+ "int": re.compile(nor_int_pattern),
117
+ "dec": re.compile(nor_dec_pattern)
118
+ }
119
+ }
120
+ return pattern_dict
121
+
122
+ def __copy_num(self, num):
123
+ cn_num = ""
124
+ for n in num:
125
+ cn_num += NUMBER_LOW_AN2CN[int(n)]
126
+ return cn_num
127
+
128
+ def __check_input_data_is_valid(self, check_data: str, mode: str) -> (int, str, str, bool):
129
+ # 去除 元整、圆整、元正、圆正
130
+ stop_words = ["元整", "圆整", "元正", "圆正"]
131
+ for word in stop_words:
132
+ if check_data[-2:] == word:
133
+ check_data = check_data[:-2]
134
+
135
+ # 去除 元、圆
136
+ if mode != "strict":
137
+ normal_stop_words = ["圆", "元"]
138
+ for word in normal_stop_words:
139
+ if check_data[-1] == word:
140
+ check_data = check_data[:-1]
141
+
142
+ # 处理元角分
143
+ result = self.yjf_pattern.search(check_data)
144
+ if result:
145
+ check_data = check_data.replace("元", "点").replace("角", "").replace("分", "")
146
+
147
+ # 处理特殊问法:一千零十一 一万零百一十一
148
+ if "零十" in check_data:
149
+ check_data = check_data.replace("零十", "零一十")
150
+ if "零百" in check_data:
151
+ check_data = check_data.replace("零百", "零一百")
152
+
153
+ for data in check_data:
154
+ if data not in self.check_key_dict[mode]:
155
+ raise ValueError(f"当前为{mode}模式,输入的数据不在转化范围内:{data}!")
156
+
157
+ # 确定正负号
158
+ if check_data[0] == "负":
159
+ check_data = check_data[1:]
160
+ sign = -1
161
+ else:
162
+ sign = 1
163
+
164
+ if "点" in check_data:
165
+ split_data = check_data.split("点")
166
+ if len(split_data) == 2:
167
+ integer_data, decimal_data = split_data
168
+ # 将 smart 模式中的阿拉伯数字转化成中文数字
169
+ if mode == "smart":
170
+ integer_data = re.sub(r"\d+", lambda x: self.ac.an2cn(x.group()), integer_data)
171
+ decimal_data = re.sub(r"\d+", lambda x: self.__copy_num(x.group()), decimal_data)
172
+ mode = "normal"
173
+ else:
174
+ raise ValueError("数据中包含不止一个点!")
175
+ else:
176
+ integer_data = check_data
177
+ decimal_data = None
178
+ # 将 smart 模式中的阿拉伯数字转化成中文数字
179
+ if mode == "smart":
180
+ # 10.1万 10.1
181
+ result1 = self.pattern1.search(integer_data)
182
+ if result1:
183
+ if result1.group() == integer_data:
184
+ if integer_data[-1] in UNIT_CN2AN.keys():
185
+ output = int(float(integer_data[:-1]) * UNIT_CN2AN[integer_data[-1]])
186
+ else:
187
+ output = float(integer_data)
188
+ return 0, output, None, None
189
+
190
+ integer_data = re.sub(r"\d+", lambda x: self.ac.an2cn(x.group()), integer_data)
191
+ mode = "normal"
192
+
193
+ result_int = self.pattern_dict[mode]["int"].search(integer_data)
194
+ if result_int:
195
+ if result_int.group() == integer_data:
196
+ if decimal_data is not None:
197
+ result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
198
+ if result_dec:
199
+ if result_dec.group() == decimal_data:
200
+ return sign, integer_data, decimal_data, False
201
+ else:
202
+ return sign, integer_data, decimal_data, False
203
+ else:
204
+ if mode == "strict":
205
+ raise ValueError(f"不符合格式的数据:{integer_data}")
206
+ elif mode == "normal":
207
+ # 纯数模式:一二三
208
+ result_all_num = self.ptn_all_num.search(integer_data)
209
+ if result_all_num:
210
+ if result_all_num.group() == integer_data:
211
+ if decimal_data is not None:
212
+ result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
213
+ if result_dec:
214
+ if result_dec.group() == decimal_data:
215
+ return sign, integer_data, decimal_data, True
216
+ else:
217
+ return sign, integer_data, decimal_data, True
218
+
219
+ # 口语模式:一万二,两千三,三百四,十三万六,一百二十五万���
220
+ result_speaking_mode = self.ptn_speaking_mode.search(integer_data)
221
+ if len(integer_data) >= 3 and result_speaking_mode and result_speaking_mode.group() == integer_data:
222
+ # len(integer_data)>=3: because the minimum length of integer_data that can be matched is 3
223
+ # to find the last unit
224
+ last_unit = result_speaking_mode.groups()[-1][-1]
225
+ _unit = UNIT_LOW_AN2CN[UNIT_CN2AN[last_unit] // 10]
226
+ integer_data = integer_data + _unit
227
+ if decimal_data is not None:
228
+ result_dec = self.pattern_dict[mode]["dec"].search(decimal_data)
229
+ if result_dec:
230
+ if result_dec.group() == decimal_data:
231
+ return sign, integer_data, decimal_data, False
232
+ else:
233
+ return sign, integer_data, decimal_data, False
234
+
235
+ raise ValueError(f"不符合格式的数据:{check_data}")
236
+
237
+ def __integer_convert(self, integer_data: str) -> int:
238
+ # 核心
239
+ output_integer = 0
240
+ unit = 1
241
+ ten_thousand_unit = 1
242
+ for index, cn_num in enumerate(reversed(integer_data)):
243
+ # 数值
244
+ if cn_num in NUMBER_CN2AN:
245
+ num = NUMBER_CN2AN[cn_num]
246
+ output_integer += num * unit
247
+ # 单位
248
+ elif cn_num in UNIT_CN2AN:
249
+ unit = UNIT_CN2AN[cn_num]
250
+ # 判断出万、亿、万亿
251
+ if unit % 10000 == 0:
252
+ # 万 亿
253
+ if unit > ten_thousand_unit:
254
+ ten_thousand_unit = unit
255
+ # 万亿
256
+ else:
257
+ ten_thousand_unit = unit * ten_thousand_unit
258
+ unit = ten_thousand_unit
259
+
260
+ if unit < ten_thousand_unit:
261
+ unit = unit * ten_thousand_unit
262
+
263
+ if index == len(integer_data) - 1:
264
+ output_integer += unit
265
+ else:
266
+ raise ValueError(f"{cn_num} 不在转化范围内")
267
+
268
+ return int(output_integer)
269
+
270
+ def __decimal_convert(self, decimal_data: str) -> float:
271
+ len_decimal_data = len(decimal_data)
272
+
273
+ if len_decimal_data > 16:
274
+ warn(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
275
+ decimal_data = decimal_data[:16]
276
+ len_decimal_data = 16
277
+
278
+ output_decimal = 0
279
+ for index in range(len(decimal_data) - 1, -1, -1):
280
+ unit_key = NUMBER_CN2AN[decimal_data[index]]
281
+ output_decimal += unit_key * 10 ** -(index + 1)
282
+
283
+ # 处理精度溢出问题
284
+ output_decimal = round(output_decimal, len_decimal_data)
285
+
286
+ return output_decimal
287
+
288
+ def __direct_convert(self, data: str) -> int:
289
+ output_data = 0
290
+ for index in range(len(data) - 1, -1, -1):
291
+ unit_key = NUMBER_CN2AN[data[index]]
292
+ output_data += unit_key * 10 ** (len(data) - index - 1)
293
+
294
+ return output_data
text/cn2an/conf.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NUMBER_CN2AN = {
2
+ "零": 0,
3
+ "〇": 0,
4
+ "一": 1,
5
+ "壹": 1,
6
+ "幺": 1,
7
+ "二": 2,
8
+ "贰": 2,
9
+ "两": 2,
10
+ "三": 3,
11
+ "叁": 3,
12
+ "四": 4,
13
+ "肆": 4,
14
+ "五": 5,
15
+ "伍": 5,
16
+ "六": 6,
17
+ "陆": 6,
18
+ "七": 7,
19
+ "柒": 7,
20
+ "八": 8,
21
+ "捌": 8,
22
+ "九": 9,
23
+ "玖": 9,
24
+ }
25
+ UNIT_CN2AN = {
26
+ "十": 10,
27
+ "拾": 10,
28
+ "百": 100,
29
+ "佰": 100,
30
+ "千": 1000,
31
+ "仟": 1000,
32
+ "万": 10000,
33
+ "亿": 100000000,
34
+ }
35
+ UNIT_LOW_AN2CN = {
36
+ 10: "十",
37
+ 100: "百",
38
+ 1000: "千",
39
+ 10000: "万",
40
+ 100000000: "亿",
41
+ }
42
+ NUMBER_LOW_AN2CN = {
43
+ 0: "零",
44
+ 1: "一",
45
+ 2: "二",
46
+ 3: "三",
47
+ 4: "四",
48
+ 5: "五",
49
+ 6: "六",
50
+ 7: "七",
51
+ 8: "八",
52
+ 9: "九",
53
+ }
54
+ NUMBER_UP_AN2CN = {
55
+ 0: "零",
56
+ 1: "壹",
57
+ 2: "贰",
58
+ 3: "叁",
59
+ 4: "肆",
60
+ 5: "伍",
61
+ 6: "陆",
62
+ 7: "柒",
63
+ 8: "捌",
64
+ 9: "玖",
65
+ }
66
+ UNIT_LOW_ORDER_AN2CN = [
67
+ "",
68
+ "十",
69
+ "百",
70
+ "千",
71
+ "万",
72
+ "十",
73
+ "百",
74
+ "千",
75
+ "亿",
76
+ "十",
77
+ "百",
78
+ "千",
79
+ "万",
80
+ "十",
81
+ "百",
82
+ "千",
83
+ ]
84
+ UNIT_UP_ORDER_AN2CN = [
85
+ "",
86
+ "拾",
87
+ "佰",
88
+ "仟",
89
+ "万",
90
+ "拾",
91
+ "佰",
92
+ "仟",
93
+ "亿",
94
+ "拾",
95
+ "佰",
96
+ "仟",
97
+ "万",
98
+ "拾",
99
+ "佰",
100
+ "仟",
101
+ ]
102
+ STRICT_CN_NUMBER = {
103
+ "零": "零",
104
+ "一": "一壹",
105
+ "二": "二贰",
106
+ "三": "三叁",
107
+ "四": "四肆",
108
+ "五": "五伍",
109
+ "六": "六陆",
110
+ "七": "七柒",
111
+ "八": "八捌",
112
+ "九": "九玖",
113
+ "十": "十拾",
114
+ "百": "百佰",
115
+ "千": "千仟",
116
+ "万": "万",
117
+ "亿": "亿",
118
+ }
119
+ NORMAL_CN_NUMBER = {
120
+ "零": "零〇",
121
+ "一": "一壹幺",
122
+ "二": "二贰两",
123
+ "三": "三叁仨",
124
+ "四": "四肆",
125
+ "五": "五伍",
126
+ "六": "六陆",
127
+ "七": "七柒",
128
+ "八": "八捌",
129
+ "九": "九玖",
130
+ "十": "十拾",
131
+ "百": "百佰",
132
+ "千": "千仟",
133
+ "万": "万",
134
+ "亿": "亿",
135
+ }
text/cn2an/transform.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from warnings import warn
3
+
4
+ from .cn2an import Cn2An
5
+ from .an2cn import An2Cn
6
+ from .conf import UNIT_CN2AN
7
+
8
+
9
+ class Transform(object):
10
+ def __init__(self) -> None:
11
+ self.all_num = "零一二三四五六七八九"
12
+ self.all_unit = "".join(list(UNIT_CN2AN.keys()))
13
+ self.cn2an = Cn2An().cn2an
14
+ self.an2cn = An2Cn().an2cn
15
+ self.cn_pattern = f"负?([{self.all_num}{self.all_unit}]+点)?[{self.all_num}{self.all_unit}]+"
16
+ self.smart_cn_pattern = f"-?([0-9]+.)?[0-9]+[{self.all_unit}]+"
17
+
18
+ def transform(self, inputs: str, method: str = "cn2an") -> str:
19
+ if method == "cn2an":
20
+ inputs = inputs.replace("廿", "二十").replace("半", "0.5").replace("两", "2")
21
+ # date
22
+ inputs = re.sub(
23
+ fr"((({self.smart_cn_pattern})|({self.cn_pattern}))年)?([{self.all_num}十]+月)?([{self.all_num}十]+日)?",
24
+ lambda x: self.__sub_util(x.group(), "cn2an", "date"), inputs)
25
+ # fraction
26
+ inputs = re.sub(fr"{self.cn_pattern}分之{self.cn_pattern}",
27
+ lambda x: self.__sub_util(x.group(), "cn2an", "fraction"), inputs)
28
+ # percent
29
+ inputs = re.sub(fr"百分之{self.cn_pattern}",
30
+ lambda x: self.__sub_util(x.group(), "cn2an", "percent"), inputs)
31
+ # celsius
32
+ inputs = re.sub(fr"{self.cn_pattern}摄氏度",
33
+ lambda x: self.__sub_util(x.group(), "cn2an", "celsius"), inputs)
34
+ # number
35
+ output = re.sub(self.cn_pattern,
36
+ lambda x: self.__sub_util(x.group(), "cn2an", "number"), inputs)
37
+
38
+ elif method == "an2cn":
39
+ # date
40
+ inputs = re.sub(r"(\d{2,4}年)?(\d{1,2}月)?(\d{1,2}日)?",
41
+ lambda x: self.__sub_util(x.group(), "an2cn", "date"), inputs)
42
+ # fraction
43
+ inputs = re.sub(r"\d+/\d+",
44
+ lambda x: self.__sub_util(x.group(), "an2cn", "fraction"), inputs)
45
+ # percent
46
+ inputs = re.sub(r"-?(\d+\.)?\d+%",
47
+ lambda x: self.__sub_util(x.group(), "an2cn", "percent"), inputs)
48
+ # celsius
49
+ inputs = re.sub(r"\d+℃",
50
+ lambda x: self.__sub_util(x.group(), "an2cn", "celsius"), inputs)
51
+ # number
52
+ output = re.sub(r"-?(\d+\.)?\d+",
53
+ lambda x: self.__sub_util(x.group(), "an2cn", "number"), inputs)
54
+ else:
55
+ raise ValueError(f"error method: {method}, only support 'cn2an' and 'an2cn'!")
56
+
57
+ return output
58
+
59
+ def __sub_util(self, inputs, method: str = "cn2an", sub_mode: str = "number") -> str:
60
+ try:
61
+ if inputs:
62
+ if method == "cn2an":
63
+ if sub_mode == "date":
64
+ return re.sub(fr"(({self.smart_cn_pattern})|({self.cn_pattern}))",
65
+ lambda x: str(self.cn2an(x.group(), "smart")), inputs)
66
+ elif sub_mode == "fraction":
67
+ if inputs[0] != "百":
68
+ frac_result = re.sub(self.cn_pattern,
69
+ lambda x: str(self.cn2an(x.group(), "smart")), inputs)
70
+ numerator, denominator = frac_result.split("分之")
71
+ return f"{denominator}/{numerator}"
72
+ else:
73
+ return inputs
74
+ elif sub_mode == "percent":
75
+ return re.sub(f"(?<=百分之){self.cn_pattern}",
76
+ lambda x: str(self.cn2an(x.group(), "smart")), inputs).replace("百分之", "") + "%"
77
+ elif sub_mode == "celsius":
78
+ return re.sub(f"{self.cn_pattern}(?=摄氏度)",
79
+ lambda x: str(self.cn2an(x.group(), "smart")), inputs).replace("摄氏度", "℃")
80
+ elif sub_mode == "number":
81
+ return str(self.cn2an(inputs, "smart"))
82
+ else:
83
+ raise Exception(f"error sub_mode: {sub_mode} !")
84
+ else:
85
+ if sub_mode == "date":
86
+ inputs = re.sub(r"\d+(?=年)",
87
+ lambda x: self.an2cn(x.group(), "direct"), inputs)
88
+ return re.sub(r"\d+",
89
+ lambda x: self.an2cn(x.group(), "low"), inputs)
90
+ elif sub_mode == "fraction":
91
+ frac_result = re.sub(r"\d+", lambda x: self.an2cn(x.group(), "low"), inputs)
92
+ numerator, denominator = frac_result.split("/")
93
+ return f"{denominator}分之{numerator}"
94
+ elif sub_mode == "celsius":
95
+ return self.an2cn(inputs[:-1], "low") + "摄氏度"
96
+ elif sub_mode == "percent":
97
+ return "百分之" + self.an2cn(inputs[:-1], "low")
98
+ elif sub_mode == "number":
99
+ return self.an2cn(inputs, "low")
100
+ else:
101
+ raise Exception(f"error sub_mode: {sub_mode} !")
102
+ except Exception as e:
103
+ warn(str(e))
104
+ return inputs
text/cnm3/ds_CNM3.txt ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a,a
2
+ ai,ai
3
+ ai0,a0 I0
4
+ an,an
5
+ an0,a0 N0
6
+ ang,ang
7
+ ang0,A0 ng0
8
+ ao,ao
9
+ ao0,A0 O0
10
+ ba,b a
11
+ bai,b a0 I0
12
+ ban,b a0 N0
13
+ bang,b A0 ng0
14
+ bao,b A0 O0
15
+ be,b e
16
+ bei,b E0 I0
17
+ ben,b e0 N0
18
+ beng,b e0 ng0
19
+ ber,b er
20
+ bi,b i
21
+ bia,b ia
22
+ bian,b iE0 N0
23
+ biang,b iA0 ng0
24
+ biao,b iA0 O0
25
+ bie,b ie
26
+ bin,b i N0
27
+ bing,b i ng0
28
+ biong,b iO0 ng0
29
+ biu,b io0 U0
30
+ bo,b o
31
+ bong,b oo0 ng0
32
+ bou,b o0 U0
33
+ bu,b u
34
+ bua,b ua
35
+ buai,b ua0 I0
36
+ buan,b ua0 N0
37
+ buang,b uA0 ng0
38
+ bui,b uE0 I0
39
+ bun,b ue0 N0
40
+ bv,b v
41
+ bve,b ve
42
+ ca,c a
43
+ cai,c a0 I0
44
+ can,c a0 N0
45
+ cang,c A0 ng0
46
+ cao,c A0 O0
47
+ ce,c e
48
+ cei,c E0 I0
49
+ cen,c e0 N0
50
+ ceng,c e0 ng0
51
+ cer,c er
52
+ cha,ch a
53
+ chai,ch a0 I0
54
+ chan,ch a0 N0
55
+ chang,ch A0 ng0
56
+ chao,ch A0 O0
57
+ che,ch e
58
+ chei,ch E0 I0
59
+ chen,ch e0 N0
60
+ cheng,ch e0 ng0
61
+ cher,ch er
62
+ chi,ch ir
63
+ chong,ch oo0 ng0
64
+ chou,ch o0 U0
65
+ chu,ch u
66
+ chua,ch ua
67
+ chuai,ch ua0 I0
68
+ chuan,ch ua0 N0
69
+ chuang,ch uA0 ng0
70
+ chui,ch uE0 I0
71
+ chun,ch ue0 N0
72
+ chuo,ch uo
73
+ chv,ch v
74
+ chyi,ch i
75
+ ci,c i0
76
+ cong,c oo0 ng0
77
+ cou,c o0 U0
78
+ cu,c u
79
+ cua,c ua
80
+ cuai,c ua0 I0
81
+ cuan,c ua0 N0
82
+ cuang,c uA0 ng0
83
+ cui,c uE0 I0
84
+ cun,c ue0 N0
85
+ cuo,c uo
86
+ cv,c v
87
+ cyi,c i
88
+ da,d a
89
+ dai,d a0 I0
90
+ dan,d a0 N0
91
+ dang,d A0 ng0
92
+ dao,d A0 O0
93
+ de,d e
94
+ dei,d E0 I0
95
+ den,d e0 N0
96
+ deng,d e0 ng0
97
+ der,d er
98
+ di,d i
99
+ dia,d ia
100
+ dian,d iE0 N0
101
+ diang,d iA0 ng0
102
+ diao,d iA0 O0
103
+ die,d ie
104
+ din,d i N0
105
+ ding,d i ng0
106
+ diong,d iO0 ng0
107
+ diu,d io0 U0
108
+ dong,d oo0 ng0
109
+ dou,d o0 U0
110
+ du,d u
111
+ dua,d ua
112
+ duai,d ua0 I0
113
+ duan,d ua0 N0
114
+ duang,d uA0 ng0
115
+ dui,d uE0 I0
116
+ dun,d ue0 N0
117
+ duo,d uo
118
+ dv,d v
119
+ dve,d ve
120
+ e,e
121
+ ei,E0 I0
122
+ en,e0 N0
123
+ eng,e0 ng0
124
+ er,er
125
+ fa,f a
126
+ fai,f a0 I0
127
+ fan,f a0 N0
128
+ fang,f A0 ng0
129
+ fao,f A0 O0
130
+ fe,f e
131
+ fei,f E0 I0
132
+ fen,f e0 N0
133
+ feng,f e0 ng0
134
+ fer,f er
135
+ fi,f i
136
+ fia,f ia
137
+ fian,f iE0 N0
138
+ fiang,f iA0 ng0
139
+ fiao,f iA0 O0
140
+ fie,f ie
141
+ fin,f i N0
142
+ fing,f i ng0
143
+ fiong,f iO0 ng0
144
+ fiu,f io0 U0
145
+ fo,f o
146
+ fong,f oo0 ng0
147
+ fou,f o0 U0
148
+ fu,f u
149
+ fua,f ua
150
+ fuai,f ua0 I0
151
+ fuan,f ua0 N0
152
+ fuang,f uA0 ng0
153
+ fui,f uE0 I0
154
+ fun,f ue0 N0
155
+ fv,f v
156
+ fve,f ve
157
+ ga,g a
158
+ gai,g a0 I0
159
+ gan,g a0 N0
160
+ gang,g A0 ng0
161
+ gao,g A0 O0
162
+ ge,g e
163
+ gei,g E0 I0
164
+ gen,g e0 N0
165
+ geng,g e0 ng0
166
+ ger,g er
167
+ gi,g i
168
+ gia,g ia
169
+ gian,g iE0 N0
170
+ giang,g iA0 ng0
171
+ giao,g iA0 O0
172
+ gie,g ie
173
+ gin,g i N0
174
+ ging,g i ng0
175
+ giong,g iO0 ng0
176
+ giu,g io0 U0
177
+ gong,g oo0 ng0
178
+ gou,g o0 U0
179
+ gu,g u
180
+ gua,g ua
181
+ guai,g ua0 I0
182
+ guan,g ua0 N0
183
+ guang,g uA0 ng0
184
+ gui,g uE0 I0
185
+ gun,g ue0 N0
186
+ guo,g uo
187
+ gv,g v
188
+ gve,g ve
189
+ ha,h a
190
+ hai,h a0 I0
191
+ han,h a0 N0
192
+ hang,h A0 ng0
193
+ hao,h A0 O0
194
+ he,h e
195
+ hei,h E0 I0
196
+ hen,h e0 N0
197
+ heng,h e0 ng0
198
+ her,h er
199
+ hi,h i
200
+ hia,h ia
201
+ hian,h iE0 N0
202
+ hiang,h iA0 ng0
203
+ hiao,h iA0 O0
204
+ hie,h ie
205
+ hin,h i N0
206
+ hing,h i ng0
207
+ hiong,h iO0 ng0
208
+ hiu,h io0 U0
209
+ hong,h oo0 ng0
210
+ hou,h o0 U0
211
+ hu,h u
212
+ hua,h ua
213
+ huai,h ua0 I0
214
+ huan,h ua0 N0
215
+ huang,h uA0 ng0
216
+ hui,h uE0 I0
217
+ hun,h ue0 N0
218
+ huo,h uo
219
+ hv,h v
220
+ hve,h ve
221
+ ji,j i
222
+ jia,j ia
223
+ jian,j iE0 N0
224
+ jiang,j iA0 ng0
225
+ jiao,j iA0 O0
226
+ jie,j ie
227
+ jin,j i N0
228
+ jing,j i ng0
229
+ jiong,j iO0 ng0
230
+ jiu,j io0 U0
231
+ ju,j v
232
+ juan,j vE0 N0
233
+ jue,j ve
234
+ jun,j v0 N0
235
+ ka,k a
236
+ kai,k a0 I0
237
+ kan,k a0 N0
238
+ kang,k A0 ng0
239
+ kao,k A0 O0
240
+ ke,k e
241
+ kei,k E0 I0
242
+ ken,k e0 N0
243
+ keng,k e0 ng0
244
+ ker,k er
245
+ ki,k i
246
+ kia,k ia
247
+ kian,k iE0 N0
248
+ kiang,k iA0 ng0
249
+ kiao,k iA0 O0
250
+ kie,k ie
251
+ kin,k i N0
252
+ king,k i ng0
253
+ kiong,k iO0 ng0
254
+ kiu,k io0 U0
255
+ kong,k oo0 ng0
256
+ kou,k o0 U0
257
+ ku,k u
258
+ kua,k ua
259
+ kuai,k ua0 I0
260
+ kuan,k ua0 N0
261
+ kuang,k uA0 ng0
262
+ kui,k uE0 I0
263
+ kun,k ue0 N0
264
+ kuo,k uo
265
+ kv,k v
266
+ kve,k ve
267
+ la,l a
268
+ lai,l a0 I0
269
+ lan,l a0 N0
270
+ lang,l A0 ng0
271
+ lao,l A0 O0
272
+ le,l e
273
+ lei,l E0 I0
274
+ len,l e0 N0
275
+ leng,l e0 ng0
276
+ ler,l er
277
+ li,l i
278
+ lia,l ia
279
+ lian,l iE0 N0
280
+ liang,l iA0 ng0
281
+ liao,l iA0 O0
282
+ lie,l ie
283
+ lin,l i N0
284
+ ling,l i ng0
285
+ liong,l iO0 ng0
286
+ liu,l io0 U0
287
+ lo,l o
288
+ long,l oo0 ng0
289
+ lou,l o0 U0
290
+ lu,l u
291
+ lua,l ua
292
+ luai,l ua0 I0
293
+ luan,l ua0 N0
294
+ luang,l uA0 ng0
295
+ lui,l uE0 I0
296
+ lun,l ue0 N0
297
+ luo,l uo
298
+ lv,l v
299
+ lve,l ve
300
+ ma,m a
301
+ mai,m a0 I0
302
+ man,m a0 N0
303
+ mang,m A0 ng0
304
+ mao,m A0 O0
305
+ me,m e
306
+ mei,m E0 I0
307
+ men,m e0 N0
308
+ meng,m e0 ng0
309
+ mer,m er
310
+ mi,m i
311
+ mia,m ia
312
+ mian,m iE0 N0
313
+ miang,m iA0 ng0
314
+ miao,m iA0 O0
315
+ mie,m ie
316
+ min,m i N0
317
+ ming,m i ng0
318
+ miong,m iO0 ng0
319
+ miu,m io0 U0
320
+ mo,m o
321
+ mong,m oo0 ng0
322
+ mou,m o0 U0
323
+ mu,m u
324
+ mua,m ua
325
+ muai,m ua0 I0
326
+ muan,m ua0 N0
327
+ muang,m uA0 ng0
328
+ mui,m uE0 I0
329
+ mun,m ue0 N0
330
+ mv,m v
331
+ mve,m ve
332
+ n,ng
333
+ na,n a
334
+ nai,n a0 I0
335
+ nan,n a0 N0
336
+ nang,n A0 ng0
337
+ nao,n A0 O0
338
+ ne,n e
339
+ nei,n E0 I0
340
+ nen,n e0 N0
341
+ neng,n e0 ng0
342
+ ner,n er
343
+ ni,n i
344
+ nia,n ia
345
+ nian,n iE0 N0
346
+ niang,n iA0 ng0
347
+ niao,n iA0 O0
348
+ nie,n ie
349
+ nin,n i N0
350
+ ning,n i ng0
351
+ niong,n iO0 ng0
352
+ niu,n io0 U0
353
+ nong,n oo0 ng0
354
+ nou,n o0 U0
355
+ nu,n u
356
+ nua,n ua
357
+ nuai,n ua0 I0
358
+ nuan,n ua0 N0
359
+ nuang,n uA0 ng0
360
+ nui,n uE0 I0
361
+ nun,n ue0 N0
362
+ nuo,n uo
363
+ nv,n v
364
+ nve,n ve
365
+ o,o
366
+ ong,ong
367
+ ou,ou
368
+ pa,p a
369
+ pai,p a0 I0
370
+ pan,p a0 N0
371
+ pang,p A0 ng0
372
+ pao,p A0 O0
373
+ pe,p e
374
+ pei,p E0 I0
375
+ pen,p e0 N0
376
+ peng,p e0 ng0
377
+ per,p er
378
+ pi,p i
379
+ pia,p ia
380
+ pian,p iE0 N0
381
+ piang,p iA0 ng0
382
+ piao,p iA0 O0
383
+ pie,p ie
384
+ pin,p i N0
385
+ ping,p i ng0
386
+ piong,p iO0 ng0
387
+ piu,p io0 U0
388
+ po,p o
389
+ pong,p oo0 ng0
390
+ pou,p o0 U0
391
+ pu,p u
392
+ pua,p ua
393
+ puai,p ua0 I0
394
+ puan,p ua0 N0
395
+ puang,p uA0 ng0
396
+ pui,p uE0 I0
397
+ pun,p ue0 N0
398
+ pv,p v
399
+ pve,p ve
400
+ qi,q i
401
+ qia,q ia
402
+ qian,q iE0 N0
403
+ qiang,q iA0 ng0
404
+ qiao,q iA0 O0
405
+ qie,q ie
406
+ qin,q i N0
407
+ qing,q i ng0
408
+ qiong,q iO0 ng0
409
+ qiu,q io0 U0
410
+ qu,q v
411
+ quan,q vE0 N0
412
+ que,q ve
413
+ qun,q v0 N0
414
+ ra,r a
415
+ rai,r a0 I0
416
+ ran,r a0 N0
417
+ rang,r A0 ng0
418
+ rao,r A0 O0
419
+ re,r e
420
+ rei,r E0 I0
421
+ ren,r e0 N0
422
+ reng,r e0 ng0
423
+ rer,r er
424
+ ri,r ir
425
+ rong,r oo0 ng0
426
+ rou,r o0 U0
427
+ ru,r u
428
+ rua,r ua
429
+ ruai,r ua0 I0
430
+ ruan,r ua0 N0
431
+ ruang,r uA0 ng0
432
+ rui,r uE0 I0
433
+ run,r ue0 N0
434
+ ruo,r uo
435
+ rv,r v
436
+ ryi,r i
437
+ sa,s a
438
+ sai,s a0 I0
439
+ san,s a0 N0
440
+ sang,s A0 ng0
441
+ sao,s A0 O0
442
+ se,s e
443
+ sei,s E0 I0
444
+ sen,s e0 N0
445
+ seng,s e0 ng0
446
+ ser,s er
447
+ sha,sh a
448
+ shai,sh a0 I0
449
+ shan,sh a0 N0
450
+ shang,sh A0 ng0
451
+ shao,sh A0 O0
452
+ she,sh e
453
+ shei,sh E0 I0
454
+ shen,sh e0 N0
455
+ sheng,sh e0 ng0
456
+ sher,sh er
457
+ shi,sh ir
458
+ shong,sh oo0 ng0
459
+ shou,sh o0 U0
460
+ shu,sh u
461
+ shua,sh ua
462
+ shuai,sh ua0 I0
463
+ shuan,sh ua0 N0
464
+ shuang,sh uA0 ng0
465
+ shui,sh uE0 I0
466
+ shun,sh ue0 N0
467
+ shuo,sh uo
468
+ shv,sh v
469
+ shyi,sh i
470
+ si,s i0
471
+ song,s oo0 ng0
472
+ sou,s o0 U0
473
+ su,s u
474
+ sua,s ua
475
+ suai,s ua0 I0
476
+ suan,s ua0 N0
477
+ suang,s uA0 ng0
478
+ sui,s uE0 I0
479
+ sun,s ue0 N0
480
+ suo,s uo
481
+ sv,s v
482
+ syi,s i
483
+ ta,t a
484
+ tai,t a0 I0
485
+ tan,t a0 N0
486
+ tang,t A0 ng0
487
+ tao,t A0 O0
488
+ te,t e
489
+ tei,t E0 I0
490
+ ten,t e0 N0
491
+ teng,t e0 ng0
492
+ ter,t er
493
+ ti,t i
494
+ tia,t ia
495
+ tian,t iE0 N0
496
+ tiang,t iA0 ng0
497
+ tiao,t iA0 O0
498
+ tie,t ie
499
+ tin,t i N0
500
+ ting,t i ng0
501
+ tiong,t iO0 ng0
502
+ tong,t oo0 ng0
503
+ tou,t o0 U0
504
+ tu,t u
505
+ tua,t ua
506
+ tuai,t ua0 I0
507
+ tuan,t ua0 N0
508
+ tuang,t uA0 ng0
509
+ tui,t uE0 I0
510
+ tun,t ue0 N0
511
+ tuo,t uo
512
+ tv,t v
513
+ tve,t ve
514
+ wa,w a
515
+ wai,w a0 I0
516
+ wan,w a0 N0
517
+ wang,w A0 ng0
518
+ wao,w A0 O0
519
+ we,w e
520
+ wei,w E0 I0
521
+ wen,w e0 N0
522
+ weng,w e0 ng0
523
+ wer,w er
524
+ wi,w i
525
+ wo,w o
526
+ wong,w oo0 ng0
527
+ wou,w o0 U0
528
+ wu,w u
529
+ xi,x i
530
+ xia,x ia
531
+ xian,x iE0 N0
532
+ xiang,x iA0 ng0
533
+ xiao,x iA0 O0
534
+ xie,x ie
535
+ xin,x i N0
536
+ xing,x i ng0
537
+ xiong,x iO0 ng0
538
+ xiu,x io0 U0
539
+ xu,x v
540
+ xuan,x vE0 N0
541
+ xue,x ve
542
+ xun,x v0 N0
543
+ ya,y a
544
+ yai,y a0 I0
545
+ yan,y iE0 N0
546
+ yang,y A0 ng0
547
+ yao,y A0 O0
548
+ ye,y E
549
+ yei,y E0 I0
550
+ yi,y i
551
+ yin,y i N0
552
+ ying,y i ng0
553
+ yo,y o
554
+ yong,y oo0 ng0
555
+ you,y o0 U0
556
+ yu,y v
557
+ yuan,y vE0 N0
558
+ yue,y ve
559
+ yun,y v0 N0
560
+ ywu,y u
561
+ za,z a
562
+ zai,z a0 I0
563
+ zan,z a0 N0
564
+ zang,z A0 ng0
565
+ zao,z A0 O0
566
+ ze,z e
567
+ zei,z E0 I0
568
+ zen,z e0 N0
569
+ zeng,z e0 ng0
570
+ zer,z er
571
+ zha,zh a
572
+ zhai,zh a0 I0
573
+ zhan,zh a0 N0
574
+ zhang,zh A0 ng0
575
+ zhao,zh A0 O0
576
+ zhe,zh e
577
+ zhei,zh E0 I0
578
+ zhen,zh e0 N0
579
+ zheng,zh e0 ng0
580
+ zher,zh er
581
+ zhi,zh ir
582
+ zhong,zh oo0 ng0
583
+ zhou,zh o0 U0
584
+ zhu,zh u
585
+ zhua,zh ua
586
+ zhuai,zh ua0 I0
587
+ zhuan,zh ua0 N0
588
+ zhuang,zh uA0 ng0
589
+ zhui,zh uE0 I0
590
+ zhun,zh ue0 N0
591
+ zhuo,zh uo
592
+ zhv,zh v
593
+ zhyi,zh i
594
+ zi,z i0
595
+ zong,z oo0 ng0
596
+ zou,z o0 U0
597
+ zu,z u
598
+ zua,z ua
599
+ zuai,z ua0 I0
600
+ zuan,z ua0 N0
601
+ zuang,z uA0 ng0
602
+ zui,z uE0 I0
603
+ zun,z ue0 N0
604
+ zuo,z uo
605
+ zv,z v
606
+ zyi,z i
text/custom_pypinyin_dict/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
text/custom_pypinyin_dict/cc_cedict_0.py ADDED
The diff for this file is too large to render. See raw diff
 
text/custom_pypinyin_dict/cc_cedict_1.py ADDED
The diff for this file is too large to render. See raw diff
 
text/custom_pypinyin_dict/cc_cedict_2.py ADDED
The diff for this file is too large to render. See raw diff
 
text/custom_pypinyin_dict/cc_cedict_3.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import unicode_literals
3
+
4
+ # Warning: Auto-generated file, don't edit.
5
+ phrases_dict = {
6
+ '𰻝𰻝面': [['biáng'], ['biáng'], ['miàn']],
7
+ }
8
+
9
+
10
+ from pypinyin import load_phrases_dict
11
+
12
+
13
+ def load():
14
+ load_phrases_dict(phrases_dict)
text/custom_pypinyin_dict/genshin.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import unicode_literals
3
+
4
+ phrases_dict = {
5
+ '㐖毒': [['xié'], ['dú']],
6
+ '若陀': [['rě'], ['tuó']],
7
+ '平藏': [['píng'], ['zàng']],
8
+ '派蒙': [['pài'], ['méng']],
9
+ '安柏': [['ān'], ['bó']],
10
+ '一斗': [['yī'], ['dǒu']]
11
+ }
text/custom_pypinyin_dict/phrase_pinyin_data.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from __future__ import unicode_literals
3
+
4
+ from pypinyin import load_phrases_dict
5
+
6
+ from text.custom_pypinyin_dict import cc_cedict_0
7
+ from text.custom_pypinyin_dict import cc_cedict_1
8
+ from text.custom_pypinyin_dict import cc_cedict_2
9
+ from text.custom_pypinyin_dict import cc_cedict_3
10
+ from text.custom_pypinyin_dict import genshin
11
+
12
+ phrases_dict = {}
13
+ phrases_dict.update(cc_cedict_0.phrases_dict)
14
+ phrases_dict.update(cc_cedict_1.phrases_dict)
15
+ phrases_dict.update(cc_cedict_2.phrases_dict)
16
+ phrases_dict.update(cc_cedict_3.phrases_dict)
17
+ phrases_dict.update(genshin.phrases_dict)
18
+
19
+ def load():
20
+ load_phrases_dict(phrases_dict)
21
+ print("加载自定义词典成功")
22
+
23
+ if __name__ == '__main__':
24
+ print(phrases_dict)
text/english.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ '''
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ '''
14
+
15
+
16
+ # Regular expression matching whitespace:
17
+
18
+
19
+ import re
20
+ import inflect
21
+ from unidecode import unidecode
22
+ import eng_to_ipa as ipa
23
+ _inflect = inflect.engine()
24
+ _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
25
+ _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
26
+ _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
27
+ _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
28
+ _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
29
+ _number_re = re.compile(r'[0-9]+')
30
+
31
+ # List of (regular expression, replacement) pairs for abbreviations:
32
+ _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
33
+ ('mrs', 'misess'),
34
+ ('mr', 'mister'),
35
+ ('dr', 'doctor'),
36
+ ('st', 'saint'),
37
+ ('co', 'company'),
38
+ ('jr', 'junior'),
39
+ ('maj', 'major'),
40
+ ('gen', 'general'),
41
+ ('drs', 'doctors'),
42
+ ('rev', 'reverend'),
43
+ ('lt', 'lieutenant'),
44
+ ('hon', 'honorable'),
45
+ ('sgt', 'sergeant'),
46
+ ('capt', 'captain'),
47
+ ('esq', 'esquire'),
48
+ ('ltd', 'limited'),
49
+ ('col', 'colonel'),
50
+ ('ft', 'fort'),
51
+ ]]
52
+
53
+
54
+ # List of (ipa, lazy ipa) pairs:
55
+ _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
56
+ ('r', 'ɹ'),
57
+ ('æ', 'e'),
58
+ ('ɑ', 'a'),
59
+ ('ɔ', 'o'),
60
+ ('ð', 'z'),
61
+ ('θ', 's'),
62
+ ('ɛ', 'e'),
63
+ ('ɪ', 'i'),
64
+ ('ʊ', 'u'),
65
+ ('ʒ', 'ʥ'),
66
+ ('ʤ', 'ʥ'),
67
+ ('ˈ', '↓'),
68
+ ]]
69
+
70
+ # List of (ipa, lazy ipa2) pairs:
71
+ _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
72
+ ('r', 'ɹ'),
73
+ ('ð', 'z'),
74
+ ('θ', 's'),
75
+ ('ʒ', 'ʑ'),
76
+ ('ʤ', 'dʑ'),
77
+ ('ˈ', '↓'),
78
+ ]]
79
+
80
+ # List of (ipa, ipa2) pairs
81
+ _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
82
+ ('r', 'ɹ'),
83
+ ('ʤ', 'dʒ'),
84
+ ('ʧ', 'tʃ')
85
+ ]]
86
+
87
+
88
+ def expand_abbreviations(text):
89
+ for regex, replacement in _abbreviations:
90
+ text = re.sub(regex, replacement, text)
91
+ return text
92
+
93
+
94
+ def collapse_whitespace(text):
95
+ return re.sub(r'\s+', ' ', text)
96
+
97
+
98
+ def _remove_commas(m):
99
+ return m.group(1).replace(',', '')
100
+
101
+
102
+ def _expand_decimal_point(m):
103
+ return m.group(1).replace('.', ' point ')
104
+
105
+
106
+ def _expand_dollars(m):
107
+ match = m.group(1)
108
+ parts = match.split('.')
109
+ if len(parts) > 2:
110
+ return match + ' dollars' # Unexpected format
111
+ dollars = int(parts[0]) if parts[0] else 0
112
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
113
+ if dollars and cents:
114
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
115
+ cent_unit = 'cent' if cents == 1 else 'cents'
116
+ return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
117
+ elif dollars:
118
+ dollar_unit = 'dollar' if dollars == 1 else 'dollars'
119
+ return '%s %s' % (dollars, dollar_unit)
120
+ elif cents:
121
+ cent_unit = 'cent' if cents == 1 else 'cents'
122
+ return '%s %s' % (cents, cent_unit)
123
+ else:
124
+ return 'zero dollars'
125
+
126
+
127
+ def _expand_ordinal(m):
128
+ return _inflect.number_to_words(m.group(0))
129
+
130
+
131
+ def _expand_number(m):
132
+ num = int(m.group(0))
133
+ if num > 1000 and num < 3000:
134
+ if num == 2000:
135
+ return 'two thousand'
136
+ elif num > 2000 and num < 2010:
137
+ return 'two thousand ' + _inflect.number_to_words(num % 100)
138
+ elif num % 100 == 0:
139
+ return _inflect.number_to_words(num // 100) + ' hundred'
140
+ else:
141
+ return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
142
+ else:
143
+ return _inflect.number_to_words(num, andword='')
144
+
145
+
146
+ def normalize_numbers(text):
147
+ text = re.sub(_comma_number_re, _remove_commas, text)
148
+ text = re.sub(_pounds_re, r'\1 pounds', text)
149
+ text = re.sub(_dollars_re, _expand_dollars, text)
150
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
151
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
152
+ text = re.sub(_number_re, _expand_number, text)
153
+ return text
154
+
155
+
156
+ def mark_dark_l(text):
157
+ return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
158
+
159
+
160
+ def english_to_ipa(text):
161
+ text = unidecode(text).lower()
162
+ text = expand_abbreviations(text)
163
+ text = normalize_numbers(text)
164
+ phonemes = ipa.convert(text)
165
+ phonemes = collapse_whitespace(phonemes)
166
+ return phonemes
167
+
168
+
169
+ def english_to_ipa2(text):
170
+ text = english_to_ipa(text)
171
+ text = mark_dark_l(text)
172
+ for regex, replacement in _ipa_to_ipa2:
173
+ text = re.sub(regex, replacement, text)
174
+ return list(text.replace('...', '…'))
175
+
text/japanese.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from unidecode import unidecode
3
+ import pyopenjtalk
4
+
5
+
6
+ # Regular expression matching Japanese without punctuation marks:
7
+ _japanese_characters = re.compile(
8
+ r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
9
+
10
+ # Regular expression matching non-Japanese characters or punctuation marks:
11
+ _japanese_marks = re.compile(
12
+ r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
13
+
14
+ # List of (symbol, Japanese) pairs for marks:
15
+ _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
16
+ ('%', 'パーセント')
17
+ ]]
18
+
19
+ # List of (romaji, ipa) pairs for marks:
20
+ _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
21
+ ('ts', 'ʦ'),
22
+ ('u', 'ɯ'),
23
+ ('j', 'ʥ'),
24
+ ('y', 'j'),
25
+ ('ni', 'n^i'),
26
+ ('nj', 'n^'),
27
+ ('hi', 'çi'),
28
+ ('hj', 'ç'),
29
+ ('f', 'ɸ'),
30
+ ('I', 'i*'),
31
+ ('U', 'ɯ*'),
32
+ ('r', 'ɾ')
33
+ ]]
34
+
35
+ # List of (romaji, ipa2) pairs for marks:
36
+ _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
37
+ ('u', 'ɯ'),
38
+ ('ʧ', 'tʃ'),
39
+ ('j', 'dʑ'),
40
+ ('y', 'j'),
41
+ ('ni', 'n^i'),
42
+ ('nj', 'n^'),
43
+ ('hi', 'çi'),
44
+ ('hj', 'ç'),
45
+ ('f', 'ɸ'),
46
+ ('I', 'i*'),
47
+ ('U', 'ɯ*'),
48
+ ('r', 'ɾ')
49
+ ]]
50
+
51
+ # List of (consonant, sokuon) pairs:
52
+ _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
53
+ (r'Q([↑↓]*[kg])', r'k#\1'),
54
+ (r'Q([↑↓]*[tdjʧ])', r't#\1'),
55
+ (r'Q([↑↓]*[sʃ])', r's\1'),
56
+ (r'Q([↑↓]*[pb])', r'p#\1')
57
+ ]]
58
+
59
+ # List of (consonant, hatsuon) pairs:
60
+ _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
61
+ (r'N([↑↓]*[pbm])', r'm\1'),
62
+ (r'N([↑↓]*[ʧʥj])', r'n^\1'),
63
+ (r'N([↑↓]*[tdn])', r'n\1'),
64
+ (r'N([↑↓]*[kg])', r'ŋ\1')
65
+ ]]
66
+
67
+
68
+ def symbols_to_japanese(text):
69
+ for regex, replacement in _symbols_to_japanese:
70
+ text = re.sub(regex, replacement, text)
71
+ return text
72
+
73
+
74
+ def japanese_to_romaji_with_accent(text):
75
+ '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
76
+ text = symbols_to_japanese(text)
77
+ sentences = re.split(_japanese_marks, text)
78
+ marks = re.findall(_japanese_marks, text)
79
+ text = ''
80
+ for i, sentence in enumerate(sentences):
81
+ if re.match(_japanese_characters, sentence):
82
+ if text != '':
83
+ text += ' '
84
+ labels = pyopenjtalk.extract_fullcontext(sentence)
85
+ for n, label in enumerate(labels):
86
+ phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
87
+ if phoneme not in ['sil', 'pau']:
88
+ text += phoneme.replace('ch', 'ʧ').replace('sh',
89
+ 'ʃ').replace('cl', 'Q')
90
+ else:
91
+ continue
92
+ # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
93
+ a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
94
+ a2 = int(re.search(r"\+(\d+)\+", label).group(1))
95
+ a3 = int(re.search(r"\+(\d+)/", label).group(1))
96
+ if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
97
+ a2_next = -1
98
+ else:
99
+ a2_next = int(
100
+ re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
101
+ # Accent phrase boundary
102
+ if a3 == 1 and a2_next == 1:
103
+ text += ' '
104
+ # Falling
105
+ elif a1 == 0 and a2_next == a2 + 1:
106
+ text += '↓'
107
+ # Rising
108
+ elif a2 == 1 and a2_next == 2:
109
+ text += '↑'
110
+ if i < len(marks):
111
+ text += unidecode(marks[i]).replace(' ', '')
112
+ return text
113
+
114
+
115
+ def get_real_sokuon(text):
116
+ for regex, replacement in _real_sokuon:
117
+ text = re.sub(regex, replacement, text)
118
+ return text
119
+
120
+
121
+ def get_real_hatsuon(text):
122
+ for regex, replacement in _real_hatsuon:
123
+ text = re.sub(regex, replacement, text)
124
+ return text
125
+
126
+
127
+ def japanese_to_ipa(text):
128
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
129
+ text = re.sub(
130
+ r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
131
+ text = get_real_sokuon(text)
132
+ text = get_real_hatsuon(text)
133
+ for regex, replacement in _romaji_to_ipa:
134
+ text = re.sub(regex, replacement, text)
135
+ return text
136
+
137
+
138
+ def japanese_to_ipa2(text):
139
+ text = japanese_to_romaji_with_accent(text).replace('...', '…')
140
+ text = get_real_sokuon(text)
141
+ text = get_real_hatsuon(text)
142
+ for regex, replacement in _romaji_to_ipa2:
143
+ text = re.sub(regex, replacement, text)
144
+ return list(text)
145
+
146
+
147
+ def japanese_to_ipa3(text):
148
+ text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
149
+ 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
150
+ text = re.sub(
151
+ r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
152
+ text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
153
+ return text
154
+
155
+ if __name__ == '__main__':
156
+ a = japanese_to_romaji_with_accent('こんにちは!はい、元気です。あなたは?')
157
+ print(a)
text/mandarin.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict, List
3
+ from pypinyin import lazy_pinyin, Style
4
+ from .custom_pypinyin_dict import phrase_pinyin_data
5
+ import jieba
6
+ from .cn2an import an2cn
7
+
8
+ # 加载自定义拼音词典数据
9
+ phrase_pinyin_data.load()
10
+
11
+ # 标点符号正则
12
+ PUNC_MAP: Dict[str, str] = {
13
+ ":": ",",
14
+ ";": ",",
15
+ ",": ",",
16
+ "。": ".",
17
+ "!": "!",
18
+ "?": "?",
19
+ "\n": ".",
20
+ "·": ",",
21
+ "、": ",",
22
+ "$": ".",
23
+ "/": ",",
24
+ "“": "'",
25
+ "”": "'",
26
+ '"': "'",
27
+ "‘": "'",
28
+ "’": "'",
29
+ "(": "'",
30
+ ")": "'",
31
+ "(": "'",
32
+ ")": "'",
33
+ "《": "'",
34
+ "》": "'",
35
+ "【": "'",
36
+ "】": "'",
37
+ "[": "'",
38
+ "]": "'",
39
+ "—": "-",
40
+ "~": "~",
41
+ "「": "'",
42
+ "」": "'",
43
+ "『": "'",
44
+ "』": "'",
45
+ }
46
+
47
+ # from GPT_SoVITS.text.zh_normalization.text_normlization
48
+ PUNC_MAP.update ({
49
+ '/': '每',
50
+ '①': '一',
51
+ '②': '二',
52
+ '③': '三',
53
+ '④': '四',
54
+ '⑤': '五',
55
+ '⑥': '六',
56
+ '⑦': '七',
57
+ '⑧': '八',
58
+ '⑨': '九',
59
+ '⑩': '十',
60
+ 'α': '阿尔法',
61
+ 'β': '贝塔',
62
+ 'γ': '伽玛',
63
+ 'Γ': '伽玛',
64
+ 'δ': '德尔塔',
65
+ 'Δ': '德尔塔',
66
+ 'ε': '艾普西龙',
67
+ 'ζ': '捷塔',
68
+ 'η': '依塔',
69
+ 'θ': '西塔',
70
+ 'Θ': '西塔',
71
+ 'ι': '艾欧塔',
72
+ 'κ': '喀帕',
73
+ 'λ': '拉姆达',
74
+ 'Λ': '拉姆达',
75
+ 'μ': '缪',
76
+ 'ν': '拗',
77
+ 'ξ': '克西',
78
+ 'Ξ': '克西',
79
+ 'ο': '欧米克伦',
80
+ 'π': '派',
81
+ 'Π': '派',
82
+ 'ρ': '肉',
83
+ 'ς': '西格玛',
84
+ 'σ': '西格玛',
85
+ 'Σ': '西格玛',
86
+ 'τ': '套',
87
+ 'υ': '宇普西龙',
88
+ 'φ': '服艾',
89
+ 'Φ': '服艾',
90
+ 'χ': '器',
91
+ 'ψ': '普赛',
92
+ 'Ψ': '普赛',
93
+ 'ω': '欧米伽',
94
+ 'Ω': '欧米伽',
95
+ '+': '加',
96
+ '-': '减',
97
+ '×': '乘',
98
+ '÷': '除',
99
+ '=': '等',
100
+
101
+ "嗯": "恩",
102
+ "呣": "母"
103
+ })
104
+
105
+ PUNC_TABLE = str.maketrans(PUNC_MAP)
106
+
107
+ # 数字正则化
108
+ NUMBER_PATTERN: re.Pattern = re.compile(r'\d+(?:\.?\d+)?')
109
+
110
+ # 阿拉伯数字转汉字
111
+ def replace_number(match: re.Match) -> str:
112
+ return an2cn(match.group())
113
+
114
+ def normalize_number(text: str) -> str:
115
+ return NUMBER_PATTERN.sub(replace_number, text)
116
+
117
+ # get symbols of phones, not used
118
+ def load_pinyin_symbols(path):
119
+ pinyin_dict={}
120
+ temp = []
121
+ with open(path, "r", encoding='utf-8') as f:
122
+ content = f.readlines()
123
+ for line in content:
124
+ cuts = line.strip().split(',')
125
+ pinyin = cuts[0]
126
+ phones = cuts[1].split(' ')
127
+ pinyin_dict[pinyin] = phones
128
+ temp.extend(phones)
129
+ temp = list(set(temp))
130
+ tone = []
131
+ for phone in temp:
132
+ for i in range(1, 6):
133
+ phone2 = phone + str(i)
134
+ tone.append(phone2)
135
+ print(sorted(tone, key=lambda x: len(x)))
136
+ return pinyin_dict
137
+
138
+ def load_pinyin_dict(path: str) -> Dict[str, List[str]]:
139
+ pinyin_dict = {}
140
+ with open(path, "r", encoding='utf-8') as f:
141
+ for line in f:
142
+ key, value = line.strip().split(',', 1)
143
+ pinyin_dict[key] = value.split()
144
+ return pinyin_dict
145
+
146
+ import os
147
+ pinyin_dict = load_pinyin_dict(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cnm3', 'ds_CNM3.txt'))
148
+ # pinyin_dict = load_pinyin_dict('text/cnm3/ds_CNM3.txt')
149
+
150
+ def chinese_to_cnm3(text: str) -> List[str]:
151
+ # 标点符号和数字正则化
152
+ text = text.translate(PUNC_TABLE)
153
+ text = normalize_number(text)
154
+ # 过滤掉特殊字符
155
+ text = re.sub(r'[#&@“”^_|\\]', '', text)
156
+
157
+ words = jieba.lcut(text, cut_all=False)
158
+
159
+ phones = []
160
+ for word in words:
161
+ pinyin_list: List[str] = lazy_pinyin(word, style=Style.TONE3, neutral_tone_with_five=True)
162
+ for pinyin in pinyin_list:
163
+ if pinyin[-1].isdigit():
164
+ tone = pinyin[-1]
165
+ syllable = pinyin[:-1]
166
+ phone = pinyin_dict[syllable]
167
+ phones.extend([ph + tone for ph in phone])
168
+ elif pinyin[-1].isalpha():
169
+ pass
170
+ else:
171
+ phones.extend(pinyin)
172
+
173
+ return phones
text/symbols.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Defines the set of symbols used in text input to the model.
3
+ '''
4
+
5
+ # japanese_cleaners
6
+ # _pad = '_'
7
+ # _punctuation = ',.!?-'
8
+ # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9
+
10
+
11
+ '''# japanese_cleaners2
12
+ _pad = '_'
13
+ _punctuation = ',.!?-~…'
14
+ _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15
+ '''
16
+
17
+
18
+ '''# korean_cleaners
19
+ _pad = '_'
20
+ _punctuation = ',.!?…~'
21
+ _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22
+ '''
23
+
24
+ '''# chinese_cleaners
25
+ _pad = '_'
26
+ _punctuation = ',。!?—…'
27
+ _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28
+ '''
29
+
30
+ # # zh_ja_mixture_cleaners
31
+ # _pad = '_'
32
+ # _punctuation = ',.!?-~…'
33
+ # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34
+
35
+
36
+ '''# sanskrit_cleaners
37
+ _pad = '_'
38
+ _punctuation = '।'
39
+ _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40
+ '''
41
+
42
+ '''# cjks_cleaners
43
+ _pad = '_'
44
+ _punctuation = ',.!?-~…'
45
+ _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46
+ '''
47
+
48
+ '''# thai_cleaners
49
+ _pad = '_'
50
+ _punctuation = '.!? '
51
+ _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52
+ '''
53
+
54
+ # # cjke_cleaners2
55
+ _pad = '_'
56
+ _punctuation = ',.!?-~…' + "'"
57
+ _IPA_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58
+ _CNM3_letters = ['y1', 'y2', 'y3', 'y4', 'y5', 'n1', 'n2', 'n3', 'n4', 'n5', 'p1', 'p2', 'p3', 'p4', 'p5', 'x1', 'x2', 'x3', 'x4', 'x5', 'k1', 'k2', 'k3', 'k4', 'k5', 'l1', 'l2', 'l3', 'l4', 'l5', 'q1', 'q2', 'q3', 'q4', 'q5', 'w1', 'w2', 'w3', 'w4', 'w5', 'E1', 'E2', 'E3', 'E4', 'E5', 'b1', 'b2', 'b3', 'b4', 'b5', 'c1', 'c2', 'c3', 'c4', 'c5', 'z1', 'z2', 'z3', 'z4', 'z5', 'e1', 'e2', 'e3', 'e4', 'e5', 'f1', 'f2', 'f3', 'f4', 'f5', 's1', 's2', 's3', 's4', 's5', 'j1', 'j2', 'j3', 'j4', 'j5', 'o1', 'o2', 'o3', 'o4', 'o5', 'i1', 'i2', 'i3', 'i4', 'i5', 'd1', 'd2', 'd3', 'd4', 'd5', 'm1', 'm2', 'm3', 'm4', 'm5', 't1', 't2', 't3', 't4', 't5', 'h1', 'h2', 'h3', 'h4', 'h5', 'g1', 'g2', 'g3', 'g4', 'g5', 'v1', 'v2', 'v3', 'v4', 'v5', 'r1', 'r2', 'r3', 'r4', 'r5', 'a1', 'a2', 'a3', 'a4', 'a5', 'u1', 'u2', 'u3', 'u4', 'u5', 'I01', 'I02', 'I03', 'I04', 'I05', 'i01', 'i02', 'i03', 'i04', 'i05', 'uo1', 'uo2', 'uo3', 'uo4', 'uo5', 'o01', 'o02', 'o03', 'o04', 'o05', 'U01', 'U02', 'U03', 'U04', 'U05', 'v01', 'v02', 'v03', 'v04', 'v05', 'er1', 'er2', 'er3', 'er4', 'er5', 'A01', 'A02', 'A03', 'A04', 'A05', 'ai1', 'ai2', 'ai3', 'ai4', 'ai5', 'e01', 'e02', 'e03', 'e04', 'e05', 'sh1', 'sh2', 'sh3', 'sh4', 'sh5', 'an1', 'an2', 'an3', 'an4', 'an5', 'ou1', 'ou2', 'ou3', 'ou4', 'ou5', 'ch1', 'ch2', 'ch3', 'ch4', 'ch5', 'a01', 'a02', 'a03', 'a04', 'a05', 'N01', 'N02', 'N03', 'N04', 'N05', 'ao1', 'ao2', 'ao3', 'ao4', 'ao5', 've1', 've2', 've3', 've4', 've5', 'ir1', 'ir2', 'ir3', 'ir4', 'ir5', 'ng1', 'ng2', 'ng3', 'ng4', 'ng5', 'ua1', 'ua2', 'ua3', 'ua4', 'ua5', 'zh1', 'zh2', 'zh3', 'zh4', 'zh5', 'O01', 'O02', 'O03', 'O04', 'O05', 'ie1', 'ie2', 'ie3', 'ie4', 'ie5', 'E01', 'E02', 'E03', 'E04', 'E05', 'ia1', 'ia2', 'ia3', 'ia4', 'ia5', 'iE01', 'iE02', 'iE03', 'iE04', 'iE05', 'ang1', 'ang2', 'ang3', 'ang4', 'ang5', 'ng01', 'ng02', 'ng03', 'ng04', 'ng05', 'io01', 'io02', 'io03', 'io04', 'io05', 'iA01', 'iA02', 'iA03', 'iA04', 'iA05', 'uA01', 'uA02', 'uA03', 'uA04', 'uA05', 'ong1', 'ong2', 'ong3', 'ong4', 'ong5', 'oo01', 'oo02', 'oo03', 'oo04', 'oo05', 'uE01', 'uE02', 'uE03', 'uE04', 'uE05', 'vE01', 'vE02', 'vE03', 'vE04', 'vE05', 'ue01', 'ue02', 'ue03', 'ue04', 'ue05', 'ua01', 'ua02', 'ua03', 'ua04', 'ua05', 'iO01', 'iO02', 'iO03', 'iO04', 'iO05']
59
+ _additional = ['<sil>', '<asp>']
60
+ # _CNM3_letters = []
61
+
62
+
63
+ '''# shanghainese_cleaners
64
+ _pad = '_'
65
+ _punctuation = ',.!?…'
66
+ _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
67
+ '''
68
+
69
+ '''# chinese_dialect_cleaners
70
+ _pad = '_'
71
+ _punctuation = ',.!?~…─'
72
+ _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
73
+ '''
74
+
75
+ # Export all symbols:
76
+ symbols = [_pad] + list(_punctuation) + list(_IPA_letters) + _CNM3_letters + _additional
77
+
78
+ # Special symbol ids
79
+ SPACE_ID = symbols.index(" ")
utils/__init__.py ADDED
File without changes
utils/audio.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ import torchaudio
5
+
6
+ class LinearSpectrogram(nn.Module):
7
+ def __init__(self, n_fft, win_length, hop_length, pad, center, pad_mode):
8
+ super().__init__()
9
+
10
+ self.n_fft = n_fft
11
+ self.win_length = win_length
12
+ self.hop_length = hop_length
13
+ self.pad = pad
14
+ self.center = center
15
+ self.pad_mode = pad_mode
16
+
17
+ self.register_buffer("window", torch.hann_window(win_length))
18
+
19
+ def forward(self, waveform: Tensor) -> Tensor:
20
+ if waveform.ndim == 3:
21
+ waveform = waveform.squeeze(1)
22
+ waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (self.pad, self.pad), self.pad_mode).squeeze(1)
23
+ spec = torch.stft(waveform, self.n_fft, self.hop_length, self.win_length, self.window, self.center, self.pad_mode, False, True, True)
24
+ spec = torch.view_as_real(spec)
25
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
26
+ return spec
27
+
28
+
29
+ class LogMelSpectrogram(nn.Module):
30
+ def __init__(self, sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, center, pad_mode, mel_scale):
31
+ super().__init__()
32
+ self.sample_rate = sample_rate
33
+ self.n_fft = n_fft
34
+ self.win_length = win_length
35
+ self.hop_length = hop_length
36
+ self.f_min = f_min
37
+ self.f_max = f_max
38
+ self.pad = pad
39
+ self.n_mels = n_mels
40
+ self.center = center
41
+ self.pad_mode = pad_mode
42
+ self.mel_scale = mel_scale
43
+
44
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, pad, center, pad_mode)
45
+ self.mel_scale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min, f_max, (n_fft//2)+1, mel_scale, mel_scale)
46
+
47
+ def compress(self, x: Tensor) -> Tensor:
48
+ return torch.log(torch.clamp(x, min=1e-5))
49
+
50
+ def decompress(self, x: Tensor) -> Tensor:
51
+ return torch.exp(x)
52
+
53
+ def forward(self, x: Tensor) -> Tensor:
54
+ linear_spec = self.spectrogram(x)
55
+ x = self.mel_scale(linear_spec)
56
+ x = self.compress(x)
57
+ return x
58
+
59
+ def load_and_resample_audio(audio_path, target_sr, device='cpu') -> Tensor:
60
+ try:
61
+ y, sr = torchaudio.load(audio_path)
62
+ except Exception as e:
63
+ print(str(e))
64
+ return None
65
+
66
+ y.to(device)
67
+ # Convert to mono
68
+ if y.size(0) > 1:
69
+ y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
70
+
71
+ # resample audio to target sample_rate
72
+ if sr != target_sr:
73
+ y = torchaudio.functional.resample(y, sr, target_sr)
74
+ return y
utils/load.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.nn.parallel import DistributedDataParallel as DDP
6
+
7
+ def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int:
8
+ """load the latest checkpoints and optimizers"""
9
+ model_dict = {}
10
+ optimizer_dict = {}
11
+
12
+ # globt all the checkpoints in the directory
13
+ for file in os.listdir(checkpoint_path):
14
+ if file.endswith(".pt") and '_' in file:
15
+ name, epoch_str = file.rsplit('_', 1)
16
+ epoch = int(epoch_str.split('.')[0])
17
+
18
+ if name.startswith("checkpoint"):
19
+ model_dict[epoch] = file
20
+ elif name.startswith("optimizer"):
21
+ optimizer_dict[epoch] = file
22
+
23
+ # get the largest epoch
24
+ common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys())
25
+ if common_epochs:
26
+ max_epoch = max(common_epochs)
27
+ model_path = os.path.join(checkpoint_path, model_dict[max_epoch])
28
+ optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch])
29
+
30
+ # load model and optimizer
31
+ model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
32
+ optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu'))
33
+
34
+ print(f'resume model and optimizer from {max_epoch} epoch')
35
+ return max_epoch + 1
36
+
37
+ else:
38
+ # load pretrained checkpoint
39
+ if model_dict:
40
+ model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())])
41
+ model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
42
+
43
+ return 0