mrfakename commited on
Commit
392ff83
1 Parent(s): c971ea2

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/train/finetune_cli.py CHANGED
@@ -45,7 +45,7 @@ def parse_args():
45
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
46
  parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
47
  parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
48
- parser.add_argument("--pretrain", type=str, default=None, help="Use pretrain model for finetune")
49
  parser.add_argument(
50
  "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
51
  )
@@ -89,7 +89,11 @@ def main():
89
  if args.finetune:
90
  if not os.path.isdir(checkpoint_path):
91
  os.makedirs(checkpoint_path, exist_ok=True)
92
- shutil.copy2(ckpt_path, os.path.join(checkpoint_path, os.path.basename(ckpt_path)))
 
 
 
 
93
 
94
  # Use the tokenizer and tokenizer_path provided in the command line arguments
95
  tokenizer = args.tokenizer
 
45
  parser.add_argument("--save_per_updates", type=int, default=10000, help="Save checkpoint every X steps")
46
  parser.add_argument("--last_per_steps", type=int, default=50000, help="Save last checkpoint every X steps")
47
  parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
48
+ parser.add_argument("--pretrain", type=str, default=None, help="the path to the checkpoint")
49
  parser.add_argument(
50
  "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
51
  )
 
89
  if args.finetune:
90
  if not os.path.isdir(checkpoint_path):
91
  os.makedirs(checkpoint_path, exist_ok=True)
92
+
93
+ file_checkpoint = os.path.join(checkpoint_path, os.path.basename(ckpt_path))
94
+ if not os.path.isfile(file_checkpoint):
95
+ shutil.copy2(ckpt_path, file_checkpoint)
96
+ print("copy checkpoint for finetune")
97
 
98
  # Use the tokenizer and tokenizer_path provided in the command line arguments
99
  tokenizer = args.tokenizer
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -26,7 +26,7 @@ from transformers import pipeline
26
  from cached_path import cached_path
27
  from f5_tts.api import F5TTS
28
  from f5_tts.model.utils import convert_char_to_pinyin
29
-
30
 
31
  training_process = None
32
  system = platform.system()
@@ -36,9 +36,9 @@ last_checkpoint = ""
36
  last_device = ""
37
  last_ema = None
38
 
39
- path_basic = os.path.abspath(os.path.join(__file__, "../../../.."))
40
- path_data = os.path.join(path_basic, "data")
41
- path_project_ckpts = os.path.join(path_basic, "ckpts")
42
  file_train = "src/f5_tts/train/finetune_cli.py"
43
 
44
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
@@ -46,6 +46,119 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is
46
  pipe = None
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Load metadata
50
  def get_audio_duration(audio_path):
51
  """Calculate the duration of an audio file."""
@@ -330,6 +443,26 @@ def start_training(
330
 
331
  print(cmd)
332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  try:
334
  # Start the training process
335
  training_process = subprocess.Popen(cmd, shell=True)
@@ -564,10 +697,11 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
564
 
565
  new_vocal = ""
566
  if not ch_tokenizer:
567
- file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
568
- if not os.path.isfile(file_vocab_finetune):
569
- return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", ""
570
- shutil.copy2(file_vocab_finetune, file_vocab)
 
571
 
572
  with open(file_vocab, "r", encoding="utf-8-sig") as f:
573
  vocab_char_map = {}
@@ -801,11 +935,13 @@ def vocab_extend(project_name, symbols, model_type):
801
  return "Symbols are okay no need to extend."
802
 
803
  size_vocab = len(vocab)
804
- vocab.pop() # fix empty space leave
805
  for item in miss_symbols:
806
  vocab.append(item)
807
 
808
- with open(file_vocab_project, "w", encoding="utf-8-sig") as f:
 
 
809
  f.write("\n".join(vocab))
810
 
811
  if model_type == "F5-TTS":
@@ -813,14 +949,17 @@ def vocab_extend(project_name, symbols, model_type):
813
  else:
814
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
815
 
816
- new_ckpt_path = os.path.join(path_project_ckpts, name_project)
 
 
 
817
  os.makedirs(new_ckpt_path, exist_ok=True)
818
  new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
819
 
820
- size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=len(miss_symbols))
821
 
822
  vocab_new = "\n".join(miss_symbols)
823
- return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {len(miss_symbols)}\nnew symbols :\n{vocab_new}"
824
 
825
 
826
  def vocab_check(project_name):
@@ -1192,7 +1331,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1192
  with gr.Row():
1193
  ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
1194
  tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
1195
- file_checkpoint_train = gr.Textbox(label="Pretrain Model", value="")
1196
 
1197
  with gr.Row():
1198
  exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
@@ -1219,6 +1358,42 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1219
  start_button = gr.Button("Start Training")
1220
  stop_button = gr.Button("Stop Training", interactive=False)
1221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1222
  txt_info_train = gr.Text(label="info", value="")
1223
  start_button.click(
1224
  fn=start_training,
@@ -1273,6 +1448,29 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1273
  check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
1274
  )
1275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1276
  with gr.TabItem("test model"):
1277
  exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1278
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)
 
26
  from cached_path import cached_path
27
  from f5_tts.api import F5TTS
28
  from f5_tts.model.utils import convert_char_to_pinyin
29
+ from importlib.resources import files
30
 
31
  training_process = None
32
  system = platform.system()
 
36
  last_device = ""
37
  last_ema = None
38
 
39
+
40
+ path_data = str(files("f5_tts").joinpath("../../data"))
41
+ path_project_ckpts = str(files("f5_tts").joinpath("../../ckpts"))
42
  file_train = "src/f5_tts/train/finetune_cli.py"
43
 
44
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
 
46
  pipe = None
47
 
48
 
49
+ # Save settings from a JSON file
50
+ def save_settings(
51
+ project_name,
52
+ exp_name,
53
+ learning_rate,
54
+ batch_size_per_gpu,
55
+ batch_size_type,
56
+ max_samples,
57
+ grad_accumulation_steps,
58
+ max_grad_norm,
59
+ epochs,
60
+ num_warmup_updates,
61
+ save_per_updates,
62
+ last_per_steps,
63
+ finetune,
64
+ file_checkpoint_train,
65
+ tokenizer_type,
66
+ tokenizer_file,
67
+ mixed_precision,
68
+ ):
69
+ path_project = os.path.join(path_project_ckpts, project_name)
70
+ os.makedirs(path_project, exist_ok=True)
71
+ file_setting = os.path.join(path_project, "setting.json")
72
+
73
+ settings = {
74
+ "exp_name": exp_name,
75
+ "learning_rate": learning_rate,
76
+ "batch_size_per_gpu": batch_size_per_gpu,
77
+ "batch_size_type": batch_size_type,
78
+ "max_samples": max_samples,
79
+ "grad_accumulation_steps": grad_accumulation_steps,
80
+ "max_grad_norm": max_grad_norm,
81
+ "epochs": epochs,
82
+ "num_warmup_updates": num_warmup_updates,
83
+ "save_per_updates": save_per_updates,
84
+ "last_per_steps": last_per_steps,
85
+ "finetune": finetune,
86
+ "file_checkpoint_train": file_checkpoint_train,
87
+ "tokenizer_type": tokenizer_type,
88
+ "tokenizer_file": tokenizer_file,
89
+ "mixed_precision": mixed_precision,
90
+ }
91
+ with open(file_setting, "w") as f:
92
+ json.dump(settings, f, indent=4)
93
+ return "Settings saved!"
94
+
95
+
96
+ # Load settings from a JSON file
97
+ def load_settings(project_name):
98
+ project_name = project_name.replace("_pinyin", "").replace("_char", "")
99
+ path_project = os.path.join(path_project_ckpts, project_name)
100
+ file_setting = os.path.join(path_project, "setting.json")
101
+
102
+ if not os.path.isfile(file_setting):
103
+ settings = {
104
+ "exp_name": "F5TTS_Base",
105
+ "learning_rate": 1e-05,
106
+ "batch_size_per_gpu": 1000,
107
+ "batch_size_type": "frame",
108
+ "max_samples": 64,
109
+ "grad_accumulation_steps": 1,
110
+ "max_grad_norm": 1,
111
+ "epochs": 100,
112
+ "num_warmup_updates": 2,
113
+ "save_per_updates": 300,
114
+ "last_per_steps": 200,
115
+ "finetune": True,
116
+ "file_checkpoint_train": "",
117
+ "tokenizer_type": "pinyin",
118
+ "tokenizer_file": "",
119
+ "mixed_precision": "none",
120
+ }
121
+ return (
122
+ settings["exp_name"],
123
+ settings["learning_rate"],
124
+ settings["batch_size_per_gpu"],
125
+ settings["batch_size_type"],
126
+ settings["max_samples"],
127
+ settings["grad_accumulation_steps"],
128
+ settings["max_grad_norm"],
129
+ settings["epochs"],
130
+ settings["num_warmup_updates"],
131
+ settings["save_per_updates"],
132
+ settings["last_per_steps"],
133
+ settings["finetune"],
134
+ settings["file_checkpoint_train"],
135
+ settings["tokenizer_type"],
136
+ settings["tokenizer_file"],
137
+ settings["mixed_precision"],
138
+ )
139
+
140
+ with open(file_setting, "r") as f:
141
+ settings = json.load(f)
142
+ return (
143
+ settings["exp_name"],
144
+ settings["learning_rate"],
145
+ settings["batch_size_per_gpu"],
146
+ settings["batch_size_type"],
147
+ settings["max_samples"],
148
+ settings["grad_accumulation_steps"],
149
+ settings["max_grad_norm"],
150
+ settings["epochs"],
151
+ settings["num_warmup_updates"],
152
+ settings["save_per_updates"],
153
+ settings["last_per_steps"],
154
+ settings["finetune"],
155
+ settings["file_checkpoint_train"],
156
+ settings["tokenizer_type"],
157
+ settings["tokenizer_file"],
158
+ settings["mixed_precision"],
159
+ )
160
+
161
+
162
  # Load metadata
163
  def get_audio_duration(audio_path):
164
  """Calculate the duration of an audio file."""
 
443
 
444
  print(cmd)
445
 
446
+ save_settings(
447
+ dataset_name,
448
+ exp_name,
449
+ learning_rate,
450
+ batch_size_per_gpu,
451
+ batch_size_type,
452
+ max_samples,
453
+ grad_accumulation_steps,
454
+ max_grad_norm,
455
+ epochs,
456
+ num_warmup_updates,
457
+ save_per_updates,
458
+ last_per_steps,
459
+ finetune,
460
+ file_checkpoint_train,
461
+ tokenizer_type,
462
+ tokenizer_file,
463
+ mixed_precision,
464
+ )
465
+
466
  try:
467
  # Start the training process
468
  training_process = subprocess.Popen(cmd, shell=True)
 
697
 
698
  new_vocal = ""
699
  if not ch_tokenizer:
700
+ if not os.path.isfile(file_vocab):
701
+ file_vocab_finetune = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
702
+ if not os.path.isfile(file_vocab_finetune):
703
+ return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!", ""
704
+ shutil.copy2(file_vocab_finetune, file_vocab)
705
 
706
  with open(file_vocab, "r", encoding="utf-8-sig") as f:
707
  vocab_char_map = {}
 
935
  return "Symbols are okay no need to extend."
936
 
937
  size_vocab = len(vocab)
938
+ vocab.pop()
939
  for item in miss_symbols:
940
  vocab.append(item)
941
 
942
+ vocab.append("")
943
+
944
+ with open(file_vocab_project, "w", encoding="utf-8") as f:
945
  f.write("\n".join(vocab))
946
 
947
  if model_type == "F5-TTS":
 
949
  else:
950
  ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
951
 
952
+ vocab_size_new = len(miss_symbols)
953
+
954
+ dataset_name = name_project.replace("_pinyin", "").replace("_char", "")
955
+ new_ckpt_path = os.path.join(path_project_ckpts, dataset_name)
956
  os.makedirs(new_ckpt_path, exist_ok=True)
957
  new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
958
 
959
+ size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=vocab_size_new)
960
 
961
  vocab_new = "\n".join(miss_symbols)
962
+ return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {vocab_size_new}\nnew symbols :\n{vocab_new}"
963
 
964
 
965
  def vocab_check(project_name):
 
1331
  with gr.Row():
1332
  ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
1333
  tokenizer_file = gr.Textbox(label="Tokenizer File", value="")
1334
+ file_checkpoint_train = gr.Textbox(label="Path to the preetrain checkpoint ", value="")
1335
 
1336
  with gr.Row():
1337
  exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
 
1358
  start_button = gr.Button("Start Training")
1359
  stop_button = gr.Button("Stop Training", interactive=False)
1360
 
1361
+ if projects_selelect is not None:
1362
+ (
1363
+ exp_namev,
1364
+ learning_ratev,
1365
+ batch_size_per_gpuv,
1366
+ batch_size_typev,
1367
+ max_samplesv,
1368
+ grad_accumulation_stepsv,
1369
+ max_grad_normv,
1370
+ epochsv,
1371
+ num_warmupv_updatesv,
1372
+ save_per_updatesv,
1373
+ last_per_stepsv,
1374
+ finetunev,
1375
+ file_checkpoint_trainv,
1376
+ tokenizer_typev,
1377
+ tokenizer_filev,
1378
+ mixed_precisionv,
1379
+ ) = load_settings(projects_selelect)
1380
+ exp_name.value = exp_namev
1381
+ learning_rate.value = learning_ratev
1382
+ batch_size_per_gpu.value = batch_size_per_gpuv
1383
+ batch_size_type.value = batch_size_typev
1384
+ max_samples.value = max_samplesv
1385
+ grad_accumulation_steps.value = grad_accumulation_stepsv
1386
+ max_grad_norm.value = max_grad_normv
1387
+ epochs.value = epochsv
1388
+ num_warmup_updates.value = num_warmupv_updatesv
1389
+ save_per_updates.value = save_per_updatesv
1390
+ last_per_steps.value = last_per_stepsv
1391
+ ch_finetune.value = finetunev
1392
+ file_checkpoint_train.value = file_checkpoint_train
1393
+ tokenizer_type.value = tokenizer_typev
1394
+ tokenizer_file.value = tokenizer_filev
1395
+ mixed_precision.value = mixed_precisionv
1396
+
1397
  txt_info_train = gr.Text(label="info", value="")
1398
  start_button.click(
1399
  fn=start_training,
 
1448
  check_finetune, inputs=[ch_finetune], outputs=[file_checkpoint_train, tokenizer_file, tokenizer_type]
1449
  )
1450
 
1451
+ cm_project.change(
1452
+ fn=load_settings,
1453
+ inputs=[cm_project],
1454
+ outputs=[
1455
+ exp_name,
1456
+ learning_rate,
1457
+ batch_size_per_gpu,
1458
+ batch_size_type,
1459
+ max_samples,
1460
+ grad_accumulation_steps,
1461
+ max_grad_norm,
1462
+ epochs,
1463
+ num_warmup_updates,
1464
+ save_per_updates,
1465
+ last_per_steps,
1466
+ ch_finetune,
1467
+ file_checkpoint_train,
1468
+ tokenizer_type,
1469
+ tokenizer_file,
1470
+ mixed_precision,
1471
+ ],
1472
+ )
1473
+
1474
  with gr.TabItem("test model"):
1475
  exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1476
  list_checkpoints, checkpoint_select = get_checkpoints_project(projects_selelect, False)