|
import os |
|
import re |
|
from fairseq import checkpoint_utils |
|
|
|
|
|
def get_index_path_from_model(sid): |
|
sid0strip = re.sub(r'\.pth|\.onnx$', '', sid) |
|
sid0name = os.path.split(sid0strip)[-1] |
|
|
|
|
|
if re.match(r'.+_e\d+_s\d+$', sid0name): |
|
base_model_name = sid0name.rsplit('_', 2)[0] |
|
else: |
|
base_model_name = sid0name |
|
|
|
return next( |
|
( |
|
f |
|
for f in [ |
|
os.path.join(root, name) |
|
for root, _, files in os.walk(os.getenv("index_root"), topdown=False) |
|
for name in files |
|
if name.endswith(".index") and "trained" not in name |
|
] |
|
if base_model_name in f |
|
), |
|
"", |
|
) |
|
|
|
|
|
def load_hubert(config): |
|
models, _, _ = checkpoint_utils.load_model_ensemble_and_task( |
|
["assets/hubert/hubert_base.pt"], |
|
suffix="", |
|
) |
|
hubert_model = models[0] |
|
hubert_model = hubert_model.to(config.device) |
|
if config.is_half: |
|
hubert_model = hubert_model.half() |
|
else: |
|
hubert_model = hubert_model.float() |
|
return hubert_model.eval() |
|
|