import numpy as np | |
import torch | |
from glob import glob | |
from safetensors.torch import save_file, load_file | |
patch_weights = np.load("589-20240113-071533.npz") | |
for file in glob("model*.safetensors"): | |
print(f"{file=}") | |
weights = load_file(file) | |
for k, tensor in weights.items(): | |
if k in patch_weights: | |
print(f"patching {k}") | |
weights[k] = torch.from_numpy(patch_weights[k]) | |
save_file(weights, "patched_" + file, metadata={"format": "pt"}) | |