vall-e/vall_e/utils/io.py

88 lines
2.3 KiB
Python

import torch
import json
from pathlib import Path
from safetensors import safe_open as sft_load
from safetensors.torch import save_file as sft_save
def coerce_path( path ):
return path if isinstance( path, Path ) else Path(path)
def pick_path( path, *suffixes ):
suffixes = [*suffixes]
for suffix in suffixes:
p = path.with_suffix( suffix )
if p.exists():
return p
return path
def is_dict_of( d, t ):
if not isinstance( d, dict ):
return False
return all([ isinstance(v, torch.Tensor) for v in d.values() ])
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
metadata = None
# is a state_dict, no need to coerce
if is_dict_of( data, torch.Tensor ):
return data, metadata
# is maybe a dict with a state dict + metadata, coerce it
metadata = {}
target = module_key
if not target:
for k, v in data.items():
# is a dict of tensors, our target
if is_dict_of( v, torch.Tensor ):
target = k
continue # continue to iterate to grab other metadata
# not a dict of tensors, put it as metadata
try:
metadata[k] = json.dumps(v)
except Exception as e:
pass
if not target:
raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
return data[target], metadata
def torch_save( data, path, module_key=None ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
return sft_save( data, path, metadata )
return torch.save( data, path )
def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
state_dict = {}
with sft_load(path, framework=framework, device=device) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if load_metadata:
metadata = f.metadata()
for k, v in metadata.items():
try:
metadata[k] = json.loads( v )
except Exception as e:
pass
state_dict = { module_key: state_dict } | metadata
return state_dict
return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )