diff --git a/vall_e/config.py b/vall_e/config.py index 9d13c16..e8f7560 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -338,6 +338,9 @@ class Model: if self.arch_type == "llama": include = ["self_attn", "mlp"] # target only the attention + mlp exclude = ["self_attn.k_proj"] # common literature says to ignore it + if self.arch_type == "retnet": + include = ["layers."] # target the core layers of the RetNet and ignore the auxiliary stuff + exclude = ["retention.k_proj"] # attention-based transformers ignore the K, so might as well ignore it for the retnet return dict(include=include, exclude=exclude) @@ -585,6 +588,7 @@ class Trainer: load_disabled_engines: bool = False weight_dtype: str = "float16" + amp: bool = False ddp: bool = False diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 2d97cc2..93eb337 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -10,7 +10,7 @@ elif cfg.trainer.backend == "local": from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine -from ..models import get_models +from ..models import get_models, get_model from ..utils import wrapper as ml from ..models.lora import apply_lora @@ -32,23 +32,49 @@ def load_engines(training=True): engines = dict() for name, model in models.items(): + state = None + stats = None + lora = None + + inferencing = cfg.mode == "inferencing" or not model.config.training or not training + loads_state_dict = cfg.trainer.load_state_dict or inferencing + + checkpoint_path = cfg.ckpt_dir / name / "latest" + # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present + load_path = cfg.ckpt_dir / name / "fp32.pth" + + # actually use the lora-specific checkpoint if available + if cfg.lora is not None: + lora = cfg.lora + checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest" + + if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): + print("Checkpoint missing, but weights found.") + loads_state_dict = True + + # load state early + if loads_state_dict: + state = torch.load(load_path, map_location=torch.device(cfg.device)) + + # check if config is defined in state, and re-initialize the model + if "config" in state: + print("Model config definition in weights, re-loading...") + config_state = state["config"] + model = get_model( config=cfg.model.__class__( *config_state ), training=training ) + hyper_config = model.config optimizer = None lr_scheduler = None - inferencing = cfg.mode == "inferencing" or not model.config.training or not training backend = cfg.inference.backend if inferencing else cfg.trainer.backend dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype amp = cfg.inference.amp if inferencing else cfg.trainer.amp - loads_state_dict = cfg.trainer.load_state_dict or inferencing ddp = cfg.trainer.ddp engine_class = LocalEngine if backend == "local" or inferencing else Engine - if inferencing: - model.config.training = False - + # apply model replacers if cfg.optimizations.replace and cfg.optimizations.linear: model.model = ml.replace_linear( model.model ) @@ -58,6 +84,9 @@ def load_engines(training=True): for lora in cfg.loras: model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize ) + if inferencing: + model.config.training = False + if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)): optimizer_class = None scheduler_class = None @@ -116,22 +145,8 @@ def load_engines(training=True): optimizer = None lr_scheduler = None - checkpoint_path = cfg.ckpt_dir / name / "latest" - # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present - load_path = cfg.ckpt_dir / name / "fp32.pth" - - # actually use the lora-specific checkpoint if available - if cfg.lora is not None: - checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest" - - if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): - print("Checkpoint missing, but weights found.") - loads_state_dict = True - - stats = None + # load state dict if requested / required if loads_state_dict: - state = torch.load(load_path, map_location=torch.device(cfg.device)) - # state dict is not just the module, extract the extra trainer details if "stats" in state: stats = state["stats"] diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index bb2f532..0232f28 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -320,11 +320,21 @@ class Engines(dict[str, Engine]): for engine in self.values(): engine.dispatch_attribute(*args, **kwargs) - def export(self, userdata={}, callback=None): + def export(self, userdata={}, callback=None, dtype=None): + if dtype is None: + dtype = cfg.trainer.dtype + for name, engine in self.items(): module = engine.module.state_dict() lora = None save_path = cfg.ckpt_dir / name / "fp32.pth" + config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config + if not isinstance(config, dict): + config = config.__dict__ + + # safety + for k, v in module.items(): + module[k] = v.to(dtype) if cfg.lora is not None: lora, module = lora_get_state_dict( module, split = True ) @@ -339,7 +349,8 @@ class Engines(dict[str, Engine]): "global_samples": engine.global_samples, "tokens_processed": engine.tokens_processed, }, - "userdata": userdata + "userdata": userdata, + "config": config } if callback: state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) diff --git a/vall_e/export.py b/vall_e/export.py index 27329a0..b6a8b19 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -56,7 +56,10 @@ def convert_to_hf( state_dict, config = None, save_path = None ): return state_dict -def extract_lora( state_dict, config = None, save_path = None ): +def extract_lora( state_dict, config = None, save_path = None, dtype = None ): + if dtype is None: + dtype = cfg.inference.dtype + lora = state_dict["lora"] if "lora" in state_dict else None # should always be included, but just in case if lora is None and "module" in state_dict: @@ -71,7 +74,10 @@ def extract_lora( state_dict, config = None, save_path = None ): # save lora specifically # should probably export other attributes, similar to what SD LoRAs do save_path = save_path.parent / "lora.pth" - torch.save( { "module": lora }, save_path ) + torch.save( { + "module": lora, + "config": cfg.lora.__dict__ if cfg.lora is not None else None, + }, save_path ) return state_dict @@ -81,6 +87,7 @@ def main(): parser.add_argument("--module-only", action='store_true') parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--lora", action='store_true', default=None) # exports LoRA + parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to args, unknown = parser.parse_known_args() if args.module_only: @@ -95,7 +102,10 @@ def main(): if args.hf and args.lora: raise Exception("Requesting more than one callback") - engines = load_engines() + if args.dtype != "auto": + cfg.trainer.weight_dtype = args.dtype + + engines = load_engines(training=False) engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) if __name__ == "__main__": diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 3be45aa..a784d4a 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -581,6 +581,18 @@ class Base(nn.Module): )) self.model = RetNetDecoder(RetNetConfig(**kwargs)) + + # do some funny stuff for LoRA training + """ + if self.gradient_checkpointing: + def make_inputs_require_grads(module, input, output): + for i, t in enumerate(input): + if not isinstance(t, torch.Tensor): + continue + t.requires_grad_(True) + + self.model.register_forward_hook(make_inputs_require_grads) + """ elif self.arch_type == "retnet-hf": kwargs = dict( vocab_size=n_resp_tokens, @@ -713,7 +725,7 @@ class Base(nn.Module): x = inputs m = mask.squeeze(-1).int() aux_loss = None - + # HF transformer derived model if self.arch_type in ["llama", "mistral", "mixtral"]: kwargs = dict( diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index eb0cbb7..9b82b28 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -148,6 +148,7 @@ class ParameterizedLoRA(nn.Module): def passes_policy( policy, name ): if policy is None: return True + if "exclude" in policy: for term in policy["exclude"]: if term in name: