mrfakename commited on
Commit
33d99f2
1 Parent(s): f0d11e3

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/train/finetune_gradio.py +145 -19
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  import gc
2
  import json
3
  import os
@@ -111,7 +115,7 @@ def load_settings(project_name):
111
  "epochs": 100,
112
  "num_warmup_updates": 2,
113
  "save_per_updates": 300,
114
- "last_per_steps": 200,
115
  "finetune": True,
116
  "file_checkpoint_train": "",
117
  "tokenizer_type": "pinyin",
@@ -369,8 +373,9 @@ def start_training(
369
  tokenizer_type="pinyin",
370
  tokenizer_file="",
371
  mixed_precision="fp16",
 
372
  ):
373
- global training_process, tts_api
374
 
375
  if tts_api is not None:
376
  del tts_api
@@ -430,6 +435,7 @@ def start_training(
430
  f"--last_per_steps {last_per_steps} "
431
  f"--dataset_name {dataset_name}"
432
  )
 
433
  if finetune:
434
  cmd += f" --finetune {finetune}"
435
 
@@ -464,14 +470,112 @@ def start_training(
464
  )
465
 
466
  try:
467
- # Start the training process
468
- training_process = subprocess.Popen(cmd, shell=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
- time.sleep(5)
471
- yield "train start", gr.update(interactive=False), gr.update(interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
- # Wait for the training process to finish
474
- training_process.wait()
475
  time.sleep(1)
476
 
477
  if training_process is None:
@@ -489,11 +593,13 @@ def start_training(
489
 
490
 
491
  def stop_training():
492
- global training_process
 
493
  if training_process is None:
494
  return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
495
  terminate_process_tree(training_process.pid)
496
- training_process = None
 
497
  return "train stop", gr.update(interactive=True), gr.update(interactive=False)
498
 
499
 
@@ -1202,7 +1308,11 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
1202
  project_name = gr.Textbox(label="project name", value="my_speak")
1203
  bt_create = gr.Button("create new project")
1204
 
1205
- cm_project = gr.Dropdown(choices=projects, value=projects_selelect, label="Project", allow_custom_value=True)
 
 
 
 
1206
 
1207
  bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
1208
 
@@ -1304,6 +1414,7 @@ Using the extended model, you can fine-tune to a new language that is missing sy
1304
  bt_prepare = bt_create = gr.Button("prepare")
1305
  txt_info_prepare = gr.Text(label="info", value="")
1306
  txt_vocab_prepare = gr.Text(label="vocab", value="")
 
1307
  bt_prepare.click(
1308
  fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
1309
  )
@@ -1347,11 +1458,11 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1347
 
1348
  with gr.Row():
1349
  epochs = gr.Number(label="Epochs", value=10)
1350
- num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
1351
 
1352
  with gr.Row():
1353
- save_per_updates = gr.Number(label="Save per Updates", value=10)
1354
- last_per_steps = gr.Number(label="Last per Steps", value=50)
1355
 
1356
  with gr.Row():
1357
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
@@ -1394,6 +1505,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1394
  tokenizer_file.value = tokenizer_filev
1395
  mixed_precision.value = mixed_precisionv
1396
 
 
1397
  txt_info_train = gr.Text(label="info", value="")
1398
  start_button.click(
1399
  fn=start_training,
@@ -1415,6 +1527,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1415
  tokenizer_type,
1416
  tokenizer_file,
1417
  mixed_precision,
 
1418
  ],
1419
  outputs=[txt_info_train, start_button, stop_button],
1420
  )
@@ -1448,10 +1561,8 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1448
  check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
1449
  )
1450
 
1451
- cm_project.change(
1452
- fn=load_settings,
1453
- inputs=[cm_project],
1454
- outputs=[
1455
  exp_name,
1456
  learning_rate,
1457
  batch_size_per_gpu,
@@ -1468,7 +1579,22 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1468
  tokenizer_type,
1469
  tokenizer_file,
1470
  mixed_precision,
1471
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1472
  )
1473
 
1474
  with gr.TabItem("test model"):
 
1
+ import threading
2
+ import queue
3
+ import re
4
+
5
  import gc
6
  import json
7
  import os
 
115
  "epochs": 100,
116
  "num_warmup_updates": 2,
117
  "save_per_updates": 300,
118
+ "last_per_steps": 100,
119
  "finetune": True,
120
  "file_checkpoint_train": "",
121
  "tokenizer_type": "pinyin",
 
373
  tokenizer_type="pinyin",
374
  tokenizer_file="",
375
  mixed_precision="fp16",
376
+ stream=False,
377
  ):
378
+ global training_process, tts_api, stop_signal
379
 
380
  if tts_api is not None:
381
  del tts_api
 
435
  f"--last_per_steps {last_per_steps} "
436
  f"--dataset_name {dataset_name}"
437
  )
438
+
439
  if finetune:
440
  cmd += f" --finetune {finetune}"
441
 
 
470
  )
471
 
472
  try:
473
+ if not stream:
474
+ # Start the training process
475
+ training_process = subprocess.Popen(cmd, shell=True)
476
+
477
+ time.sleep(5)
478
+ yield "train start", gr.update(interactive=False), gr.update(interactive=True)
479
+
480
+ # Wait for the training process to finish
481
+ training_process.wait()
482
+ else:
483
+
484
+ def stream_output(pipe, output_queue):
485
+ try:
486
+ for line in iter(pipe.readline, ""):
487
+ output_queue.put(line)
488
+ except Exception as e:
489
+ output_queue.put(f"Error reading pipe: {str(e)}")
490
+ finally:
491
+ pipe.close()
492
+
493
+ env = os.environ.copy()
494
+ env["PYTHONUNBUFFERED"] = "1"
495
 
496
+ training_process = subprocess.Popen(
497
+ cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, env=env
498
+ )
499
+ yield "Training started...", gr.update(interactive=False), gr.update(interactive=True)
500
+
501
+ stdout_queue = queue.Queue()
502
+ stderr_queue = queue.Queue()
503
+
504
+ stdout_thread = threading.Thread(target=stream_output, args=(training_process.stdout, stdout_queue))
505
+ stderr_thread = threading.Thread(target=stream_output, args=(training_process.stderr, stderr_queue))
506
+ stdout_thread.daemon = True
507
+ stderr_thread.daemon = True
508
+ stdout_thread.start()
509
+ stderr_thread.start()
510
+ stop_signal = False
511
+ while True:
512
+ if stop_signal:
513
+ training_process.terminate()
514
+ time.sleep(0.5)
515
+ if training_process.poll() is None:
516
+ training_process.kill()
517
+ yield "Training stopped by user.", gr.update(interactive=True), gr.update(interactive=False)
518
+ break
519
+
520
+ process_status = training_process.poll()
521
+
522
+ # Handle stdout
523
+ try:
524
+ while True:
525
+ output = stdout_queue.get_nowait()
526
+ print(output, end="")
527
+ match = re.search(
528
+ r"Epoch (\d+)/(\d+):\s+(\d+)%\|.*\[(\d+:\d+)<.*?loss=(\d+\.\d+), step=(\d+)", output
529
+ )
530
+ if match:
531
+ current_epoch = match.group(1)
532
+ total_epochs = match.group(2)
533
+ percent_complete = match.group(3)
534
+ elapsed_time = match.group(4)
535
+ loss = match.group(5)
536
+ current_step = match.group(6)
537
+ message = (
538
+ f"Epoch: {current_epoch}/{total_epochs}, "
539
+ f"Progress: {percent_complete}%, "
540
+ f"Elapsed Time: {elapsed_time}, "
541
+ f"Loss: {loss}, "
542
+ f"Step: {current_step}"
543
+ )
544
+ yield message, gr.update(interactive=False), gr.update(interactive=True)
545
+ elif output.strip():
546
+ yield output, gr.update(interactive=False), gr.update(interactive=True)
547
+ except queue.Empty:
548
+ pass
549
+
550
+ # Handle stderr
551
+ try:
552
+ while True:
553
+ error_output = stderr_queue.get_nowait()
554
+ print(error_output, end="")
555
+ if error_output.strip():
556
+ yield f"{error_output.strip()}", gr.update(interactive=False), gr.update(interactive=True)
557
+ except queue.Empty:
558
+ pass
559
+
560
+ if process_status is not None and stdout_queue.empty() and stderr_queue.empty():
561
+ if process_status != 0:
562
+ yield (
563
+ f"Process crashed with exit code {process_status}!",
564
+ gr.update(interactive=False),
565
+ gr.update(interactive=True),
566
+ )
567
+ else:
568
+ yield "Training complete!", gr.update(interactive=False), gr.update(interactive=True)
569
+ break
570
+
571
+ # Small sleep to prevent CPU thrashing
572
+ time.sleep(0.1)
573
+
574
+ # Clean up
575
+ training_process.stdout.close()
576
+ training_process.stderr.close()
577
+ training_process.wait()
578
 
 
 
579
  time.sleep(1)
580
 
581
  if training_process is None:
 
593
 
594
 
595
  def stop_training():
596
+ global training_process, stop_signal
597
+
598
  if training_process is None:
599
  return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
600
  terminate_process_tree(training_process.pid)
601
+ # training_process = None
602
+ stop_signal = True
603
  return "train stop", gr.update(interactive=True), gr.update(interactive=False)
604
 
605
 
 
1308
  project_name = gr.Textbox(label="project name", value="my_speak")
1309
  bt_create = gr.Button("create new project")
1310
 
1311
+ with gr.Row():
1312
+ cm_project = gr.Dropdown(
1313
+ choices=projects, value=projects_selelect, label="Project", allow_custom_value=True, scale=6
1314
+ )
1315
+ ch_refresh_project = gr.Button("refresh", scale=1)
1316
 
1317
  bt_create.click(fn=create_data_project, inputs=[project_name, tokenizer_type], outputs=[cm_project])
1318
 
 
1414
  bt_prepare = bt_create = gr.Button("prepare")
1415
  txt_info_prepare = gr.Text(label="info", value="")
1416
  txt_vocab_prepare = gr.Text(label="vocab", value="")
1417
+
1418
  bt_prepare.click(
1419
  fn=create_metadata, inputs=[cm_project, ch_tokenizern], outputs=[txt_info_prepare, txt_vocab_prepare]
1420
  )
 
1458
 
1459
  with gr.Row():
1460
  epochs = gr.Number(label="Epochs", value=10)
1461
+ num_warmup_updates = gr.Number(label="Warmup Updates", value=2)
1462
 
1463
  with gr.Row():
1464
+ save_per_updates = gr.Number(label="Save per Updates", value=300)
1465
+ last_per_steps = gr.Number(label="Last per Steps", value=100)
1466
 
1467
  with gr.Row():
1468
  mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
 
1505
  tokenizer_file.value = tokenizer_filev
1506
  mixed_precision.value = mixed_precisionv
1507
 
1508
+ ch_stream = gr.Checkbox(label="stream output experiment.", value=True)
1509
  txt_info_train = gr.Text(label="info", value="")
1510
  start_button.click(
1511
  fn=start_training,
 
1527
  tokenizer_type,
1528
  tokenizer_file,
1529
  mixed_precision,
1530
+ ch_stream,
1531
  ],
1532
  outputs=[txt_info_train, start_button, stop_button],
1533
  )
 
1561
  check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
1562
  )
1563
 
1564
+ def setup_load_settings():
1565
+ output_components = [
 
 
1566
  exp_name,
1567
  learning_rate,
1568
  batch_size_per_gpu,
 
1579
  tokenizer_type,
1580
  tokenizer_file,
1581
  mixed_precision,
1582
+ ]
1583
+
1584
+ return output_components
1585
+
1586
+ outputs = setup_load_settings()
1587
+
1588
+ cm_project.change(
1589
+ fn=load_settings,
1590
+ inputs=[cm_project],
1591
+ outputs=outputs,
1592
+ )
1593
+
1594
+ ch_refresh_project.click(
1595
+ fn=load_settings,
1596
+ inputs=[cm_project],
1597
+ outputs=outputs,
1598
  )
1599
 
1600
  with gr.TabItem("test model"):