Spaces:
Runtime error
Runtime error
"""Copyright PolyAI Limited.""" | |
import logging | |
import pdb | |
import sys | |
import traceback | |
from functools import wraps | |
from time import time | |
from typing import List | |
import torch | |
from .symbol_table import SymbolTable | |
def load_checkpoint(ckpt_path: str) -> dict: | |
""" | |
Loads checkpoint, while matching phone embedding size. | |
""" | |
state_dict: dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] | |
new_state_dict = dict() | |
for p_name in state_dict.keys(): | |
if p_name.startswith("vocoder"): | |
continue | |
new_state_dict[p_name] = state_dict[p_name] | |
return new_state_dict | |
def breakpoint_on_error(fn): | |
"""Creates a breakpoint on error | |
Use as a wrapper | |
Args: | |
fn: the function | |
Returns: | |
inner function | |
""" | |
def inner(*args, **kwargs): | |
try: | |
return fn(*args, **kwargs) | |
except Exception: | |
"""Standard python way of creating a breakpoint on error""" | |
extype, value, tb = sys.exc_info() | |
print(f"extype={extype},\nvalue={value}") | |
traceback.print_exc() | |
pdb.post_mortem(tb) | |
return inner | |
def measure_duration(f): | |
def wrap(*args, **kw): | |
ts = time() | |
result = f(*args, **kw) | |
te = time() | |
logging.debug("func:%r took: %2.4f sec" % (f.__name__, te - ts)) | |
return result | |
return wrap | |
def split_metapath(in_paths: List[str]): | |
other_paths = [] | |
for itm_path in in_paths: | |
other_paths.append(itm_path) | |
return other_paths | |