|
import base64 |
|
import pickle |
|
from dataclasses import dataclass |
|
from typing import Dict, Optional, Tuple |
|
|
|
import safetensors.torch |
|
import torch |
|
|
|
from .aliases import PathOrStr |
|
|
|
__all__ = [ |
|
"state_dict_to_safetensors_file", |
|
"safetensors_file_to_state_dict", |
|
] |
|
|
|
|
|
@dataclass(eq=True, frozen=True) |
|
class STKey: |
|
keys: Tuple |
|
value_is_pickled: bool |
|
|
|
|
|
def encode_key(key: STKey) -> str: |
|
b = pickle.dumps((key.keys, key.value_is_pickled)) |
|
b = base64.urlsafe_b64encode(b) |
|
return str(b, "ASCII") |
|
|
|
|
|
def decode_key(key: str) -> STKey: |
|
b = base64.urlsafe_b64decode(key) |
|
keys, value_is_pickled = pickle.loads(b) |
|
return STKey(keys, value_is_pickled) |
|
|
|
|
|
def flatten_dict(d: Dict) -> Dict[STKey, torch.Tensor]: |
|
result = {} |
|
for key, value in d.items(): |
|
if isinstance(value, torch.Tensor): |
|
result[STKey((key,), False)] = value |
|
elif isinstance(value, dict): |
|
value = flatten_dict(value) |
|
for inner_key, inner_value in value.items(): |
|
result[STKey((key,) + inner_key.keys, inner_key.value_is_pickled)] = inner_value |
|
else: |
|
pickled = bytearray(pickle.dumps(value)) |
|
pickled_tensor = torch.frombuffer(pickled, dtype=torch.uint8) |
|
result[STKey((key,), True)] = pickled_tensor |
|
return result |
|
|
|
|
|
def unflatten_dict(d: Dict[STKey, torch.Tensor]) -> Dict: |
|
result: Dict = {} |
|
|
|
for key, value in d.items(): |
|
if key.value_is_pickled: |
|
value = pickle.loads(value.numpy().data) |
|
|
|
target_dict = result |
|
for k in key.keys[:-1]: |
|
new_target_dict = target_dict.get(k) |
|
if new_target_dict is None: |
|
new_target_dict = {} |
|
target_dict[k] = new_target_dict |
|
target_dict = new_target_dict |
|
target_dict[key.keys[-1]] = value |
|
|
|
return result |
|
|
|
|
|
def state_dict_to_safetensors_file(state_dict: Dict, filename: PathOrStr): |
|
state_dict = flatten_dict(state_dict) |
|
state_dict = {encode_key(k): v for k, v in state_dict.items()} |
|
safetensors.torch.save_file(state_dict, filename) |
|
|
|
|
|
def safetensors_file_to_state_dict(filename: PathOrStr, map_location: Optional[str] = None) -> Dict: |
|
if map_location is None: |
|
map_location = "cpu" |
|
state_dict = safetensors.torch.load_file(filename, device=map_location) |
|
state_dict = {decode_key(k): v for k, v in state_dict.items()} |
|
return unflatten_dict(state_dict) |
|
|