From 188d116222c9a5c0831785fcb755534d23c0a815 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 22 Jul 2024 20:47:24 -0500 Subject: [PATCH] some weird fixes for an equally weird regression with LoRA loading --- vall_e/__main__.py | 4 ++- vall_e/engines/__init__.py | 70 +++++++++++++------------------------- vall_e/engines/base.py | 11 ++++-- vall_e/models/lora.py | 3 +- 4 files changed, 35 insertions(+), 53 deletions(-) diff --git a/vall_e/__main__.py b/vall_e/__main__.py index 53ee7d3..f8e379a 100755 --- a/vall_e/__main__.py +++ b/vall_e/__main__.py @@ -4,12 +4,14 @@ from .inference import TTS from .config import cfg def path_list(arg): + if not arg: + return None return [Path(p) for p in arg.split(";")] def main(): parser = argparse.ArgumentParser("VALL-E TTS") parser.add_argument("text") - parser.add_argument("references", type=path_list) + parser.add_argument("references", type=path_list, default=None) parser.add_argument("--language", type=str, default="en") parser.add_argument("--out-path", type=Path, default=None) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index ebb1f38..f2029db 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -12,7 +12,7 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine from ..models import get_models, get_model from ..utils import wrapper as ml -from ..models.lora import apply_lora +from ..models.lora import apply_lora, lora_load_state_dict import torch import re @@ -37,16 +37,21 @@ def load_engines(training=True): lora = None inferencing = cfg.mode == "inferencing" or not model.config.training or not training - loads_state_dict = cfg.trainer.load_state_dict or inferencing + backend = cfg.inference.backend if inferencing else cfg.trainer.backend + loads_state_dict = cfg.trainer.load_state_dict # or inferencing checkpoint_path = cfg.ckpt_dir / name / "latest" # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present load_path = cfg.ckpt_dir / name / "fp32.pth" # actually use the lora-specific checkpoint if available - if cfg.lora is not None: - lora = cfg.lora - checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest" + if cfg.lora is not None: + checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest" + + # to handle the issue of training with deepspeed, but inferencing with local + if checkpoint_path.exists() and backend == "local": + tag = open(checkpoint_path).read() + checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth" if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): print("Checkpoint missing, but weights found.") @@ -67,12 +72,11 @@ def load_engines(training=True): optimizer = None lr_scheduler = None - backend = cfg.inference.backend if inferencing else cfg.trainer.backend dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype amp = cfg.inference.amp if inferencing else cfg.trainer.amp ddp = cfg.trainer.ddp - engine_class = LocalEngine if backend == "local" or inferencing else Engine + engine_class = LocalEngine if backend == "local" else Engine # apply model replacers if cfg.optimizations.replace and cfg.optimizations.linear: @@ -152,7 +156,7 @@ def load_engines(training=True): stats = state["stats"] # do not load stats if we're training a LoRA - if "lora" in state or cfg.lora is not None or cfg.trainer.restart_step_count: + if cfg.lora is not None or cfg.trainer.restart_step_count: stats = None if "module" in state: @@ -182,13 +186,15 @@ def load_engines(training=True): model.load_state_dict(state, strict=cfg.trainer.strict_loading) - # load lora weights if exists - if cfg.lora is not None: - lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth" - if lora_path.exists(): - state = torch.load(lora_path, map_location=torch.device(cfg.device)) - state = state['lora' if 'lora' in state else 'module'] - model.load_state_dict(state, strict=False) + # load lora weights if exists + if cfg.lora is not None: + lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth" + if lora_path.exists(): + print( "Loaded LoRA state dict:", lora_path ) + + state = torch.load(lora_path, map_location=torch.device(cfg.device)) + state = state['lora' if 'lora' in state else 'module'] + lora_load_state_dict( model, state ) # wrap if DDP is requested if ddp: @@ -213,42 +219,12 @@ def load_engines(training=True): engines = Engines(engines) engines.setup() + # this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not if not cfg.trainer.load_state_dict: - engines.load_checkpoint() + engines.load_checkpoint(training=not inferencing) # freeze requested params for name, engine in engines.items(): engine.freeze(freeze_all=False) - """ - # copy embeddings if requested - if cfg.model._embeddings is not None: - embeddings_path = cfg.rel_path / cfg.model._embeddings - - if embeddings_path.exists(): - embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) - if "module" in embeddings: - embeddings = embeddings["module"] - - frozen_params = set() - - for k in list(embeddings.keys()): - if re.findall(r'_emb\.', k): - frozen_params.add(k) - else: - del embeddings[k] - - engine.module.load_state_dict(embeddings, strict=False) - - # there's definitely a much better way but I can't be assed at the moment - for name, param in engine.module.named_parameters(): - if name not in frozen_params: - continue - param.requires_grad_(False) - engine._frozen_params.add(param) - """ - - - #do_gc() - return engines \ No newline at end of file diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index ac548e8..e41ed4c 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -164,15 +164,19 @@ class Engine(): if tag is None: tag_path = load_dir / "latest" + if not tag_path.exists(): return + tag = open(tag_path).read() load_path = load_dir / tag / "state.pth" + if not load_path.exists(): return state = torch.load(load_path, map_location=torch.device(cfg.device)) + self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step'] self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples'] @@ -393,18 +397,19 @@ class Engines(dict[str, Engine]): p.unlink() d.rmdir() - def load_checkpoint(self, tag=None): + def load_checkpoint(self, tag=None, training=True): if not tag: tag = cfg.trainer.load_tag for name, engine in self.items(): load_dir = cfg.ckpt_dir / name + engine.load_checkpoint( tag=tag, load_dir=load_dir, load_module_strict=cfg.trainer.strict_loading, - load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states, - load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states, + load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states, + load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states, load_module_only=cfg.trainer.load_module_only, ) if cfg.trainer.restart_step_count: diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index fe93c92..ec7314d 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -118,7 +118,6 @@ class ParameterizedLoRA(nn.Module): nn.init.zeros_( self.lora_B ) def forward(self, x: torch.Tensor): - print( self.enabled, x.shape ) if self.enabled: return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling return x @@ -193,7 +192,7 @@ def apply_lora( model, register = True, merge = False, policy = None, use_parame else: setattr( model.get_submodule(name), k, replacement ) - return model + return enable_lora( model ) def enable_lora( model, mode = True ): for name, module in model.named_modules():