From f25e7656827febd51d8e10c3bd753f3246c41965 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 22 Jul 2024 20:48:06 -0500 Subject: [PATCH] maybe backported some weird fixes for LoRA loading from mrq/vall-e ? --- tortoise_tts/config.py | 1 + tortoise_tts/engines/__init__.py | 133 ++++++++++++++---------------- tortoise_tts/engines/base.py | 32 +++++-- tortoise_tts/engines/deepspeed.py | 11 +-- tortoise_tts/export.py | 17 +++- tortoise_tts/models/lora.py | 18 ++-- 6 files changed, 120 insertions(+), 92 deletions(-) diff --git a/tortoise_tts/config.py b/tortoise_tts/config.py index 305022a..bfccefa 100755 --- a/tortoise_tts/config.py +++ b/tortoise_tts/config.py @@ -218,6 +218,7 @@ class LoRA: rank: int = 8 # rank for the LoRA alpha: int = 16 # rank for the LoRA training: bool = True # + embeddings: bool = False parametrize: bool = False # module: str = "linear" # linear | conv1d diff --git a/tortoise_tts/engines/__init__.py b/tortoise_tts/engines/__init__.py index d356614..51174f1 100755 --- a/tortoise_tts/engines/__init__.py +++ b/tortoise_tts/engines/__init__.py @@ -10,9 +10,9 @@ elif cfg.trainer.backend == "local": from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine -from ..models import get_models +from ..models import get_models, get_model from ..utils import wrapper as ml -from ..models.lora import apply_lora +from ..models.lora import apply_lora, lora_load_state_dict import torch import re @@ -32,23 +32,53 @@ def load_engines(training=True): engines = dict() for name, model in models.items(): + state = None + stats = None + lora = None + + inferencing = cfg.mode == "inferencing" or not model.config.training or not training + backend = cfg.inference.backend if inferencing else cfg.trainer.backend + loads_state_dict = cfg.trainer.load_state_dict # or inferencing + + checkpoint_path = cfg.ckpt_dir / name / "latest" + # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present + load_path = cfg.ckpt_dir / name / "fp32.pth" + + # actually use the lora-specific checkpoint if available + if cfg.lora is not None: + checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / "latest" + + # to handle the issue of training with deepspeed, but inferencing with local + if checkpoint_path.exists() and backend == "local": + tag = open(checkpoint_path).read() + checkpoint_path = cfg.ckpt_dir / cfg.lora.full_name / tag / "state.pth" + + if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): + print("Checkpoint missing, but weights found.") + loads_state_dict = True + + # load state early + if loads_state_dict: + state = torch.load(load_path, map_location=torch.device(cfg.device)) + + # check if config is defined in state, and re-initialize the model + if "config" in state and False: + print("Model config definition in weights, re-loading...") + config_state = state["config"] + model = get_model( config=cfg.model.__class__( *config_state ), training=training ) + hyper_config = model.config optimizer = None lr_scheduler = None - inferencing = cfg.mode == "inferencing" or not model.config.training - backend = cfg.inference.backend if inferencing else cfg.trainer.backend dtype = cfg.inference.dtype if inferencing else cfg.trainer.dtype amp = cfg.inference.amp if inferencing else cfg.trainer.amp - loads_state_dict = cfg.trainer.load_state_dict or inferencing ddp = cfg.trainer.ddp - engine_class = LocalEngine if backend == "local" or inferencing else Engine - - if inferencing: - model.config.training = False + engine_class = LocalEngine if backend == "local" else Engine + # apply model replacers if cfg.optimizations.replace and cfg.optimizations.linear: model.model = ml.replace_linear( model.model ) @@ -60,6 +90,9 @@ def load_engines(training=True): #model.gpt = apply_lora( model.gpt, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, parametrize = lora.parametrize ) model = apply_lora( model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize ) + if inferencing: + model.config.training = False + if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)): optimizer_class = None scheduler_class = None @@ -118,28 +151,14 @@ def load_engines(training=True): optimizer = None lr_scheduler = None - checkpoint_path = cfg.ckpt_dir / name / "latest" - # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present - load_path = cfg.ckpt_dir / name / "fp32.pth" - - # actually use the lora-specific checkpoint if available - if cfg.lora is not None: - checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest" - - if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): - print("Checkpoint missing, but weights found.") - loads_state_dict = True - - stats = None - if loads_state_dict and load_path.exists(): - state = torch.load(load_path, map_location=torch.device(cfg.device)) - + # load state dict if requested / required + if loads_state_dict: # state dict is not just the module, extract the extra trainer details if "stats" in state: stats = state["stats"] # do not load stats if we're training a LoRA - if "lora" not in state: + if cfg.lora is not None or cfg.trainer.restart_step_count: stats = None if "module" in state: @@ -161,23 +180,23 @@ def load_engines(training=True): for k in erase: del state[k] - # resize text embedding - if "text_emb.weight" in state and model.config.text_tokens != state["text_emb.weight"].shape[0]: - state["text_emb.weight"] = state["text_emb.weight"][:model.config.text_tokens] - - # resize text embedding - if "rvq_l_emb.weight" in state and model.config.resp_levels != state["rvq_l_emb.weight"].shape[0]: - state["rvq_l_emb.weight"] = state["rvq_l_emb.weight"][:model.config.resp_levels] + # resize embeddings + if "text_emb.weight" in state: + state["text_emb.weight"] = ml.resize_weight( state["text_emb.weight"], model.config.text_tokens ) + if "rvq_l_emb.weight" in state: + state["rvq_l_emb.weight"] = ml.resize_weight( state["rvq_l_emb.weight"], model.config.resp_levels ) model.load_state_dict(state, strict=cfg.trainer.strict_loading) - # load lora weights if exists - if cfg.lora is not None: - lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth" - if lora_path.exists(): - state = torch.load(lora_path, map_location=torch.device(cfg.device)) - state = state['lora' if 'lora' in state else 'module'] - model.load_state_dict(state, strict=False) + # load lora weights if exists + if cfg.lora is not None: + lora_path = cfg.ckpt_dir / cfg.lora.full_name / "lora.pth" + if lora_path.exists(): + print( "Loaded LoRA state dict:", lora_path ) + + state = torch.load(lora_path, map_location=torch.device(cfg.device)) + state = state['lora' if 'lora' in state else 'module'] + lora_load_state_dict( model, state ) # wrap if DDP is requested if ddp: @@ -202,42 +221,12 @@ def load_engines(training=True): engines = Engines(engines) engines.setup() + # this might bite me in the ass since technically this doesn't handle one engine loading fine but another engine not if not cfg.trainer.load_state_dict: - engines.load_checkpoint() + engines.load_checkpoint(training=not inferencing) # freeze requested params for name, engine in engines.items(): engine.freeze(freeze_all=False) - """ - # copy embeddings if requested - if cfg.model._embeddings is not None: - embeddings_path = cfg.rel_path / cfg.model._embeddings - - if embeddings_path.exists(): - embeddings = torch.load(embeddings_path, map_location=torch.device(cfg.device)) - if "module" in embeddings: - embeddings = embeddings["module"] - - frozen_params = set() - - for k in list(embeddings.keys()): - if re.findall(r'_emb\.', k): - frozen_params.add(k) - else: - del embeddings[k] - - engine.module.load_state_dict(embeddings, strict=False) - - # there's definitely a much better way but I can't be assed at the moment - for name, param in engine.module.named_parameters(): - if name not in frozen_params: - continue - param.requires_grad_(False) - engine._frozen_params.add(param) - """ - - - #do_gc() - return engines \ No newline at end of file diff --git a/tortoise_tts/engines/base.py b/tortoise_tts/engines/base.py index bb2f532..e41ed4c 100755 --- a/tortoise_tts/engines/base.py +++ b/tortoise_tts/engines/base.py @@ -81,7 +81,7 @@ class Engine(): # freeze non-LoRA params if requested if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: - return freeze_non_lora_weights( self.module ) + return freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings ) for name, param in self.module.named_parameters(): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): @@ -164,15 +164,19 @@ class Engine(): if tag is None: tag_path = load_dir / "latest" + if not tag_path.exists(): return + tag = open(tag_path).read() load_path = load_dir / tag / "state.pth" + if not load_path.exists(): return state = torch.load(load_path, map_location=torch.device(cfg.device)) + self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step'] self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples'] @@ -320,11 +324,21 @@ class Engines(dict[str, Engine]): for engine in self.values(): engine.dispatch_attribute(*args, **kwargs) - def export(self, userdata={}, callback=None): + def export(self, userdata={}, callback=None, dtype=None): + if dtype is None: + dtype = cfg.trainer.dtype + for name, engine in self.items(): module = engine.module.state_dict() lora = None save_path = cfg.ckpt_dir / name / "fp32.pth" + config = engine.module.config if hasattr(engine.module, "config") else engine.hyper_config + if not isinstance(config, dict): + config = config.__dict__ + + # safety + for k, v in module.items(): + module[k] = v.to(dtype) if cfg.lora is not None: lora, module = lora_get_state_dict( module, split = True ) @@ -339,8 +353,13 @@ class Engines(dict[str, Engine]): "global_samples": engine.global_samples, "tokens_processed": engine.tokens_processed, }, - "userdata": userdata + "userdata": userdata, + "config": config } + + if lora is None: + del state_dict['lora'] + if callback: state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) @@ -378,18 +397,19 @@ class Engines(dict[str, Engine]): p.unlink() d.rmdir() - def load_checkpoint(self, tag=None): + def load_checkpoint(self, tag=None, training=True): if not tag: tag = cfg.trainer.load_tag for name, engine in self.items(): load_dir = cfg.ckpt_dir / name + engine.load_checkpoint( tag=tag, load_dir=load_dir, load_module_strict=cfg.trainer.strict_loading, - load_optimizer_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states, - load_lr_scheduler_states=False if cfg.trainer.load_module_only else cfg.trainer.load_states, + load_optimizer_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states, + load_lr_scheduler_states=False if cfg.trainer.load_module_only or not training else cfg.trainer.load_states, load_module_only=cfg.trainer.load_module_only, ) if cfg.trainer.restart_step_count: diff --git a/tortoise_tts/engines/deepspeed.py b/tortoise_tts/engines/deepspeed.py index 5d3e2b1..e31da2d 100755 --- a/tortoise_tts/engines/deepspeed.py +++ b/tortoise_tts/engines/deepspeed.py @@ -27,6 +27,8 @@ from deepspeed.accelerator import get_accelerator from ..utils.distributed import init_distributed, distributed_initialized from ..utils import wrapper as ml +from ..models.lora import freeze_non_lora_weights + if not distributed_initialized() and cfg.trainer.backend == "deepspeed": init_distributed(init_deepspeed_dist) @@ -66,11 +68,10 @@ class Engine(DeepSpeedEngine): def freeze(self, freeze_all=True): # freeze non-LoRA params if requested if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: - for name, param in self.module.named_parameters(): - should = 'lora_' in name - param.requires_grad_(should) - if not should: - self._frozen_params.add(param) + frozen_params = freeze_non_lora_weights( self.module, embeddings=cfg.lora.embeddings ) + for param in frozen_params: + self._frozen_params.add( param ) + return if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): diff --git a/tortoise_tts/export.py b/tortoise_tts/export.py index 9ee285b..8af9c7a 100755 --- a/tortoise_tts/export.py +++ b/tortoise_tts/export.py @@ -8,7 +8,10 @@ from .engines import load_engines from .config import cfg from .models.lora import lora_get_state_dict -def extract_lora( state_dict, config = None, save_path = None ): +def extract_lora( state_dict, config = None, save_path = None, dtype = None ): + if dtype is None: + dtype = cfg.inference.dtype + lora = state_dict["lora"] if "lora" in state_dict else None # should always be included, but just in case if lora is None and "module" in state_dict: @@ -23,15 +26,18 @@ def extract_lora( state_dict, config = None, save_path = None ): # save lora specifically # should probably export other attributes, similar to what SD LoRAs do save_path = save_path.parent / "lora.pth" - torch.save( { "module": lora }, save_path ) + torch.save( { + "module": lora, + "config": cfg.lora.__dict__ if cfg.lora is not None else None, + }, save_path ) return state_dict - def main(): parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--module-only", action='store_true') parser.add_argument("--lora", action='store_true', default=None) # exports LoRA + parser.add_argument("--dtype", type=str, default="auto") # set target dtype to export to args, unknown = parser.parse_known_args() if args.module_only: @@ -41,7 +47,10 @@ def main(): if args.lora: callback = extract_lora - engines = load_engines() + if args.dtype != "auto": + cfg.trainer.weight_dtype = args.dtype + + engines = load_engines(training=False) engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) if __name__ == "__main__": diff --git a/tortoise_tts/models/lora.py b/tortoise_tts/models/lora.py index cee0605..ec7314d 100644 --- a/tortoise_tts/models/lora.py +++ b/tortoise_tts/models/lora.py @@ -1,5 +1,4 @@ # Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py - from functools import partial import torch import torch.nn.functional as F @@ -148,6 +147,7 @@ class ParameterizedLoRA(nn.Module): def passes_policy( policy, name ): if policy is None: return True + if "exclude" in policy: for term in policy["exclude"]: if term in name: @@ -192,7 +192,7 @@ def apply_lora( model, register = True, merge = False, policy = None, use_parame else: setattr( model.get_submodule(name), k, replacement ) - return model + return enable_lora( model ) def enable_lora( model, mode = True ): for name, module in model.named_modules(): @@ -204,10 +204,18 @@ def enable_lora( model, mode = True ): def disable_lora( model ): return enable_lora( model, False ) -def freeze_non_lora_weights( model ): +def freeze_non_lora_weights( model, embeddings = False ): + frozen_params = [] + for name, param in model.named_parameters(): - param.requires_grad_('lora_' in name) - return model + should = 'lora_' in name or (embeddings and "_emb" in name) + + param.requires_grad_(should) + + if not should: + frozen_params.append( param ) + + return frozen_params def lora_get_state_dict( state_dict, split = True ): lora = { name: param for name, param in state_dict.items() if "lora_" in name }