mrfakename commited on
Commit
79086d9
1 Parent(s): 4af33eb

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/train/finetune_cli.py +5 -6
src/f5_tts/train/finetune_cli.py CHANGED
@@ -6,6 +6,7 @@ from cached_path import cached_path
6
  from f5_tts.model import CFM, UNetT, DiT, Trainer
7
  from f5_tts.model.utils import get_tokenizer
8
  from f5_tts.model.dataset import load_dataset
 
9
 
10
 
11
  # -------------------------- Dataset Settings --------------------------- #
@@ -63,6 +64,7 @@ def parse_args():
63
 
64
  def main():
65
  args = parse_args()
 
66
 
67
  # Model parameters based on experiment name
68
  if args.exp_name == "F5TTS_Base":
@@ -85,12 +87,9 @@ def main():
85
  ckpt_path = args.pretrain
86
 
87
  if args.finetune:
88
- path_ckpt = os.path.join("ckpts", args.dataset_name)
89
- if not os.path.isdir(path_ckpt):
90
- os.makedirs(path_ckpt, exist_ok=True)
91
- shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
92
-
93
- checkpoint_path = os.path.join("ckpts", args.dataset_name)
94
 
95
  # Use the tokenizer and tokenizer_path provided in the command line arguments
96
  tokenizer = args.tokenizer
 
6
  from f5_tts.model import CFM, UNetT, DiT, Trainer
7
  from f5_tts.model.utils import get_tokenizer
8
  from f5_tts.model.dataset import load_dataset
9
+ from importlib.resources import files
10
 
11
 
12
  # -------------------------- Dataset Settings --------------------------- #
 
64
 
65
  def main():
66
  args = parse_args()
67
+ checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
68
 
69
  # Model parameters based on experiment name
70
  if args.exp_name == "F5TTS_Base":
 
87
  ckpt_path = args.pretrain
88
 
89
  if args.finetune:
90
+ if not os.path.isdir(checkpoint_path):
91
+ os.makedirs(checkpoint_path, exist_ok=True)
92
+ shutil.copy2(ckpt_path, os.path.join(checkpoint_path, os.path.basename(ckpt_path)))
 
 
 
93
 
94
  # Use the tokenizer and tokenizer_path provided in the command line arguments
95
  tokenizer = args.tokenizer