diff --git a/setup.py b/setup.py index de7f8c9..ba14d79 100755 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ setup( # HF bloat "tokenizers", "transformers", + "safetensors", # training bloat "auraloss[all]", # [all] is needed for MelSTFTLoss diff --git a/vall_e/config.py b/vall_e/config.py index d830183..ea3f0cf 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -716,6 +716,8 @@ class Config(BaseConfig): audio_backend: str = "vocos" # audio backend to use "encodec" | "vocos" | "dac"" + weights_format: str = "pth" # "pth" | "sft" + @property def model(self): for i, model in enumerate(self.models): @@ -882,10 +884,14 @@ class Config(BaseConfig): try: from transformers import PreTrainedTokenizerFast - tokenizer_path = cfg.rel_path / cfg.tokenizer_path - if not tokenizer_path.exists(): + tokenizer_path = cfg.rel_path / cfg.tokenizer_path if cfg.yaml_path is not None else None + if tokenizer_path and not tokenizer_path.exists(): tokenizer_path = Path("./data/") / cfg.tokenizer_path - cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) + + if tokenizer_path and tokenizer_path.exists(): + cfg.tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_path)) + else: + cfg.tokenizer = NaiveTokenizer() except Exception as e: cfg.tokenizer = NaiveTokenizer() print("Error while parsing tokenizer:", e) diff --git a/vall_e/data.py b/vall_e/data.py index 4610884..e20be63 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -14,6 +14,7 @@ from .config import cfg from .emb.qnt import trim, trim_random, repeat_extend_audio, concat_audio, merge_audio, decode_to_file, decode as decode_qnt, encode as encode_qnt from .utils.sampler import PoolSampler, OrderedSampler, BatchedOrderedSampler, RandomSampler from .utils.distributed import global_rank, local_rank, world_size +from .utils.io import torch_save, torch_load from collections import defaultdict from functools import cache, cached_property @@ -739,7 +740,7 @@ class Dataset(_Dataset): "samplers": { name: sampler.get_state() for name, sampler in self.samplers.items() }, "spkr_samplers": { name: sampler.get_state() for name, sampler in self.spkr_samplers.items() }, } - torch.save(state_dict, path) + torch_save(state_dict, path) def load_state_dict(self, path = None): if path is None: @@ -748,7 +749,7 @@ class Dataset(_Dataset): if not path.exists(): return - state_dict = torch.load(path) + state_dict = torch_load(path) if self.sampler_type == "path": state_dict = self.sampler.set_state(state_dict) else: diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 3046aaf..f0c76d6 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -12,6 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine from ..models import get_models, get_model from ..utils import wrapper as ml +from ..utils.io import torch_save, torch_load from ..models.lora import apply_lora, lora_load_state_dict import torch @@ -42,7 +43,7 @@ def load_engines(training=True): checkpoint_path = cfg.ckpt_dir / name / "latest" # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present - load_path = cfg.ckpt_dir / name / "fp32.pth" + load_path = cfg.ckpt_dir / name / f"fp32.{cfg.weights_format}" # actually use the lora-specific checkpoint if available if cfg.lora is not None: @@ -51,15 +52,15 @@ def load_engines(training=True): # to handle the issue of training with deepspeed, but inferencing with local if checkpoint_path.exists() and backend == "local": tag = open(checkpoint_path).read() - checkpoint_path = checkpoint_path.parent / tag / "state.pth" + checkpoint_path = checkpoint_path.parent / tag / f"state.{cfg.weights_format}" if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): - print("Checkpoint missing, but weights found.") + print("Checkpoint missing, but weights found:", load_path) loads_state_dict = True # load state early if loads_state_dict: - state = torch.load(load_path, map_location=torch.device(cfg.device)) + state = torch_load(load_path, device=cfg.device) # check if config is defined in state, and re-initialize the model if "config" in state and False: @@ -196,11 +197,11 @@ def load_engines(training=True): # load lora weights if exists if cfg.lora is not None: - lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth" + lora_path = cfg.ckpt_dir / cfg.lora.full_name / f"lora.{cfg.weights_format}" if lora_path.exists(): print( "Loaded LoRA state dict:", lora_path ) - state = torch.load(lora_path, map_location=torch.device(cfg.device)) + state = torch_load(lora_path, device=cfg.device) state = state['lora' if 'lora' in state else 'module'] lora_load_state_dict( model, state ) diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 7313ab8..ed73888 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -29,6 +29,7 @@ def default_feeder(engine, batch): from ..config import cfg from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size, cleanup_distributed +from ..utils.io import torch_save, torch_load from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict import logging @@ -136,10 +137,10 @@ class Engine(): lora, module = lora_get_state_dict( module, split = True ) save_dir = cfg.ckpt_dir / cfg.lora.full_name - save_path = save_dir / tag / "state.pth" + save_path = save_dir / tag / f"state.{cfg.weights_format}" save_path.parent.mkdir(parents=True, exist_ok=True) - torch.save({ + torch_save({ "module": module, "lora": lora, "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, @@ -170,12 +171,12 @@ class Engine(): tag = open(tag_path).read() - load_path = load_dir / tag / "state.pth" + load_path = load_dir / tag / f"state.{cfg.weights_format}" if not load_path.exists(): return - state = torch.load(load_path, map_location=torch.device(cfg.device)) + state = torch_load(load_path, device=cfg.device) self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step'] self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] @@ -187,10 +188,10 @@ class Engine(): load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state if load_optimizer_states: - self.optimizer.load_state_dict(state['optimizer']) #, map_location=torch.device(cfg.device)) + self.optimizer.load_state_dict(state['optimizer']) #, device=cfg.device) if load_lr_scheduler_states: - self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device)) + self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, device=cfg.device) if 'lora' in state: lora_load_state_dict( self.module, state['lora'] ) @@ -324,17 +325,25 @@ class Engines(dict[str, Engine]): for engine in self.values(): engine.dispatch_attribute(*args, **kwargs) - def export(self, userdata={}, callback=None, dtype=None): + def export(self, userdata={}, callback=None, dtype=None, format=None): + if not format: + format = cfg.weights_format + format = format.lower() + if dtype is None: dtype = cfg.trainer.dtype for name, engine in self.items(): module = engine.module.state_dict() lora = None - save_path = cfg.ckpt_dir / name / "fp32.pth" + save_path = cfg.ckpt_dir / name / f"fp32.{format}" config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config + + # coerce if not isinstance(config, dict): config = config.__dict__ + if not isinstance(config['experimental'], dict): + config['experimental'] = config['experimental'].__dict__ # safety for k, v in module.items(): @@ -342,7 +351,7 @@ class Engines(dict[str, Engine]): if cfg.lora is not None: lora, module = lora_get_state_dict( module, split = True ) - save_path = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth" + save_path = cfg.ckpt_dir / cfg.lora.full_name / f"fp32.{format}" state_dict = { 'module': module, @@ -363,7 +372,7 @@ class Engines(dict[str, Engine]): if callback: state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) - torch.save(state_dict, save_path) + torch_save(state_dict, save_path) print(f"Exported {name} to {save_path}") def save_checkpoint(self, tag=None): diff --git a/vall_e/export.py b/vall_e/export.py index d5d958b..a807b22 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -7,6 +7,7 @@ from .data import get_phone_symmap from .engines import load_engines from .config import cfg from .models.lora import lora_get_state_dict +from .utils.io import torch_save, torch_load # stitches embeddings into one embedding & classifier => lm_head def convert_to_hf( state_dict, config = None, save_path = None ): @@ -61,12 +62,16 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ): if dtype is None: dtype = cfg.inference.dtype + format = save_path.stem[1:] + lora = state_dict["lora"] if "lora" in state_dict else None # should always be included, but just in case if lora is None and "module" in state_dict: lora, module = lora_get_state_dict( state_dict["module"], split = True ) state_dict["module"] = module - state_dict["lora"] = lora + + if "lora" in state_dict: + state_dict["lora"] = None # should raise an exception since there's nothing to extract, or at least a warning if not lora: @@ -74,8 +79,8 @@ def extract_lora( state_dict, config = None, save_path = None, dtype = None ): # save lora specifically # should probably export other attributes, similar to what SD LoRAs do - save_path = save_path.parent / "lora.pth" - torch.save( { + save_path = save_path.parent / f"lora.{format}" + torch_save( { "module": lora, "config": cfg.lora.__dict__ if cfg.lora is not None else None, }, save_path ) @@ -109,8 +114,12 @@ def main(): parser.add_argument("--lora", action='store_true', default=None) # exports LoRA parser.add_argument("--split-classifiers", action='store_true', default=None) # splits classifier heads parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to + parser.add_argument("--format", type=str, default="pth") # set target format to export weights under args, unknown = parser.parse_known_args() + if args.format.lower() not in ["sft", "safetensors", "pt", "pth"]: + raise Exception(f"Unknown requested format: {args.format}") + if args.module_only: cfg.trainer.load_module_only = True @@ -132,7 +141,7 @@ def main(): cfg.inference.backend = cfg.trainer.backend engines = load_engines(training=False) # to ignore loading optimizer state - engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) + engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback, format=args.format) if __name__ == "__main__": main() \ No newline at end of file diff --git a/vall_e/utils/io.py b/vall_e/utils/io.py new file mode 100644 index 0000000..4dad4bf --- /dev/null +++ b/vall_e/utils/io.py @@ -0,0 +1,78 @@ +import torch +import json + +from pathlib import Path +from safetensors import safe_open as sft_load +from safetensors.torch import save_file as sft_save + +def coerce_path( path ): + return path if isinstance( path, Path ) else Path(path) + +def is_dict_of( d, t ): + if not isinstance( d, dict ): + return False + + return all([ isinstance(v, torch.Tensor) for v in d.values() ]) + +# handles converting the usual pth state_dict into just the dict with the tensors + a dict of JSON strings, for safetensors +def state_dict_to_tensor_metadata( data: dict, module_key=None ): + metadata = None + + # is a state_dict, no need to coerce + if is_dict_of( data, torch.Tensor ): + return data, metadata + + # is maybe a dict with a state dict + metadata, coerce it + metadata = {} + target = module_key + if not target: + for k, v in data.items(): + # is a dict of tensors, our target + if is_dict_of( v, torch.Tensor ): + target = k + continue # continue to iterate to grab other metadata + + # not a dict of tensors, put it as metadata + try: + metadata[k] = json.dumps(v) + except Exception as e: + pass + + if not target: + raise Exception(f'Requesting to save safetensors of a state dict, but state dict contains no key of torch.Tensor: {path}') + + return data[target], metadata + +def torch_save( data, path, module_key=None ): + path = coerce_path(path) + ext = path.suffix + + if ext in [".safetensor", ".sft"]: + data, metadata = state_dict_to_tensor_metadata( data, module_key=module_key ) + + return sft_save( data, path, metadata ) + + return torch.save( data, path ) + +def torch_load( path, device="cpu", framework="pt", unsafe=True, load_metadata=True, module_key="module" ): + path = coerce_path(path) + ext = path.suffix + + if ext in [".safetensor", ".sft"]: + state_dict = {} + with sft_load(path, framework=framework, device=device) as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + + if load_metadata: + metadata = f.metadata() + for k, v in metadata.items(): + try: + metadata[k] = json.loads( v ) + except Exception as e: + pass + state_dict = { module_key: state_dict } | metadata + + return state_dict + + return torch.load( path, map_location=torch.device(device), weights_only=not unsafe ) \ No newline at end of file