nijisakai commited on
Commit
387abab
1 Parent(s): 9f4b797

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -42,7 +42,7 @@ default_cluster_infer_ratio = 0.5
42
  duration_limit = int(os.environ.get("MAX_DURATION_SECONDS", 9e9))
43
  ###################################################################
44
 
45
- models = {}
46
  all_speakers = []
47
  for repo_id in repo_ids:
48
  # Figure out the latest generator by taking highest value one.
@@ -57,7 +57,7 @@ for repo_id in repo_ids:
57
  )[-1]
58
  ckpt_name = f"G_{latest_id}.pth"
59
 
60
- cluster_model_name = cluster_model_name or "kmeans.pt"
61
  if cluster_model_name in list_repo_files(repo_id):
62
  print(f"Found Cluster model - Downloading {cluster_model_name} from {repo_id}")
63
  cluster_model_path = hf_hub_download(repo_id, cluster_model_name)
@@ -70,16 +70,13 @@ for repo_id in repo_ids:
70
  config_path = hf_hub_download(repo_id, "config.json")
71
  hparams = HParams(**json.loads(Path(config_path).read_text()))
72
  speakers = list(hparams.spk.keys())
 
73
  device = "cuda" if torch.cuda.is_available() else "cpu"
74
  model = Svc(net_g_path=generator_path, config_path=config_path, device=device, cluster_model_path=cluster_model_path)
75
-
76
- for speaker in speakers:
77
- models[speaker] = model
78
- all_speakers.append(speaker)
79
 
80
- # Reset ckpt_name and cluster_model_name for the next iteration
81
  ckpt_name = None
82
- cluster_model_name = None
83
 
84
  demucs_model = get_model(DEFAULT_MODEL)
85
 
 
42
  duration_limit = int(os.environ.get("MAX_DURATION_SECONDS", 9e9))
43
  ###################################################################
44
 
45
+ models = []
46
  all_speakers = []
47
  for repo_id in repo_ids:
48
  # Figure out the latest generator by taking highest value one.
 
57
  )[-1]
58
  ckpt_name = f"G_{latest_id}.pth"
59
 
60
+ cluster_model_name = "kmeans.pt"
61
  if cluster_model_name in list_repo_files(repo_id):
62
  print(f"Found Cluster model - Downloading {cluster_model_name} from {repo_id}")
63
  cluster_model_path = hf_hub_download(repo_id, cluster_model_name)
 
70
  config_path = hf_hub_download(repo_id, "config.json")
71
  hparams = HParams(**json.loads(Path(config_path).read_text()))
72
  speakers = list(hparams.spk.keys())
73
+ all_speakers.extend(speakers)
74
  device = "cuda" if torch.cuda.is_available() else "cpu"
75
  model = Svc(net_g_path=generator_path, config_path=config_path, device=device, cluster_model_path=cluster_model_path)
76
+ models.append(model)
 
 
 
77
 
78
+ # Reset ckpt_name for the next iteration
79
  ckpt_name = None
 
80
 
81
  demucs_model = get_model(DEFAULT_MODEL)
82