barghavani commited on
Commit
42e8393
1 Parent(s): 4d8dcfc

Upload multi-speaker.py

Browse files
Files changed (1) hide show
  1. multi-speaker.py +185 -0
multi-speaker.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "7"
3
+
4
+
5
+ from trainer import Trainer, TrainerArgs
6
+
7
+ from TTS.tts.configs.shared_configs import BaseDatasetConfig , CharactersConfig
8
+ from TTS.config.shared_configs import BaseAudioConfig
9
+ from TTS.tts.configs.vits_config import VitsConfig
10
+ from TTS.tts.datasets import load_tts_samples
11
+ from TTS.tts.models.vits import Vits, VitsAudioConfig, VitsArgs
12
+ from TTS.tts.utils.text.tokenizer import TTSTokenizer
13
+ from TTS.utils.audio import AudioProcessor
14
+ from TTS.tts.utils.speakers import SpeakerManager
15
+
16
+ #import wandb
17
+ # Start a wandb run with `sync_tensorboard=True`
18
+ #if wandb.run is None:
19
+ #wandb.init(project="persian-tts-vits-grapheme-cv15-fa-male-native-multispeaker-RERUN", group="GPUx8 accel mixed bf16 128x32", sync_tensorboard=True)
20
+
21
+ # output_path = os.path.dirname(os.path.abspath(__file__))
22
+ # output_path = output_path + '/notebook_files/runs'
23
+ # output_path = wandb.run.dir ### PROBABLY better for notebook
24
+ output_path = "runs"
25
+
26
+ # print("output path is:")
27
+ # print(output_path)
28
+
29
+ cache_path = "cache"
30
+
31
+
32
+
33
+ # def mozilla(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
34
+ # """Normalizes Mozilla meta data files to TTS format"""
35
+ # txt_file = os.path.join(root_path, meta_file)
36
+ # items = []
37
+ # # speaker_name = "mozilla"
38
+ # with open(txt_file, "r", encoding="utf-8") as ttf:
39
+ # for line in ttf:
40
+ # cols = line.split("|")
41
+ # wav_file = cols[1].strip()
42
+ # text = cols[0].strip()
43
+ # speaker_name = cols[2].strip()
44
+ # wav_file = os.path.join(root_path, "wavs", wav_file)
45
+ # items.append({"text": text, "audio_file": wav_file, "speaker_name": speaker_name, "root_path": root_path})
46
+ # return items
47
+
48
+
49
+
50
+ dataset_config = BaseDatasetConfig(
51
+ formatter='common_voice', meta_file_train='validated.tsv', path="/home/bargh1/TTS/datasets"
52
+ )
53
+
54
+
55
+
56
+
57
+ character_config=CharactersConfig(
58
+ characters='ءابتثجحخدذرزسشصضطظعغفقلمنهويِپچژکگیآأؤإئًَُّ',
59
+ # characters="!¡'(),-.:;¿?ABCDEFGHIJKLMNOPRSTUVWXYZabcdefghijklmnopqrstuvwxyzáçèéêëìíîïñòóôöùúûü«°±µ»$%&‘’‚“`”„",
60
+ punctuations='!(),-.:;? ̠،؛؟‌<>٫',
61
+ phonemes='ˈˌːˑpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟaegiouwyɪʊ̩æɑɔəɚɛɝɨ̃ʉʌʍ0123456789"#$%*+/=ABCDEFGHIJKLMNOPRSTUVWXYZ[]^_{}۱۲۳۴۵۶۷۸۹۰',
62
+ pad="<PAD>",
63
+ eos="<EOS>",
64
+ bos="<BOS>",
65
+ blank="<BLNK>",
66
+ characters_class="TTS.tts.models.vits.VitsCharacters",
67
+ )
68
+
69
+ # From the coqui multilinguL recipes, will try later
70
+ vitsArgs = VitsArgs(
71
+ # use_language_embedding=True,
72
+ # embedded_language_dim=1,
73
+ use_speaker_embedding=True,
74
+ use_sdp=False,
75
+ )
76
+
77
+ audio_config = BaseAudioConfig(
78
+ sample_rate=22050,
79
+ do_trim_silence=True,
80
+ min_level_db=-1,
81
+ # do_sound_norm=True,
82
+ signal_norm=True,
83
+ clip_norm=True,
84
+ symmetric_norm=True,
85
+ max_norm = 0.9,
86
+ resample=True,
87
+ win_length=1024,
88
+ hop_length=256,
89
+ num_mels=80,
90
+ mel_fmin=0,
91
+ mel_fmax=None
92
+ )
93
+
94
+ vits_audio_config = VitsAudioConfig(
95
+ sample_rate=22050,
96
+ # do_sound_norm=True,
97
+ win_length=1024,
98
+ hop_length=256,
99
+ num_mels=80,
100
+ # do_trim_silence=True, #from hugging
101
+ mel_fmin=0,
102
+ mel_fmax=None
103
+ )
104
+ config = VitsConfig(
105
+ model_args=vitsArgs,
106
+ audio=vits_audio_config, #from huggingface
107
+ run_name="persian-tts-vits-grapheme-cv15-multispeaker-RERUN",
108
+ use_speaker_embedding=True, ## For MULTI SPEAKER
109
+ batch_size=8,
110
+ batch_group_size=16,
111
+ eval_batch_size=4,
112
+ num_loader_workers=16,
113
+ num_eval_loader_workers=8,
114
+ run_eval=True,
115
+ run_eval_steps = 1000,
116
+ print_eval=True,
117
+ test_delay_epochs=-1,
118
+ epochs=1000,
119
+ save_step=1000,
120
+ text_cleaner="basic_cleaners", #from MH
121
+ use_phonemes=False,
122
+ # phonemizer='persian_mh', #from TTS github
123
+ # phoneme_language="fa",
124
+ characters=character_config, #test without as well
125
+ phoneme_cache_path=os.path.join(cache_path, "phoneme_cache_grapheme_azure-2"),
126
+ compute_input_seq_cache=True,
127
+ print_step=25,
128
+ mixed_precision=False, #from TTS - True causes error "Expected reduction dim"
129
+ test_sentences=[
130
+ ["زین همرهان سست عناصر، دلم گرفت."],
131
+ ["بیا تا گل برافشانیم و می در ساغر اندازیم."],
132
+ ["بنی آدم اعضای یک پیکرند, که در آفرینش ز یک گوهرند."],
133
+ ["سهام زندگی به 10 درصد و سهام بیتکوین گوگل به 33 درصد افزایش یافت."],
134
+ ["من بودم و آبجی فوتینا، و حالا رپتی پتینا. این شعر یکی از اشعار معروف رو حوضی است که در کوچه بازار تهران زمزمه می شده است." ],
135
+ ["یه دو دقه هم به حرفم گوش کن، نگو نگوشیدم و نحرفیدی."],
136
+ [ "داستان با توصیف طوفان‌های شدید آغاز می‌شود؛ طوفان‌هایی که مزرعه‌ها را از بین می‌برد و محصولات را زیر شن دفن می‌کند؛ محصولاتی که زندگی افراد بسیاری به آن وابسته است."]
137
+ ],
138
+ output_path=output_path,
139
+ datasets=[dataset_config]
140
+ )
141
+
142
+ # INITIALIZE THE AUDIO PROCESSOR
143
+ # Audio processor is used for feature extraction and audio I/O.
144
+ # It mainly serves to the dataloader and the training loggers.
145
+ ap = AudioProcessor.init_from_config(config)
146
+
147
+ # INITIALIZE THE TOKENIZER
148
+ # Tokenizer is used to convert text to sequences of token IDs.
149
+ # config is updated with the default characters if not defined in the config.
150
+ tokenizer, config = TTSTokenizer.init_from_config(config)
151
+
152
+ # LOAD DATA SAMPLES
153
+ # Each sample is a list of ```[text, audio_file_path, speaker_name]```
154
+ # You can define your custom sample loader returning the list of samples.
155
+ # Or define your custom formatter and pass it to the `load_tts_samples`.
156
+ # Check `TTS.tts.datasets.load_tts_samples` for more details.
157
+ train_samples, eval_samples = load_tts_samples(
158
+ dataset_config,
159
+ eval_split=True,
160
+ eval_split_max_size=config.eval_split_max_size,
161
+ eval_split_size=config.eval_split_size,
162
+ )
163
+
164
+ # init speaker manager for multi-speaker training
165
+ # it maps speaker-id to speaker-name in the model and data-loader
166
+ speaker_manager = SpeakerManager()
167
+ speaker_manager.set_ids_from_data(train_samples + eval_samples, parse_key="speaker_name")
168
+ config.num_speakers = speaker_manager.num_speakers
169
+
170
+
171
+
172
+ # init model
173
+ model = Vits(config, ap, tokenizer, speaker_manager=speaker_manager)
174
+
175
+ # init the trainer and 🚀
176
+
177
+ trainer = Trainer(
178
+ TrainerArgs(use_accelerate=True),
179
+ config,
180
+ output_path,
181
+ model=model,
182
+ train_samples=train_samples,
183
+ eval_samples=eval_samples,
184
+ )
185
+ trainer.fit()