added safetensors support (with metadata) and feed whatever torch.load/torch.save into it

This commit is contained in:
mrq 2024-08-03 23:15:20 -05:00
parent 6a733eb2ed
commit c09133d00f
7 changed files with 130 additions and 25 deletions

View File

@ -53,6 +53,7 @@ setup(
# HF bloat
"tokenizers",
"transformers",
"safetensors",
# training bloat
"auraloss[all]", # [all] is needed for MelSTFTLoss

View File

@ -716,6 +716,8 @@ class Config(BaseConfig):
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
weights_format: str = "pth" # "pth" | "sft"
@property
def model(self):
for i, model in enumerate(self.models):
@ -882,10 +884,14 @@ class Config(BaseConfig):
try:
from transformers import PreTrainedTokenizerFast
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
if not tokenizer_path.exists():
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
if tokenizer_path and not tokenizer_path.exists():
tokenizer_path = Path("./data/") / cfg.tokenizer_path
if tokenizer_path and tokenizer_path.exists():
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
else:
cfg.tokenizer = NaiveTokenizer()
except Exception as e:
cfg.tokenizer = NaiveTokenizer()
print("Error while parsing tokenizer:", e)

View File

@ -14,6 +14,7 @@ from .config import cfg
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
from .utils.distributed import global_rank, local_rank, world_size
from .utils.io import torch_save, torch_load
from collections import defaultdict
from functools import cache, cached_property
@ -739,7 +740,7 @@ class Dataset(_Dataset):
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
}
torch.save(state_dict, path)
torch_save(state_dict, path)
def load_state_dict(self, path = None):
if path is None:
@ -748,7 +749,7 @@ class Dataset(_Dataset):
if not path.exists():
return
state_dict = torch.load(path)
state_dict = torch_load(path)
if self.sampler_type == "path":
state_dict = self.sampler.set_state(state_dict)
else:

View File

@ -12,6 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
from ..models import get_models, get_model
from ..utils import wrapper as ml
from ..utils.io import torch_save, torch_load
from ..models.lora import apply_lora, lora_load_state_dict
import torch
@ -42,7 +43,7 @@ def load_engines(training=True):
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
load_path = cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}"
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
@ -51,15 +52,15 @@ def load_engines(training=True):
# to handle the issue of training with deepspeed, but inferencing with local
if checkpoint_path.exists() and backend == "local":
tag = open(checkpoint_path).read()
checkpoint_path = checkpoint_path.parent / tag / "state.pth"
checkpoint_path = checkpoint_path.parent / tag / f"state.{cfg.weights_format}"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")
print("Checkpoint missing, but weights found:", load_path)
loads_state_dict = True
# load state early
if loads_state_dict:
state = torch.load(load_path, map_location=torch.device(cfg.device))
state = torch_load(load_path, device=cfg.device)
# check if config is defined in state, and re-initialize the model
if "config" in state and False:
@ -196,11 +197,11 @@ def load_engines(training=True):
# load lora weights if exists
if cfg.lora is not None:
lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth"
lora_path = cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}"
if lora_path.exists():
print( "Loaded LoRA state dict:", lora_path )
state = torch.load(lora_path, map_location=torch.device(cfg.device))
state = torch_load(lora_path, device=cfg.device)
state = state['lora' if 'lora' in state else 'module']
lora_load_state_dict( model, state )

View File

@ -29,6 +29,7 @@ def default_feeder(engine, batch):
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
from ..utils.io import torch_save, torch_load
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
import logging
@ -136,10 +137,10 @@ class Engine():
lora, module = lora_get_state_dict( module, split = True )
save_dir = cfg.ckpt_dir / cfg.lora.full_name
save_path = save_dir / tag / "state.pth"
save_path = save_dir / tag / f"state.{cfg.weights_format}"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
torch_save({
"module": module,
"lora": lora,
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
@ -170,12 +171,12 @@ class Engine():
tag = open(tag_path).read()
load_path = load_dir / tag / "state.pth"
load_path = load_dir / tag / f"state.{cfg.weights_format}"
if not load_path.exists():
return
state = torch.load(load_path, map_location=torch.device(cfg.device))
state = torch_load(load_path, device=cfg.device)
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
@ -187,10 +188,10 @@ class Engine():
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
if load_optimizer_states:
self.optimizer.load_state_dict(state['optimizer']) #, map_location=torch.device(cfg.device))
self.optimizer.load_state_dict(state['optimizer']) #, device=cfg.device)
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device))
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device)
if 'lora' in state:
lora_load_state_dict( self.module, state['lora'] )
@ -324,17 +325,25 @@ class Engines(dict[str, Engine]):
for engine in self.values():
engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}, callback=None, dtype=None):
def export(self, userdata={}, callback=None, dtype=None, format=None):
if not format:
format = cfg.weights_format
format = format.lower()
if dtype is None:
dtype = cfg.trainer.dtype
for name, engine in self.items():
module = engine.module.state_dict()
lora = None
save_path = cfg.ckpt_dir / name / "fp32.pth"
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
# coerce
if not isinstance(config, dict):
config = config.__dict__
if not isinstance(config['experimental'], dict):
config['experimental'] = config['experimental'].__dict__
# safety
for k, v in module.items():
@ -342,7 +351,7 @@ class Engines(dict[str, Engine]):
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_path = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
state_dict = {
'module': module,
@ -363,7 +372,7 @@ class Engines(dict[str, Engine]):
if callback:
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
torch.save(state_dict, save_path)
torch_save(state_dict, save_path)
print(f"Exported {name} to {save_path}")
def save_checkpoint(self, tag=None):

View File

@ -7,6 +7,7 @@ from .data import get_phone_symmap
from .engines import load_engines
from .config import cfg
from .models.lora import lora_get_state_dict
from .utils.io import torch_save, torch_load
# stitches embeddings into one embedding & classifier => lm_head
def convert_to_hf( state_dict, config = None, save_path = None ):
@ -61,12 +62,16 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
if dtype is None:
dtype = cfg.inference.dtype
format = save_path.stem[1:]
lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case
if lora is None and "module" in state_dict:
lora, module = lora_get_state_dict( state_dict["module"], split = True )
state_dict["module"] = module
state_dict["lora"] = lora
if "lora" in state_dict:
state_dict["lora"] = None
# should raise an exception since there's nothing to extract, or at least a warning
if not lora:
@ -74,8 +79,8 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
# save lora specifically
# should probably export other attributes, similar to what SD LoRAs do
save_path = save_path.parent / "lora.pth"
torch.save( {
save_path = save_path.parent / f"lora.{format}"
torch_save( {
"module": lora,
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
}, save_path )
@ -109,8 +114,12 @@ def main():
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
args, unknown = parser.parse_known_args()
if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]:
raise Exception(f"Unknown requested format: {args.format}")
if args.module_only:
cfg.trainer.load_module_only = True
@ -132,7 +141,7 @@ def main():
cfg.inference.backend = cfg.trainer.backend
engines = load_engines(training=False) # to ignore loading optimizer state
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback, format=args.format)
if __name__ == "__main__":
main()

78
vall_e/utils/io.py Normal file
View File

@ -0,0 +1,78 @@
import torch
import json
from pathlib import Path
from safetensors import safe_open as sft_load
from safetensors.torch import save_file as sft_save
def coerce_path( path ):
return path if isinstance( path, Path ) else Path(path)
def is_dict_of( d, t ):
if not isinstance( d, dict ):
return False
return all([ isinstance(v, torch.Tensor) for v in d.values() ])
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
metadata = None
# is a state_dict, no need to coerce
if is_dict_of( data, torch.Tensor ):
return data, metadata
# is maybe a dict with a state dict + metadata, coerce it
metadata = {}
target = module_key
if not target:
for k, v in data.items():
# is a dict of tensors, our target
if is_dict_of( v, torch.Tensor ):
target = k
continue # continue to iterate to grab other metadata
# not a dict of tensors, put it as metadata
try:
metadata[k] = json.dumps(v)
except Exception as e:
pass
if not target:
raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
return data[target], metadata
def torch_save( data, path, module_key=None ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
return sft_save( data, path, metadata )
return torch.save( data, path )
def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
path = coerce_path(path)
ext = path.suffix
if ext in [".safetensor", ".sft"]:
state_dict = {}
with sft_load(path, framework=framework, device=device) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if load_metadata:
metadata = f.metadata()
for k, v in metadata.items():
try:
metadata[k] = json.loads( v )
except Exception as e:
pass
state_dict = { module_key: state_dict } | metadata
return state_dict
return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )