added safetensors support (with metadata) and feed whatever torch.load/torch.save into it
This commit is contained in:
parent
6a733eb2ed
commit
c09133d00f
1
setup.py
1
setup.py
|
@ -53,6 +53,7 @@ setup(
|
|||
# HF bloat
|
||||
"tokenizers",
|
||||
"transformers",
|
||||
"safetensors",
|
||||
|
||||
# training bloat
|
||||
"auraloss[all]", # [all] is needed for MelSTFTLoss
|
||||
|
|
|
@ -716,6 +716,8 @@ class Config(BaseConfig):
|
|||
|
||||
audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac""
|
||||
|
||||
weights_format: str = "pth" # "pth" | "sft"
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
for i, model in enumerate(self.models):
|
||||
|
@ -882,10 +884,14 @@ class Config(BaseConfig):
|
|||
try:
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path
|
||||
if not tokenizer_path.exists():
|
||||
tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None
|
||||
if tokenizer_path and not tokenizer_path.exists():
|
||||
tokenizer_path = Path("./data/") / cfg.tokenizer_path
|
||||
|
||||
if tokenizer_path and tokenizer_path.exists():
|
||||
cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path))
|
||||
else:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
except Exception as e:
|
||||
cfg.tokenizer = NaiveTokenizer()
|
||||
print("Error while parsing tokenizer:", e)
|
||||
|
|
|
@ -14,6 +14,7 @@ from .config import cfg
|
|||
from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt
|
||||
from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler
|
||||
from .utils.distributed import global_rank, local_rank, world_size
|
||||
from .utils.io import torch_save, torch_load
|
||||
|
||||
from collections import defaultdict
|
||||
from functools import cache, cached_property
|
||||
|
@ -739,7 +740,7 @@ class Dataset(_Dataset):
|
|||
"samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() },
|
||||
"spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() },
|
||||
}
|
||||
torch.save(state_dict, path)
|
||||
torch_save(state_dict, path)
|
||||
|
||||
def load_state_dict(self, path = None):
|
||||
if path is None:
|
||||
|
@ -748,7 +749,7 @@ class Dataset(_Dataset):
|
|||
if not path.exists():
|
||||
return
|
||||
|
||||
state_dict = torch.load(path)
|
||||
state_dict = torch_load(path)
|
||||
if self.sampler_type == "path":
|
||||
state_dict = self.sampler.set_state(state_dict)
|
||||
else:
|
||||
|
|
|
@ -12,6 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine
|
|||
|
||||
from ..models import get_models, get_model
|
||||
from ..utils import wrapper as ml
|
||||
from ..utils.io import torch_save, torch_load
|
||||
from ..models.lora import apply_lora, lora_load_state_dict
|
||||
|
||||
import torch
|
||||
|
@ -42,7 +43,7 @@ def load_engines(training=True):
|
|||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
load_path = cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}"
|
||||
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
|
@ -51,15 +52,15 @@ def load_engines(training=True):
|
|||
# to handle the issue of training with deepspeed, but inferencing with local
|
||||
if checkpoint_path.exists() and backend == "local":
|
||||
tag = open(checkpoint_path).read()
|
||||
checkpoint_path = checkpoint_path.parent / tag / "state.pth"
|
||||
checkpoint_path = checkpoint_path.parent / tag / f"state.{cfg.weights_format}"
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found.")
|
||||
print("Checkpoint missing, but weights found:", load_path)
|
||||
loads_state_dict = True
|
||||
|
||||
# load state early
|
||||
if loads_state_dict:
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
state = torch_load(load_path, device=cfg.device)
|
||||
|
||||
# check if config is defined in state, and re-initialize the model
|
||||
if "config" in state and False:
|
||||
|
@ -196,11 +197,11 @@ def load_engines(training=True):
|
|||
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth"
|
||||
lora_path = cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}"
|
||||
if lora_path.exists():
|
||||
print( "Loaded LoRA state dict:", lora_path )
|
||||
|
||||
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
state = torch_load(lora_path, device=cfg.device)
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
lora_load_state_dict( model, state )
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ def default_feeder(engine, batch):
|
|||
from ..config import cfg
|
||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed
|
||||
from ..utils.io import torch_save, torch_load
|
||||
from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
||||
|
||||
import logging
|
||||
|
@ -136,10 +137,10 @@ class Engine():
|
|||
lora, module = lora_get_state_dict( module, split = True )
|
||||
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
save_path = save_dir / tag / "state.pth"
|
||||
save_path = save_dir / tag / f"state.{cfg.weights_format}"
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
torch.save({
|
||||
torch_save({
|
||||
"module": module,
|
||||
"lora": lora,
|
||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||
|
@ -170,12 +171,12 @@ class Engine():
|
|||
|
||||
tag = open(tag_path).read()
|
||||
|
||||
load_path = load_dir / tag / "state.pth"
|
||||
load_path = load_dir / tag / f"state.{cfg.weights_format}"
|
||||
|
||||
if not load_path.exists():
|
||||
return
|
||||
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
state = torch_load(load_path, device=cfg.device)
|
||||
|
||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||
|
@ -187,10 +188,10 @@ class Engine():
|
|||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||
|
||||
if load_optimizer_states:
|
||||
self.optimizer.load_state_dict(state['optimizer']) #, map_location=torch.device(cfg.device))
|
||||
self.optimizer.load_state_dict(state['optimizer']) #, device=cfg.device)
|
||||
|
||||
if load_lr_scheduler_states:
|
||||
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device))
|
||||
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device)
|
||||
|
||||
if 'lora' in state:
|
||||
lora_load_state_dict( self.module, state['lora'] )
|
||||
|
@ -324,17 +325,25 @@ class Engines(dict[str, Engine]):
|
|||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def export(self, userdata={}, callback=None, dtype=None):
|
||||
def export(self, userdata={}, callback=None, dtype=None, format=None):
|
||||
if not format:
|
||||
format = cfg.weights_format
|
||||
format = format.lower()
|
||||
|
||||
if dtype is None:
|
||||
dtype = cfg.trainer.dtype
|
||||
|
||||
for name, engine in self.items():
|
||||
module = engine.module.state_dict()
|
||||
lora = None
|
||||
save_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
save_path = cfg.ckpt_dir / name / f"fp32.{format}"
|
||||
config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config
|
||||
|
||||
# coerce
|
||||
if not isinstance(config, dict):
|
||||
config = config.__dict__
|
||||
if not isinstance(config['experimental'], dict):
|
||||
config['experimental'] = config['experimental'].__dict__
|
||||
|
||||
# safety
|
||||
for k, v in module.items():
|
||||
|
@ -342,7 +351,7 @@ class Engines(dict[str, Engine]):
|
|||
|
||||
if cfg.lora is not None:
|
||||
lora, module = lora_get_state_dict( module, split = True )
|
||||
save_path = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth"
|
||||
save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}"
|
||||
|
||||
state_dict = {
|
||||
'module': module,
|
||||
|
@ -363,7 +372,7 @@ class Engines(dict[str, Engine]):
|
|||
if callback:
|
||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
||||
torch.save(state_dict, save_path)
|
||||
torch_save(state_dict, save_path)
|
||||
print(f"Exported {name} to {save_path}")
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
|
|
|
@ -7,6 +7,7 @@ from .data import get_phone_symmap
|
|||
from .engines import load_engines
|
||||
from .config import cfg
|
||||
from .models.lora import lora_get_state_dict
|
||||
from .utils.io import torch_save, torch_load
|
||||
|
||||
# stitches embeddings into one embedding & classifier => lm_head
|
||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||
|
@ -61,12 +62,16 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
|||
if dtype is None:
|
||||
dtype = cfg.inference.dtype
|
||||
|
||||
format = save_path.stem[1:]
|
||||
|
||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||
# should always be included, but just in case
|
||||
if lora is None and "module" in state_dict:
|
||||
lora, module = lora_get_state_dict( state_dict["module"], split = True )
|
||||
state_dict["module"] = module
|
||||
state_dict["lora"] = lora
|
||||
|
||||
if "lora" in state_dict:
|
||||
state_dict["lora"] = None
|
||||
|
||||
# should raise an exception since there's nothing to extract, or at least a warning
|
||||
if not lora:
|
||||
|
@ -74,8 +79,8 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ):
|
|||
|
||||
# save lora specifically
|
||||
# should probably export other attributes, similar to what SD LoRAs do
|
||||
save_path = save_path.parent / "lora.pth"
|
||||
torch.save( {
|
||||
save_path = save_path.parent / f"lora.{format}"
|
||||
torch_save( {
|
||||
"module": lora,
|
||||
"config": cfg.lora.__dict__ if cfg.lora is not None else None,
|
||||
}, save_path )
|
||||
|
@ -109,8 +114,12 @@ def main():
|
|||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||
parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads
|
||||
parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to
|
||||
parser.add_argument("--format", type=str, default="pth") # set target format to export weights under
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]:
|
||||
raise Exception(f"Unknown requested format: {args.format}")
|
||||
|
||||
if args.module_only:
|
||||
cfg.trainer.load_module_only = True
|
||||
|
||||
|
@ -132,7 +141,7 @@ def main():
|
|||
cfg.inference.backend = cfg.trainer.backend
|
||||
|
||||
engines = load_engines(training=False) # to ignore loading optimizer state
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback, format=args.format)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
78
vall_e/utils/io.py
Normal file
78
vall_e/utils/io.py
Normal file
|
@ -0,0 +1,78 @@
|
|||
import torch
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
from safetensors import safe_open as sft_load
|
||||
from safetensors.torch import save_file as sft_save
|
||||
|
||||
def coerce_path( path ):
|
||||
return path if isinstance( path, Path ) else Path(path)
|
||||
|
||||
def is_dict_of( d, t ):
|
||||
if not isinstance( d, dict ):
|
||||
return False
|
||||
|
||||
return all([ isinstance(v, torch.Tensor) for v in d.values() ])
|
||||
|
||||
# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors
|
||||
def state_dict_to_tensor_metadata( data: dict, module_key=None ):
|
||||
metadata = None
|
||||
|
||||
# is a state_dict, no need to coerce
|
||||
if is_dict_of( data, torch.Tensor ):
|
||||
return data, metadata
|
||||
|
||||
# is maybe a dict with a state dict + metadata, coerce it
|
||||
metadata = {}
|
||||
target = module_key
|
||||
if not target:
|
||||
for k, v in data.items():
|
||||
# is a dict of tensors, our target
|
||||
if is_dict_of( v, torch.Tensor ):
|
||||
target = k
|
||||
continue # continue to iterate to grab other metadata
|
||||
|
||||
# not a dict of tensors, put it as metadata
|
||||
try:
|
||||
metadata[k] = json.dumps(v)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if not target:
|
||||
raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}')
|
||||
|
||||
return data[target], metadata
|
||||
|
||||
def torch_save( data, path, module_key=None ):
|
||||
path = coerce_path(path)
|
||||
ext = path.suffix
|
||||
|
||||
if ext in [".safetensor", ".sft"]:
|
||||
data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key )
|
||||
|
||||
return sft_save( data, path, metadata )
|
||||
|
||||
return torch.save( data, path )
|
||||
|
||||
def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ):
|
||||
path = coerce_path(path)
|
||||
ext = path.suffix
|
||||
|
||||
if ext in [".safetensor", ".sft"]:
|
||||
state_dict = {}
|
||||
with sft_load(path, framework=framework, device=device) as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
|
||||
if load_metadata:
|
||||
metadata = f.metadata()
|
||||
for k, v in metadata.items():
|
||||
try:
|
||||
metadata[k] = json.loads( v )
|
||||
except Exception as e:
|
||||
pass
|
||||
state_dict = { module_key: state_dict } | metadata
|
||||
|
||||
return state_dict
|
||||
|
||||
return torch.load( path, map_location=torch.device(device), weights_only=not unsafe )
|
Loading…
Reference in New Issue
Block a user