sayakpaul's picture
sayakpaul HF staff
add: files.
ddc8a59
raw
history blame
490 Bytes
import numpy as np
import torch
from typing import Dict
def run_assertion(orig_pt_state_dict: Dict[str, torch.Tensor], pt_state_dict_from_tf: Dict[str, torch.Tensor]):
for k in orig_pt_state_dict:
try:
np.testing.assert_allclose(
orig_pt_state_dict[k].numpy(),
pt_state_dict_from_tf[k].numpy()
)
except:
raise ValueError("There are problems in the parameter population process. Cannot proceed :(")