Maximofn commited on
Commit
9959a9b
1 Parent(s): 2d71d11

Start de quickstart

Browse files
Files changed (1) hide show
  1. translatube.py +5 -634
translatube.py CHANGED
@@ -1,637 +1,8 @@
1
- from __future__ import annotations
2
-
3
  import gradio as gr
4
- import numpy as np
5
- import torch
6
- import torchaudio
7
- from huggingface_hub import hf_hub_download
8
- from seamless_communication.models.inference.translator import Translator
9
-
10
- DESCRIPTION = """# TranslaTube"""
11
-
12
- TASK_NAMES = [
13
- "S2ST (Speech to Speech translation)",
14
- "S2TT (Speech to Text translation)",
15
- "T2ST (Text to Speech translation)",
16
- "T2TT (Text to Text translation)",
17
- "ASR (Automatic Speech Recognition)",
18
- ]
19
-
20
- # Language dict
21
- language_code_to_name = {
22
- "afr": "Afrikaans",
23
- "amh": "Amharic",
24
- "arb": "Modern Standard Arabic",
25
- "ary": "Moroccan Arabic",
26
- "arz": "Egyptian Arabic",
27
- "asm": "Assamese",
28
- "ast": "Asturian",
29
- "azj": "North Azerbaijani",
30
- "bel": "Belarusian",
31
- "ben": "Bengali",
32
- "bos": "Bosnian",
33
- "bul": "Bulgarian",
34
- "cat": "Catalan",
35
- "ceb": "Cebuano",
36
- "ces": "Czech",
37
- "ckb": "Central Kurdish",
38
- "cmn": "Mandarin Chinese",
39
- "cym": "Welsh",
40
- "dan": "Danish",
41
- "deu": "German",
42
- "ell": "Greek",
43
- "eng": "English",
44
- "est": "Estonian",
45
- "eus": "Basque",
46
- "fin": "Finnish",
47
- "fra": "French",
48
- "gaz": "West Central Oromo",
49
- "gle": "Irish",
50
- "glg": "Galician",
51
- "guj": "Gujarati",
52
- "heb": "Hebrew",
53
- "hin": "Hindi",
54
- "hrv": "Croatian",
55
- "hun": "Hungarian",
56
- "hye": "Armenian",
57
- "ibo": "Igbo",
58
- "ind": "Indonesian",
59
- "isl": "Icelandic",
60
- "ita": "Italian",
61
- "jav": "Javanese",
62
- "jpn": "Japanese",
63
- "kam": "Kamba",
64
- "kan": "Kannada",
65
- "kat": "Georgian",
66
- "kaz": "Kazakh",
67
- "kea": "Kabuverdianu",
68
- "khk": "Halh Mongolian",
69
- "khm": "Khmer",
70
- "kir": "Kyrgyz",
71
- "kor": "Korean",
72
- "lao": "Lao",
73
- "lit": "Lithuanian",
74
- "ltz": "Luxembourgish",
75
- "lug": "Ganda",
76
- "luo": "Luo",
77
- "lvs": "Standard Latvian",
78
- "mai": "Maithili",
79
- "mal": "Malayalam",
80
- "mar": "Marathi",
81
- "mkd": "Macedonian",
82
- "mlt": "Maltese",
83
- "mni": "Meitei",
84
- "mya": "Burmese",
85
- "nld": "Dutch",
86
- "nno": "Norwegian Nynorsk",
87
- "nob": "Norwegian Bokm\u00e5l",
88
- "npi": "Nepali",
89
- "nya": "Nyanja",
90
- "oci": "Occitan",
91
- "ory": "Odia",
92
- "pan": "Punjabi",
93
- "pbt": "Southern Pashto",
94
- "pes": "Western Persian",
95
- "pol": "Polish",
96
- "por": "Portuguese",
97
- "ron": "Romanian",
98
- "rus": "Russian",
99
- "slk": "Slovak",
100
- "slv": "Slovenian",
101
- "sna": "Shona",
102
- "snd": "Sindhi",
103
- "som": "Somali",
104
- "spa": "Spanish",
105
- "srp": "Serbian",
106
- "swe": "Swedish",
107
- "swh": "Swahili",
108
- "tam": "Tamil",
109
- "tel": "Telugu",
110
- "tgk": "Tajik",
111
- "tgl": "Tagalog",
112
- "tha": "Thai",
113
- "tur": "Turkish",
114
- "ukr": "Ukrainian",
115
- "urd": "Urdu",
116
- "uzn": "Northern Uzbek",
117
- "vie": "Vietnamese",
118
- "xho": "Xhosa",
119
- "yor": "Yoruba",
120
- "yue": "Cantonese",
121
- "zlm": "Colloquial Malay",
122
- "zsm": "Standard Malay",
123
- "zul": "Zulu",
124
- }
125
- LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
126
-
127
- # Source langs: S2ST / S2TT / ASR don't need source lang
128
- # T2TT / T2ST use this
129
- text_source_language_codes = [
130
- "afr",
131
- "amh",
132
- "arb",
133
- "ary",
134
- "arz",
135
- "asm",
136
- "azj",
137
- "bel",
138
- "ben",
139
- "bos",
140
- "bul",
141
- "cat",
142
- "ceb",
143
- "ces",
144
- "ckb",
145
- "cmn",
146
- "cym",
147
- "dan",
148
- "deu",
149
- "ell",
150
- "eng",
151
- "est",
152
- "eus",
153
- "fin",
154
- "fra",
155
- "gaz",
156
- "gle",
157
- "glg",
158
- "guj",
159
- "heb",
160
- "hin",
161
- "hrv",
162
- "hun",
163
- "hye",
164
- "ibo",
165
- "ind",
166
- "isl",
167
- "ita",
168
- "jav",
169
- "jpn",
170
- "kan",
171
- "kat",
172
- "kaz",
173
- "khk",
174
- "khm",
175
- "kir",
176
- "kor",
177
- "lao",
178
- "lit",
179
- "lug",
180
- "luo",
181
- "lvs",
182
- "mai",
183
- "mal",
184
- "mar",
185
- "mkd",
186
- "mlt",
187
- "mni",
188
- "mya",
189
- "nld",
190
- "nno",
191
- "nob",
192
- "npi",
193
- "nya",
194
- "ory",
195
- "pan",
196
- "pbt",
197
- "pes",
198
- "pol",
199
- "por",
200
- "ron",
201
- "rus",
202
- "slk",
203
- "slv",
204
- "sna",
205
- "snd",
206
- "som",
207
- "spa",
208
- "srp",
209
- "swe",
210
- "swh",
211
- "tam",
212
- "tel",
213
- "tgk",
214
- "tgl",
215
- "tha",
216
- "tur",
217
- "ukr",
218
- "urd",
219
- "uzn",
220
- "vie",
221
- "yor",
222
- "yue",
223
- "zsm",
224
- "zul",
225
- ]
226
- TEXT_SOURCE_LANGUAGE_NAMES = sorted(
227
- [language_code_to_name[code] for code in text_source_language_codes]
228
- )
229
-
230
- # Target langs:
231
- # S2ST / T2ST
232
- s2st_target_language_codes = [
233
- "eng",
234
- "arb",
235
- "ben",
236
- "cat",
237
- "ces",
238
- "cmn",
239
- "cym",
240
- "dan",
241
- "deu",
242
- "est",
243
- "fin",
244
- "fra",
245
- "hin",
246
- "ind",
247
- "ita",
248
- "jpn",
249
- "kor",
250
- "mlt",
251
- "nld",
252
- "pes",
253
- "pol",
254
- "por",
255
- "ron",
256
- "rus",
257
- "slk",
258
- "spa",
259
- "swe",
260
- "swh",
261
- "tel",
262
- "tgl",
263
- "tha",
264
- "tur",
265
- "ukr",
266
- "urd",
267
- "uzn",
268
- "vie",
269
- ]
270
- S2ST_TARGET_LANGUAGE_NAMES = sorted(
271
- [language_code_to_name[code] for code in s2st_target_language_codes]
272
- )
273
- # S2TT / ASR
274
- S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
275
- # T2TT
276
- T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
277
-
278
- # Download sample input audio files
279
- filenames = ["assets/sample_input.mp3", "assets/sample_input_2.mp3"]
280
- for filename in filenames:
281
- hf_hub_download(
282
- repo_id="facebook/seamless_m4t",
283
- repo_type="space",
284
- filename=filename,
285
- local_dir=".",
286
- )
287
-
288
- AUDIO_SAMPLE_RATE = 16000.0
289
- MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
290
- DEFAULT_TARGET_LANGUAGE = "French"
291
-
292
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
293
- translator = Translator(
294
- model_name_or_card="seamlessM4T_large",
295
- vocoder_name_or_card="vocoder_36langs",
296
- device=device,
297
- dtype=torch.float16 if "cuda" in device.type else torch.float32,
298
- )
299
-
300
-
301
- def predict(
302
- task_name: str,
303
- audio_source: str,
304
- input_audio_mic: str | None,
305
- input_audio_file: str | None,
306
- input_text: str | None,
307
- source_language: str | None,
308
- target_language: str,
309
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
310
- task_name = task_name.split()[0]
311
- source_language_code = (
312
- LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
313
- )
314
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
315
-
316
- if task_name in ["S2ST", "S2TT", "ASR"]:
317
- if audio_source == "microphone":
318
- input_data = input_audio_mic
319
- else:
320
- input_data = input_audio_file
321
-
322
- arr, org_sr = torchaudio.load(input_data)
323
- new_arr = torchaudio.functional.resample(
324
- arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE
325
- )
326
- max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
327
- if new_arr.shape[1] > max_length:
328
- new_arr = new_arr[:, :max_length]
329
- gr.Warning(
330
- f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used."
331
- )
332
- torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
333
- else:
334
- input_data = input_text
335
- text_out, wav, sr = translator.predict(
336
- input=input_data,
337
- task_str=task_name,
338
- tgt_lang=target_language_code,
339
- src_lang=source_language_code,
340
- ngram_filtering=True,
341
- )
342
- if task_name in ["S2ST", "T2ST"]:
343
- return (sr, wav.cpu().detach().numpy()), text_out
344
- else:
345
- return None, text_out
346
-
347
-
348
- def process_s2st_example(
349
- input_audio_file: str, target_language: str
350
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
351
- return predict(
352
- task_name="S2ST",
353
- audio_source="file",
354
- input_audio_mic=None,
355
- input_audio_file=input_audio_file,
356
- input_text=None,
357
- source_language=None,
358
- target_language=target_language,
359
- )
360
-
361
-
362
- def process_s2tt_example(
363
- input_audio_file: str, target_language: str
364
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
365
- return predict(
366
- task_name="S2TT",
367
- audio_source="file",
368
- input_audio_mic=None,
369
- input_audio_file=input_audio_file,
370
- input_text=None,
371
- source_language=None,
372
- target_language=target_language,
373
- )
374
-
375
-
376
- def process_t2st_example(
377
- input_text: str, source_language: str, target_language: str
378
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
379
- return predict(
380
- task_name="T2ST",
381
- audio_source="",
382
- input_audio_mic=None,
383
- input_audio_file=None,
384
- input_text=input_text,
385
- source_language=source_language,
386
- target_language=target_language,
387
- )
388
-
389
-
390
- def process_t2tt_example(
391
- input_text: str, source_language: str, target_language: str
392
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
393
- return predict(
394
- task_name="T2TT",
395
- audio_source="",
396
- input_audio_mic=None,
397
- input_audio_file=None,
398
- input_text=input_text,
399
- source_language=source_language,
400
- target_language=target_language,
401
- )
402
-
403
-
404
- def process_asr_example(
405
- input_audio_file: str, target_language: str
406
- ) -> tuple[tuple[int, np.ndarray] | None, str]:
407
- return predict(
408
- task_name="ASR",
409
- audio_source="file",
410
- input_audio_mic=None,
411
- input_audio_file=input_audio_file,
412
- input_text=None,
413
- source_language=None,
414
- target_language=target_language,
415
- )
416
-
417
-
418
- def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
419
- mic = audio_source == "microphone"
420
- return (
421
- gr.update(visible=mic, value=None), # input_audio_mic
422
- gr.update(visible=not mic, value=None), # input_audio_file
423
- )
424
-
425
-
426
- def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
427
- task_name = task_name.split()[0]
428
- if task_name == "S2ST":
429
- return (
430
- gr.update(visible=True), # audio_box
431
- gr.update(visible=False), # input_text
432
- gr.update(visible=False), # source_language
433
- gr.update(
434
- visible=True,
435
- choices=S2ST_TARGET_LANGUAGE_NAMES,
436
- value=DEFAULT_TARGET_LANGUAGE,
437
- ), # target_language
438
- )
439
- elif task_name == "S2TT":
440
- return (
441
- gr.update(visible=True), # audio_box
442
- gr.update(visible=False), # input_text
443
- gr.update(visible=False), # source_language
444
- gr.update(
445
- visible=True,
446
- choices=S2TT_TARGET_LANGUAGE_NAMES,
447
- value=DEFAULT_TARGET_LANGUAGE,
448
- ), # target_language
449
- )
450
- elif task_name == "T2ST":
451
- return (
452
- gr.update(visible=False), # audio_box
453
- gr.update(visible=True), # input_text
454
- gr.update(visible=True), # source_language
455
- gr.update(
456
- visible=True,
457
- choices=S2ST_TARGET_LANGUAGE_NAMES,
458
- value=DEFAULT_TARGET_LANGUAGE,
459
- ), # target_language
460
- )
461
- elif task_name == "T2TT":
462
- return (
463
- gr.update(visible=False), # audio_box
464
- gr.update(visible=True), # input_text
465
- gr.update(visible=True), # source_language
466
- gr.update(
467
- visible=True,
468
- choices=T2TT_TARGET_LANGUAGE_NAMES,
469
- value=DEFAULT_TARGET_LANGUAGE,
470
- ), # target_language
471
- )
472
- elif task_name == "ASR":
473
- return (
474
- gr.update(visible=True), # audio_box
475
- gr.update(visible=False), # input_text
476
- gr.update(visible=False), # source_language
477
- gr.update(
478
- visible=True,
479
- choices=S2TT_TARGET_LANGUAGE_NAMES,
480
- value=DEFAULT_TARGET_LANGUAGE,
481
- ), # target_language
482
- )
483
- else:
484
- raise ValueError(f"Unknown task: {task_name}")
485
-
486
-
487
- def update_output_ui(task_name: str) -> tuple[dict, dict]:
488
- task_name = task_name.split()[0]
489
- if task_name in ["S2ST", "T2ST"]:
490
- return (
491
- gr.update(visible=True, value=None), # output_audio
492
- gr.update(value=None), # output_text
493
- )
494
- elif task_name in ["S2TT", "T2TT", "ASR"]:
495
- return (
496
- gr.update(visible=False, value=None), # output_audio
497
- gr.update(value=None), # output_text
498
- )
499
- else:
500
- raise ValueError(f"Unknown task: {task_name}")
501
-
502
-
503
- def update_example_ui(task_name: str) -> tuple[dict, dict, dict, dict, dict]:
504
- task_name = task_name.split()[0]
505
- return (
506
- gr.update(visible=task_name == "S2ST"), # s2st_example_row
507
- gr.update(visible=task_name == "S2TT"), # s2tt_example_row
508
- gr.update(visible=task_name == "T2ST"), # t2st_example_row
509
- gr.update(visible=task_name == "T2TT"), # t2tt_example_row
510
- gr.update(visible=task_name == "ASR"), # asr_example_row
511
- )
512
-
513
- def check_url(url: str) -> bool:
514
- if url.startswith("https://www.youtube.com/watch?v="):
515
- print("URL is valid")
516
-
517
-
518
- css = """
519
- h1 {
520
- text-align: center;
521
- }
522
-
523
- .contain {
524
- max-width: 730px;
525
- margin: auto;
526
- padding-top: 1.5rem;
527
- }
528
- """
529
-
530
- with gr.Blocks(css=css) as translatube:
531
- # Title
532
- gr.Markdown(DESCRIPTION)
533
-
534
- # URL video
535
- with gr.Group():
536
- url_text = gr.Textbox(label="URL video", placeholder="Paste URL video here")
537
-
538
- with gr.Group() as tasks:
539
- task_name = gr.Dropdown(
540
- label="Task",
541
- choices=TASK_NAMES,
542
- value=TASK_NAMES[0],
543
- )
544
- with gr.Row():
545
- source_language = gr.Dropdown(
546
- label="Source language",
547
- choices=TEXT_SOURCE_LANGUAGE_NAMES,
548
- value="English",
549
- # visible=False,
550
- )
551
- target_language = gr.Dropdown(
552
- label="Target language",
553
- choices=S2ST_TARGET_LANGUAGE_NAMES,
554
- value=DEFAULT_TARGET_LANGUAGE,
555
- )
556
- # with gr.Row() as audio_box:
557
- # audio_source = gr.Radio(
558
- # label="Audio source",
559
- # choices=["file", "microphone"],
560
- # value="file",
561
- # )
562
- # input_audio_mic = gr.Audio(
563
- # label="Input speech",
564
- # type="filepath",
565
- # source="microphone",
566
- # visible=False,
567
- # )
568
- # input_audio_file = gr.Audio(
569
- # label="Input speech",
570
- # type="filepath",
571
- # source="upload",
572
- # visible=True,
573
- # )
574
- # input_text = gr.Textbox(label="Input text", visible=False)
575
- btn = gr.Button("Translate")
576
- with gr.Column():
577
- output_audio = gr.Audio(
578
- label="Translated speech",
579
- autoplay=False,
580
- streaming=False,
581
- type="numpy",
582
- )
583
- output_text = gr.Textbox(label="Translated text")
584
-
585
- url_text.change(
586
- fn=check_url,
587
- inputs=url_text,
588
- outputs=[],
589
- queue=False,
590
- api_name=False,
591
- )
592
- # audio_source.change(
593
- # fn=update_audio_ui,
594
- # inputs=audio_source,
595
- # outputs=[
596
- # input_audio_mic,
597
- # input_audio_file,
598
- # ],
599
- # queue=False,
600
- # api_name=False,
601
- # )
602
- task_name.change(
603
- fn=update_input_ui,
604
- inputs=task_name,
605
- outputs=[
606
- # audio_box,
607
- # input_text,
608
- source_language,
609
- target_language,
610
- ],
611
- queue=False,
612
- api_name=False,
613
- ).then(
614
- fn=update_output_ui,
615
- inputs=task_name,
616
- outputs=[output_audio, output_text],
617
- queue=False,
618
- api_name=False,
619
- )
620
 
621
- btn.click(
622
- fn=predict,
623
- inputs=[
624
- task_name,
625
- # audio_source,
626
- # input_audio_mic,
627
- # input_audio_file,
628
- # input_text,
629
- source_language,
630
- target_language,
631
- ],
632
- outputs=[output_audio, output_text],
633
- api_name="run",
634
- )
635
 
636
- if __name__ == "__main__":
637
- translatube.queue().launch()
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ def greet(name):
4
+ return "Hello " + name + "!"
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+
8
+ demo.launch()