File size: 484 Bytes
7730b56 2f1b075 7730b56 2f1b075 7730b56 2f1b075 7730b56 2f1b075 7730b56 2f1b075 0e6349c 7730b56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import numpy as np
import torch
from glob import glob
from safetensors.torch import save_file, load_file
patch_weights = np.load("589-20240113-071533.npz")
for file in glob("model*.safetensors"):
print(f"{file=}")
weights = load_file(file)
for k, tensor in weights.items():
if k in patch_weights:
print(f"patching {k}")
weights[k] = torch.from_numpy(patch_weights[k])
save_file(weights, "patched_" + file, metadata={"format": "pt"})
|