mrfakename commited on
Commit
897409a
1 Parent(s): 507af76

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/train/finetune_gradio.py +132 -14
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -23,7 +23,7 @@ from datasets.arrow_writer import ArrowWriter
23
  from safetensors.torch import save_file
24
  from scipy.io import wavfile
25
  from transformers import pipeline
26
-
27
  from f5_tts.api import F5TTS
28
  from f5_tts.model.utils import convert_char_to_pinyin
29
 
@@ -731,6 +731,97 @@ def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str, s
731
  return f"An error occurred: {e}"
732
 
733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  def vocab_check(project_name):
735
  name_project = project_name
736
  path_project = os.path.join(path_data, name_project)
@@ -739,7 +830,7 @@ def vocab_check(project_name):
739
 
740
  file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
741
  if not os.path.isfile(file_vocab):
742
- return f"the file {file_vocab} not found !"
743
 
744
  with open(file_vocab, "r", encoding="utf-8-sig") as f:
745
  data = f.read()
@@ -747,7 +838,7 @@ def vocab_check(project_name):
747
  vocab = set(vocab)
748
 
749
  if not os.path.isfile(file_metadata):
750
- return f"the file {file_metadata} not found !"
751
 
752
  with open(file_metadata, "r", encoding="utf-8-sig") as f:
753
  data = f.read()
@@ -765,12 +856,15 @@ def vocab_check(project_name):
765
  if t not in vocab and t not in miss_symbols_keep:
766
  miss_symbols.append(t)
767
  miss_symbols_keep[t] = t
 
768
  if miss_symbols == []:
 
769
  info = "You can train using your language !"
770
  else:
771
- info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
 
772
 
773
- return info
774
 
775
 
776
  def get_random_sample_prepare(project_name):
@@ -1009,6 +1103,38 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
1009
  outputs=[random_text_transcribe, random_audio_transcribe],
1010
  )
1011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1012
  with gr.TabItem("prepare Data"):
1013
  gr.Markdown(
1014
  """```plaintext
@@ -1030,7 +1156,7 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
1030
 
1031
  ```"""
1032
  )
1033
- ch_tokenizern = gr.Checkbox(label="create vocabulary from dataset", value=False)
1034
  bt_prepare = bt_create = gr.Button("prepare")
1035
  txt_info_prepare = gr.Text(label="info", value="")
1036
  txt_vocab_prepare = gr.Text(label="vocab", value="")
@@ -1048,14 +1174,6 @@ for tutorial and updates check here (https://github.com/SWivid/F5-TTS/discussion
1048
  fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
1049
  )
1050
 
1051
- with gr.TabItem("vocab check"):
1052
- gr.Markdown("""```plaintext
1053
- check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language
1054
- ```""")
1055
- check_button = gr.Button("check vocab")
1056
- txt_info_check = gr.Text(label="info", value="")
1057
- check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check])
1058
-
1059
  with gr.TabItem("train Data"):
1060
  gr.Markdown("""```plaintext
1061
  The auto-setting is still experimental. Please make sure that the epochs , save per updates , and last per steps are set correctly, or change them manually as needed.
 
23
  from safetensors.torch import save_file
24
  from scipy.io import wavfile
25
  from transformers import pipeline
26
+ from cached_path import cached_path
27
  from f5_tts.api import F5TTS
28
  from f5_tts.model.utils import convert_char_to_pinyin
29
 
 
731
  return f"An error occurred: {e}"
732
 
733
 
734
+ def expand_model_embeddings(ckpt_path, new_ckpt_path, num_new_tokens=42):
735
+ seed = 666
736
+ random.seed(seed)
737
+ os.environ["PYTHONHASHSEED"] = str(seed)
738
+ torch.manual_seed(seed)
739
+ torch.cuda.manual_seed(seed)
740
+ torch.cuda.manual_seed_all(seed)
741
+ torch.backends.cudnn.deterministic = True
742
+ torch.backends.cudnn.benchmark = False
743
+
744
+ ckpt = torch.load(ckpt_path, map_location="cpu")
745
+
746
+ ema_sd = ckpt.get("ema_model_state_dict", {})
747
+ embed_key_ema = "ema_model.transformer.text_embed.text_embed.weight"
748
+ old_embed_ema = ema_sd[embed_key_ema]
749
+
750
+ vocab_old = old_embed_ema.size(0)
751
+ embed_dim = old_embed_ema.size(1)
752
+ vocab_new = vocab_old + num_new_tokens
753
+
754
+ def expand_embeddings(old_embeddings):
755
+ new_embeddings = torch.zeros((vocab_new, embed_dim))
756
+ new_embeddings[:vocab_old] = old_embeddings
757
+ new_embeddings[vocab_old:] = torch.randn((num_new_tokens, embed_dim))
758
+ return new_embeddings
759
+
760
+ ema_sd[embed_key_ema] = expand_embeddings(ema_sd[embed_key_ema])
761
+
762
+ torch.save(ckpt, new_ckpt_path)
763
+
764
+ return vocab_new
765
+
766
+
767
+ def vocab_count(text):
768
+ return str(len(text.split(",")))
769
+
770
+
771
+ def vocab_extend(project_name, symbols, model_type):
772
+ if symbols == "":
773
+ return "Symbols empty!"
774
+
775
+ name_project = project_name
776
+ path_project = os.path.join(path_data, name_project)
777
+ file_vocab_project = os.path.join(path_project, "vocab.txt")
778
+
779
+ file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
780
+ if not os.path.isfile(file_vocab):
781
+ return f"the file {file_vocab} not found !"
782
+
783
+ symbols = symbols.split(",")
784
+ if symbols == []:
785
+ return "Symbols to extend not found."
786
+
787
+ with open(file_vocab, "r", encoding="utf-8-sig") as f:
788
+ data = f.read()
789
+ vocab = data.split("\n")
790
+ vocab_check = set(vocab)
791
+
792
+ miss_symbols = []
793
+ for item in symbols:
794
+ item = item.replace(" ", "")
795
+ if item in vocab_check:
796
+ continue
797
+ miss_symbols.append(item)
798
+
799
+ if miss_symbols == []:
800
+ return "Symbols are okay no need to extend."
801
+
802
+ size_vocab = len(vocab)
803
+
804
+ for item in miss_symbols:
805
+ vocab.append(item)
806
+
807
+ with open(file_vocab_project, "w", encoding="utf-8-sig") as f:
808
+ f.write("\n".join(vocab))
809
+
810
+ if model_type == "F5-TTS":
811
+ ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
812
+ else:
813
+ ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
814
+
815
+ new_ckpt_path = os.path.join(path_project_ckpts, name_project)
816
+ os.makedirs(new_ckpt_path, exist_ok=True)
817
+ new_ckpt_file = os.path.join(new_ckpt_path, "model_1200000.pt")
818
+
819
+ size = expand_model_embeddings(ckpt_path, new_ckpt_file, num_new_tokens=len(miss_symbols))
820
+
821
+ vocab_new = "\n".join(miss_symbols)
822
+ return f"vocab old size : {size_vocab}\nvocab new size : {size}\nvocab add : {len(miss_symbols)}\nnew symbols :\n{vocab_new}"
823
+
824
+
825
  def vocab_check(project_name):
826
  name_project = project_name
827
  path_project = os.path.join(path_data, name_project)
 
830
 
831
  file_vocab = os.path.join(path_data, "Emilia_ZH_EN_pinyin/vocab.txt")
832
  if not os.path.isfile(file_vocab):
833
+ return f"the file {file_vocab} not found !", ""
834
 
835
  with open(file_vocab, "r", encoding="utf-8-sig") as f:
836
  data = f.read()
 
838
  vocab = set(vocab)
839
 
840
  if not os.path.isfile(file_metadata):
841
+ return f"the file {file_metadata} not found !", ""
842
 
843
  with open(file_metadata, "r", encoding="utf-8-sig") as f:
844
  data = f.read()
 
856
  if t not in vocab and t not in miss_symbols_keep:
857
  miss_symbols.append(t)
858
  miss_symbols_keep[t] = t
859
+
860
  if miss_symbols == []:
861
+ vocab_miss = ""
862
  info = "You can train using your language !"
863
  else:
864
+ vocab_miss = ",".join(miss_symbols)
865
+ info = f"The following symbols are missing in your language {len(miss_symbols)}\n\n"
866
 
867
+ return info, vocab_miss
868
 
869
 
870
  def get_random_sample_prepare(project_name):
 
1103
  outputs=[random_text_transcribe, random_audio_transcribe],
1104
  )
1105
 
1106
+ with gr.TabItem("vocab check"):
1107
+ gr.Markdown("""```plaintext
1108
+ check the vocabulary for fine-tuning Emilia_ZH_EN to ensure all symbols are included. for finetune new language
1109
+ ```""")
1110
+
1111
+ check_button = gr.Button("check vocab")
1112
+ txt_info_check = gr.Text(label="info", value="")
1113
+
1114
+ gr.Markdown("""```plaintext
1115
+ Using the extended model, you can fine-tune to a new language that is missing symbols in the vocab , this create a new model with a new vocabulary size and save it in your ckpts/project folder.
1116
+ ```""")
1117
+
1118
+ exp_name_extend = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
1119
+
1120
+ with gr.Row():
1121
+ txt_extend = gr.Textbox(
1122
+ label="Symbols",
1123
+ value="",
1124
+ placeholder="To add new symbols, make sure to use ',' for each symbol",
1125
+ scale=6,
1126
+ )
1127
+ txt_count_symbol = gr.Textbox(label="new size vocab", value="", scale=1)
1128
+
1129
+ extend_button = gr.Button("Extended")
1130
+ txt_info_extend = gr.Text(label="info", value="")
1131
+
1132
+ txt_extend.change(vocab_count, inputs=[txt_extend], outputs=[txt_count_symbol])
1133
+ check_button.click(fn=vocab_check, inputs=[cm_project], outputs=[txt_info_check, txt_extend])
1134
+ extend_button.click(
1135
+ fn=vocab_extend, inputs=[cm_project, txt_extend, exp_name_extend], outputs=[txt_info_extend]
1136
+ )
1137
+
1138
  with gr.TabItem("prepare Data"):
1139
  gr.Markdown(
1140
  """```plaintext
 
1156
 
1157
  ```"""
1158
  )
1159
+ ch_tokenizern = gr.Checkbox(label="create vocabulary", value=False, visible=False)
1160
  bt_prepare = bt_create = gr.Button("prepare")
1161
  txt_info_prepare = gr.Text(label="info", value="")
1162
  txt_vocab_prepare = gr.Text(label="vocab", value="")
 
1174
  fn=get_random_sample_prepare, inputs=[cm_project], outputs=[random_text_prepare, random_audio_prepare]
1175
  )
1176
 
 
 
 
 
 
 
 
 
1177
  with gr.TabItem("train Data"):
1178
  gr.Markdown("""```plaintext
1179
  The auto-setting is still experimental. Please make sure that the epochs , save per updates , and last per steps are set correctly, or change them manually as needed.