From fe0f23533515c277dfca3a118d580fca587bd17d Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 16 Jul 2024 18:23:13 -0500 Subject: [PATCH] mechanism to store the model config inside the weights and load them, some other things to allow LoRA training on the RetNet (gradient checkpointing will gripe about inputs not having require_grad and nothing seems to remedy it) --- vall_e/config.py | 4 +++ vall_e/engines/__init__.py | 57 ++++++++++++++++++++++++-------------- vall_e/engines/base.py | 15 ++++++++-- vall_e/export.py | 16 +++++++++-- vall_e/models/base.py | 14 +++++++++- vall_e/models/lora.py | 1 + 6 files changed, 80 insertions(+), 27 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 9d13c16..e8f7560 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -338,6 +338,9 @@ class Model: if self.arch_type == "llama": include = ["self_attn", "mlp"] # target only the attention + mlp exclude = ["self_attn.k_proj"] # common literature says to ignore it + if self.arch_type == "retnet": + include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff + exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet return dict(include=include, exclude=exclude) @@ -585,6 +588,7 @@ class Trainer: load_disabled_engines: bool = False weight_dtype: str = "float16" + amp: bool = False ddp: bool = False diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 2d97cc2..93eb337 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -10,7 +10,7 @@ elif cfg.trainer.backend == "local": from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine -from ..models import get_models +from ..models import get_models, get_model from ..utils import wrapper as ml from ..models.lora import apply_lora @@ -32,23 +32,49 @@ def load_engines(training=True): engines = dict() for name, model in models.items(): + state = None + stats = None + lora = None + + inferencing = cfg.mode == "inferencing" or not model.config.training or not training + loads_state_dict = cfg.trainer.load_state_dict or inferencing + + checkpoint_path = cfg.ckpt_dir / name / "latest" + # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present + load_path = cfg.ckpt_dir / name / "fp32.pth" + + # actually use the lora-specific checkpoint if available + if cfg.lora is not None: + lora = cfg.lora + checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest" + + if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): + print("Checkpoint missing, but weights found.") + loads_state_dict = True + + # load state early + if loads_state_dict: + state = torch.load(load_path, map_location=torch.device(cfg.device)) + + # check if config is defined in state, and re-initialize the model + if "config" in state: + print("Model config definition in weights, re-loading...") + config_state = state["config"] + model = get_model( config=cfg.model.__class__( *config_state ), training=training ) + hyper_config = model.config optimizer = None lr_scheduler = None - inferencing = cfg.mode == "inferencing" or not model.config.training or not training backend = cfg.inference.backend if inferencing else cfg.trainer.backend dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype amp = cfg.inference.amp if inferencing else cfg.trainer.amp - loads_state_dict = cfg.trainer.load_state_dict or inferencing ddp = cfg.trainer.ddp engine_class = LocalEngine if backend == "local" or inferencing else Engine - if inferencing: - model.config.training = False - + # apply model replacers if cfg.optimizations.replace and cfg.optimizations.linear: model.model = ml.replace_linear( model.model ) @@ -58,6 +84,9 @@ def load_engines(training=True): for lora in cfg.loras: model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize ) + if inferencing: + model.config.training = False + if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)): optimizer_class = None scheduler_class = None @@ -116,22 +145,8 @@ def load_engines(training=True): optimizer = None lr_scheduler = None - checkpoint_path = cfg.ckpt_dir / name / "latest" - # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present - load_path = cfg.ckpt_dir / name / "fp32.pth" - - # actually use the lora-specific checkpoint if available - if cfg.lora is not None: - checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest" - - if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): - print("Checkpoint missing, but weights found.") - loads_state_dict = True - - stats = None + # load state dict if requested / required if loads_state_dict: - state = torch.load(load_path, map_location=torch.device(cfg.device)) - # state dict is not just the module, extract the extra trainer details if "stats" in state: stats = state["stats"] diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index bb2f532..0232f28 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -320,11 +320,21 @@ class Engines(dict[str, Engine]): for engine in self.values(): engine.dispatch_attribute(*args, **kwargs) - def export(self, userdata={}, callback=None): + def export(self, userdata={}, callback=None, dtype=None): + if dtype is None: + dtype = cfg.trainer.dtype + for name, engine in self.items(): module = engine.module.state_dict() lora = None save_path = cfg.ckpt_dir / name / "fp32.pth" + config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config + if not isinstance(config, dict): + config = config.__dict__ + + # safety + for k, v in module.items(): + module[k] = v.to(dtype) if cfg.lora is not None: lora, module = lora_get_state_dict( module, split = True ) @@ -339,7 +349,8 @@ class Engines(dict[str, Engine]): "global_samples": engine.global_samples, "tokens_processed": engine.tokens_processed, }, - "userdata": userdata + "userdata": userdata, + "config": config } if callback: state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) diff --git a/vall_e/export.py b/vall_e/export.py index 27329a0..b6a8b19 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -56,7 +56,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ): return state_dict -def extract_lora( state_dict, config = None, save_path = None ): +def extract_lora( state_dict, config = None, save_path = None, dtype = None ): + if dtype is None: + dtype = cfg.inference.dtype + lora = state_dict["lora"] if "lora" in state_dict else None # should always be included, but just in case if lora is None and "module" in state_dict: @@ -71,7 +74,10 @@ def extract_lora( state_dict, config = None, save_path = None ): # save lora specifically # should probably export other attributes, similar to what SD LoRAs do save_path = save_path.parent / "lora.pth" - torch.save( { "module": lora }, save_path ) + torch.save( { + "module": lora, + "config": cfg.lora.__dict__ if cfg.lora is not None else None, + }, save_path ) return state_dict @@ -81,6 +87,7 @@ def main(): parser.add_argument("--module-only", action='store_true') parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--lora", action='store_true', default=None) # exports LoRA + parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to args, unknown = parser.parse_known_args() if args.module_only: @@ -95,7 +102,10 @@ def main(): if args.hf and args.lora: raise Exception("Requesting more than one callback") - engines = load_engines() + if args.dtype != "auto": + cfg.trainer.weight_dtype = args.dtype + + engines = load_engines(training=False) engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) if __name__ == "__main__": diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 3be45aa..a784d4a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -581,6 +581,18 @@ class Base(nn.Module): )) self.model = RetNetDecoder(RetNetConfig(**kwargs)) + + # do some funny stuff for LoRA training + """ + if self.gradient_checkpointing: + def make_inputs_require_grads(module, input, output): + for i, t in enumerate(input): + if not isinstance(t, torch.Tensor): + continue + t.requires_grad_(True) + + self.model.register_forward_hook(make_inputs_require_grads) + """ elif self.arch_type == "retnet-hf": kwargs = dict( vocab_size=n_resp_tokens, @@ -713,7 +725,7 @@ class Base(nn.Module): x = inputs m = mask.squeeze(-1).int() aux_loss = None - + # HF transformer derived model if self.arch_type in ["llama", "mistral", "mixtral"]: kwargs = dict( diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index eb0cbb7..9b82b28 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -148,6 +148,7 @@ class ParameterizedLoRA(nn.Module): def passes_policy( policy, name ): if policy is None: return True + if "exclude" in policy: for term in policy["exclude"]: if term in name: