import torch import json from pathlib import Path from safetensors import safe_open as sft_load from safetensors.torch import save_file as sft_save try: use_orjson = True import orjson as json except: import json from .utils import truncate_json def json_stringify( data, truncate=False, pretty=False, raw=False ): if truncate: return truncate_json( json.dumps( data ) ) if pretty: s = json.dumps( data, option=json.OPT_INDENT_2 ) if use_orjson else json.dumps( data, indent='\t' ) return s if raw and use_orjson else s.decode('utf-8') return json.dumps( data ) def json_parse( string ): return json.loads( string ) def json_read( path, default=None ): path = coerce_path( path ) if not path.exists(): return default with (open( str(path), "rb" ) if use_orjson else open( str(path), "r", encoding="utf-8" ) ) as f: return json_parse( f.read() ) def json_write( data, path, **kwargs ): path = coerce_path( path ) with (open( str(path), "wb" ) if use_orjson else open( str(path), "w", encoding="utf-8" ) ) as f: f.write( json_stringify( data, raw=use_orjson, **kwargs ) ) def coerce_path( path ): return path if isinstance( path, Path ) else Path(path) def pick_path( path, *suffixes ): suffixes = [*suffixes] for suffix in suffixes: p = path.with_suffix( suffix ) if p.exists(): return p return path def is_dict_of( d, t ): if not isinstance( d, dict ): return False return all([ isinstance(v, torch.Tensor) for v in d.values() ]) # handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors def state_dict_to_tensor_metadata( data: dict, module_key=None ): metadata = {} # is a state_dict, no need to coerce if is_dict_of( data, torch.Tensor ): return data, metadata # is maybe a dict with a state dict + metadata, coerce it target = module_key if not target: for k, v in data.items(): # is a dict of tensors, our target if is_dict_of( v, torch.Tensor ): target = k continue # continue to iterate to grab other metadata # not a dict of tensors, put it as metadata try: metadata[k] = json_stringify(v) if any([isinstance( v, dict ), isinstance( v, list )]) else v if isinstance( metadata[k], bytes ): metadata[k] = metadata[k].decode('utf-8') except Exception as e: pass if not target: raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}') return data[target], metadata def torch_save( data, path, module_key=None ): path = coerce_path(path) ext = path.suffix if ext in [".safetensor", ".safetensors", ".sft"]: data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key ) if metadata is None: metadata = {} return sft_save( data, path, metadata ) return torch.save( data, path ) def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ): path = coerce_path(path) ext = path.suffix if ext in [".safetensor", ".safetensors", ".sft"]: state_dict = {} with sft_load(path, framework=framework, device=device) as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if load_metadata: metadata = f.metadata() for k, v in metadata.items(): try: metadata[k] = json.loads( v ) except Exception as e: pass state_dict = { module_key: state_dict } | metadata return state_dict return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )