Spaces:
Runtime error
Runtime error
mrfakename
commited on
Commit
•
5cf7b18
1
Parent(s):
6eb9ea3
[Experimental] Gruut support
Browse files- app.py +4 -3
- gruut_phonemize.py +10 -0
- requirements.txt +2 -1
- styletts2importable.py +72 -58
app.py
CHANGED
@@ -16,13 +16,13 @@ voices = {}
|
|
16 |
# else:
|
17 |
for v in voicelist:
|
18 |
voices[v] = styletts2importable.compute_style(f'voices/{v}.wav')
|
19 |
-
def synthesize(text, voice):
|
20 |
if text.strip() == "":
|
21 |
raise gr.Error("You must enter some text")
|
22 |
if len(text) > 300:
|
23 |
raise gr.Error("Text must be under 300 characters")
|
24 |
v = voice.lower()
|
25 |
-
return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1))
|
26 |
def clsynthesize(text, voice):
|
27 |
if text.strip() == "":
|
28 |
raise gr.Error("You must enter some text")
|
@@ -43,10 +43,11 @@ with gr.Blocks() as vctk:
|
|
43 |
with gr.Column(scale=1):
|
44 |
inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
45 |
voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-1', interactive=True)
|
|
|
46 |
with gr.Column(scale=1):
|
47 |
btn = gr.Button("Synthesize", variant="primary")
|
48 |
audio = gr.Audio(interactive=False, label="Synthesized Audio")
|
49 |
-
btn.click(synthesize, inputs=[inp, voice], outputs=[audio], concurrency_limit=4)
|
50 |
with gr.Blocks() as clone:
|
51 |
with gr.Row():
|
52 |
with gr.Column(scale=1):
|
|
|
16 |
# else:
|
17 |
for v in voicelist:
|
18 |
voices[v] = styletts2importable.compute_style(f'voices/{v}.wav')
|
19 |
+
def synthesize(text, voice, use_gruut):
|
20 |
if text.strip() == "":
|
21 |
raise gr.Error("You must enter some text")
|
22 |
if len(text) > 300:
|
23 |
raise gr.Error("Text must be under 300 characters")
|
24 |
v = voice.lower()
|
25 |
+
return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1, use_gruut=use_gruut))
|
26 |
def clsynthesize(text, voice):
|
27 |
if text.strip() == "":
|
28 |
raise gr.Error("You must enter some text")
|
|
|
43 |
with gr.Column(scale=1):
|
44 |
inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
45 |
voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-1', interactive=True)
|
46 |
+
use_gruut = gr.Checkbox(label="Use alternate phonemizer (Gruut) - Experimental")
|
47 |
with gr.Column(scale=1):
|
48 |
btn = gr.Button("Synthesize", variant="primary")
|
49 |
audio = gr.Audio(interactive=False, label="Synthesized Audio")
|
50 |
+
btn.click(synthesize, inputs=[inp, voice, use_gruut], outputs=[audio], concurrency_limit=4)
|
51 |
with gr.Blocks() as clone:
|
52 |
with gr.Row():
|
53 |
with gr.Column(scale=1):
|
gruut_phonemize.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gruut import sentences
|
2 |
+
|
3 |
+
|
4 |
+
def gphonemize(text):
|
5 |
+
phonemes = ''
|
6 |
+
for sent in sentences(text, lang="en-us"):
|
7 |
+
for word in sent:
|
8 |
+
if word.phonemes:
|
9 |
+
phonemes += ''.join(word.phonemes)
|
10 |
+
return phonemes
|
requirements.txt
CHANGED
@@ -18,4 +18,5 @@ git+https://github.com/resemble-ai/monotonic_align.git
|
|
18 |
scipy
|
19 |
phonemizer
|
20 |
cached-path
|
21 |
-
gradio
|
|
|
|
18 |
scipy
|
19 |
phonemizer
|
20 |
cached-path
|
21 |
+
gradio
|
22 |
+
gruut
|
styletts2importable.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
from cached_path import cached_path
|
|
|
|
|
2 |
|
3 |
# from dp.phonemizer import Phonemizer
|
4 |
print("NLTK")
|
@@ -131,9 +133,12 @@ sampler = DiffusionSampler(
|
|
131 |
clamp=False
|
132 |
)
|
133 |
|
134 |
-
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
|
135 |
text = text.strip()
|
136 |
-
|
|
|
|
|
|
|
137 |
ps = word_tokenize(ps[0])
|
138 |
ps = ' '.join(ps)
|
139 |
tokens = textclenaer(ps)
|
@@ -200,86 +205,92 @@ def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding
|
|
200 |
|
201 |
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
202 |
|
203 |
-
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
210 |
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
|
223 |
-
|
224 |
embedding=bert_dur,
|
225 |
embedding_scale=embedding_scale,
|
226 |
-
|
227 |
num_steps=diffusion_steps).squeeze(1)
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
|
233 |
-
|
234 |
-
|
235 |
|
236 |
-
|
237 |
-
|
238 |
|
239 |
-
|
240 |
|
241 |
-
|
242 |
s, input_lengths, text_mask)
|
243 |
|
244 |
-
|
245 |
-
|
246 |
|
247 |
-
|
248 |
-
|
249 |
|
250 |
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
|
265 |
-
|
266 |
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
|
274 |
-
|
275 |
-
|
276 |
|
277 |
|
278 |
-
|
279 |
|
280 |
-
def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
|
281 |
text = text.strip()
|
282 |
-
|
|
|
|
|
|
|
283 |
ps = word_tokenize(ps[0])
|
284 |
ps = ' '.join(ps)
|
285 |
|
@@ -288,7 +299,10 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
|
|
288 |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
289 |
|
290 |
ref_text = ref_text.strip()
|
291 |
-
|
|
|
|
|
|
|
292 |
ps = word_tokenize(ps[0])
|
293 |
ps = ' '.join(ps)
|
294 |
|
|
|
1 |
from cached_path import cached_path
|
2 |
+
print("GRUUT")
|
3 |
+
from gruut_phonemize import gphonemize
|
4 |
|
5 |
# from dp.phonemizer import Phonemizer
|
6 |
print("NLTK")
|
|
|
133 |
clamp=False
|
134 |
)
|
135 |
|
136 |
+
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
137 |
text = text.strip()
|
138 |
+
if use_gruut:
|
139 |
+
ps = gphonemize(text)
|
140 |
+
else:
|
141 |
+
ps = global_phonemizer.phonemize([text])
|
142 |
ps = word_tokenize(ps[0])
|
143 |
ps = ' '.join(ps)
|
144 |
tokens = textclenaer(ps)
|
|
|
205 |
|
206 |
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
207 |
|
208 |
+
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
209 |
+
text = text.strip()
|
210 |
+
if use_gruut:
|
211 |
+
ps = gphonemize(text)
|
212 |
+
else:
|
213 |
+
ps = global_phonemizer.phonemize([text])
|
214 |
+
ps = word_tokenize(ps[0])
|
215 |
+
ps = ' '.join(ps)
|
216 |
+
ps = ps.replace('``', '"')
|
217 |
+
ps = ps.replace("''", '"')
|
218 |
|
219 |
+
tokens = textclenaer(ps)
|
220 |
+
tokens.insert(0, 0)
|
221 |
+
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
222 |
|
223 |
+
with torch.no_grad():
|
224 |
+
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
225 |
+
text_mask = length_to_mask(input_lengths).to(device)
|
226 |
|
227 |
+
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
228 |
+
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
229 |
+
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
230 |
|
231 |
+
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
232 |
embedding=bert_dur,
|
233 |
embedding_scale=embedding_scale,
|
234 |
+
features=ref_s, # reference from the same speaker as the embedding
|
235 |
num_steps=diffusion_steps).squeeze(1)
|
236 |
|
237 |
+
if s_prev is not None:
|
238 |
+
# convex combination of previous and current style
|
239 |
+
s_pred = t * s_prev + (1 - t) * s_pred
|
240 |
|
241 |
+
s = s_pred[:, 128:]
|
242 |
+
ref = s_pred[:, :128]
|
243 |
|
244 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
245 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
246 |
|
247 |
+
s_pred = torch.cat([ref, s], dim=-1)
|
248 |
|
249 |
+
d = model.predictor.text_encoder(d_en,
|
250 |
s, input_lengths, text_mask)
|
251 |
|
252 |
+
x, _ = model.predictor.lstm(d)
|
253 |
+
duration = model.predictor.duration_proj(x)
|
254 |
|
255 |
+
duration = torch.sigmoid(duration).sum(axis=-1)
|
256 |
+
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
257 |
|
258 |
|
259 |
+
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
260 |
+
c_frame = 0
|
261 |
+
for i in range(pred_aln_trg.size(0)):
|
262 |
+
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
263 |
+
c_frame += int(pred_dur[i].data)
|
264 |
|
265 |
+
# encode prosody
|
266 |
+
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
267 |
+
if model_params.decoder.type == "hifigan":
|
268 |
+
asr_new = torch.zeros_like(en)
|
269 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
270 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
271 |
+
en = asr_new
|
272 |
|
273 |
+
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
274 |
|
275 |
+
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
276 |
+
if model_params.decoder.type == "hifigan":
|
277 |
+
asr_new = torch.zeros_like(asr)
|
278 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
279 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
280 |
+
asr = asr_new
|
281 |
|
282 |
+
out = model.decoder(asr,
|
283 |
+
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
284 |
|
285 |
|
286 |
+
return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
|
287 |
|
288 |
+
def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
289 |
text = text.strip()
|
290 |
+
if use_gruut:
|
291 |
+
ps = gphonemize(text)
|
292 |
+
else:
|
293 |
+
ps = global_phonemizer.phonemize([text])
|
294 |
ps = word_tokenize(ps[0])
|
295 |
ps = ' '.join(ps)
|
296 |
|
|
|
299 |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
300 |
|
301 |
ref_text = ref_text.strip()
|
302 |
+
if use_gruut:
|
303 |
+
ps = gphonemize(text)
|
304 |
+
else:
|
305 |
+
ps = global_phonemizer.phonemize([ref_text])
|
306 |
ps = word_tokenize(ps[0])
|
307 |
ps = ' '.join(ps)
|
308 |
|