Spaces:
Build error
Build error
import numpy as np | |
import torch | |
from typing import Dict | |
def run_assertion(orig_pt_state_dict: Dict[str, torch.Tensor], pt_state_dict_from_tf: Dict[str, torch.Tensor]): | |
for k in orig_pt_state_dict: | |
try: | |
np.testing.assert_allclose( | |
orig_pt_state_dict[k].numpy(), | |
pt_state_dict_from_tf[k].numpy() | |
) | |
except: | |
raise ValueError("There are problems in the parameter population process. Cannot proceed :(") |