88 lines
2.3 KiB
Python
88 lines
2.3 KiB
Python
import torch
|
|
import json
|
|
|
|
from pathlib import Path
|
|
from safetensors import safe_open as sft_load
|
|
from safetensors.torch import save_file as sft_save
|
|
|
|
def coerce_path( path ):
|
|
return path if isinstance( path, Path ) else Path(path)
|
|
|
|
def pick_path( path, *suffixes ):
|
|
suffixes = [*suffixes]
|
|
|
|
for suffix in suffixes:
|
|
p = path.with_suffix( suffix )
|
|
if p.exists():
|
|
return p
|
|
|
|
return path
|
|
|
|
def is_dict_of( d, t ):
|
|
if not isinstance( d, dict ):
|
|
return False
|
|
|
|
return all([ isinstance(v, torch.Tensor) for v in d.values() ])
|
|
|
|
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
|
|
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
|
|
metadata = None
|
|
|
|
# is a state_dict, no need to coerce
|
|
if is_dict_of( data, torch.Tensor ):
|
|
return data, metadata
|
|
|
|
# is maybe a dict with a state dict + metadata, coerce it
|
|
metadata = {}
|
|
target = module_key
|
|
if not target:
|
|
for k, v in data.items():
|
|
# is a dict of tensors, our target
|
|
if is_dict_of( v, torch.Tensor ):
|
|
target = k
|
|
continue # continue to iterate to grab other metadata
|
|
|
|
# not a dict of tensors, put it as metadata
|
|
try:
|
|
metadata[k] = json.dumps(v)
|
|
except Exception as e:
|
|
pass
|
|
|
|
if not target:
|
|
raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
|
|
|
|
return data[target], metadata
|
|
|
|
def torch_save( data, path, module_key=None ):
|
|
path = coerce_path(path)
|
|
ext = path.suffix
|
|
|
|
if ext in [".safetensor", ".sft"]:
|
|
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
|
|
|
|
return sft_save( data, path, metadata )
|
|
|
|
return torch.save( data, path )
|
|
|
|
def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
|
|
path = coerce_path(path)
|
|
ext = path.suffix
|
|
|
|
if ext in [".safetensor", ".sft"]:
|
|
state_dict = {}
|
|
with sft_load(path, framework=framework, device=device) as f:
|
|
for k in f.keys():
|
|
state_dict[k] = f.get_tensor(k)
|
|
|
|
if load_metadata:
|
|
metadata = f.metadata()
|
|
for k, v in metadata.items():
|
|
try:
|
|
metadata[k] = json.loads( v )
|
|
except Exception as e:
|
|
pass
|
|
state_dict = { module_key: state_dict } | metadata
|
|
|
|
return state_dict
|
|
|
|
return torch.load( path, map_location=torch.device(device), weights_only=not unsafe ) |