From 7047fcc6e20a2aeb7ccb822257213fc52f9ab1b5 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 17 Jun 2024 13:55:37 -0500 Subject: [PATCH] actually make deepspeed work with LoRAs --- vall_e/engines/__init__.py | 12 +++++++----- vall_e/engines/base.py | 21 +++++---------------- vall_e/engines/deepspeed.py | 17 ++++++----------- vall_e/export.py | 35 ++++++++++++++++++++++++++++++----- vall_e/models/lora.py | 2 +- 5 files changed, 49 insertions(+), 38 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 4088bdc..aef8b92 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -12,6 +12,8 @@ from .base import Engines, TrainFeeder, default_feeder, Engine as LocalEngine from ..models import get_models from ..utils import wrapper as ml +from ..models.lora import apply_lora + import torch import re @@ -30,6 +32,8 @@ def load_engines(training=True): engines = dict() for name, model in models.items(): + hyper_config = model.config + optimizer = None lr_scheduler = None @@ -51,6 +55,9 @@ def load_engines(training=True): if cfg.optimizations.replace and cfg.optimizations.embedding: model.model = ml.replace_embedding( model.model ) + for lora in cfg.loras: + model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy ) + if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): optimizer_class = None scheduler_class = None @@ -153,8 +160,6 @@ def load_engines(training=True): model.load_state_dict(state, strict=cfg.trainer.strict_loading) - hyper_config = model.config - # wrap if DDP is requested if ddp: model = ddp_model(model) @@ -178,9 +183,6 @@ def load_engines(training=True): engines = Engines(engines) engines.setup() - for name, engine in engines.items(): - engine.load_loras() - if not cfg.trainer.load_state_dict: engines.load_checkpoint() diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 3803098..47e658a 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -29,7 +29,7 @@ def default_feeder(engine, batch): from ..config import cfg from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size -from ..models.lora import apply_lora, freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict +from ..models.lora import freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict import logging import time @@ -190,17 +190,6 @@ class Engine(): if 'lora' in state: lora_load_state_dict( self.module, state['lora'] ) - - def load_loras( self ): - # apply lora weights - for lora in cfg.loras: - self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy ) - - lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth" - if lora_path.exists(): - state_dict = torch.load(lora_path, map_location=torch.device(cfg.device)) - self.module = lora_load_state_dict( self.module, state_dict ) - def eval(self): return self.module.eval() @@ -334,11 +323,11 @@ class Engines(dict[str, Engine]): for name, engine in self.items(): module = engine.module.state_dict() lora = None - save_dir = cfg.ckpt_dir / name / "fp32.pth" + save_path = cfg.ckpt_dir / name / "fp32.pth" if cfg.lora is not None: lora, module = lora_get_state_dict( module, split = True ) - save_dir = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth" + save_path = cfg.ckpt_dir / cfg.lora.full_name / "fp32.pth" state_dict = { 'module': module, @@ -352,9 +341,9 @@ class Engines(dict[str, Engine]): "userdata": userdata } if callback: - state_dict = callback( state_dict, engine.hyper_config ) + state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) - torch.save(state_dict, save_dir) + torch.save(state_dict, save_path) print(f"Exported {name} to {outpath}") def save_checkpoint(self, tag=None): diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 595f46c..8196a8f 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -68,7 +68,12 @@ class Engine(DeepSpeedEngine): # freeze non-LoRA params if requested if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: - return freeze_non_lora_weights( self.module ) + for name, param in self.module.named_parameters(): + should = 'lora_' in name + param.requires_grad_(should) + if not should: + self._frozen_params.add(param) + return for name, param in self.module.named_parameters(): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): @@ -128,16 +133,6 @@ class Engine(DeepSpeedEngine): return super().save_checkpoint( save_dir, **kwargs ) - def load_loras( self ): - # apply lora weights - for lora in cfg.loras: - self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy ) - - lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth" - if lora_path.exists(): - state_dict = torch.load(lora_path, map_location=torch.device(cfg.device)) - self.module = lora_load_state_dict( self.module, state_dict ) - def traverse(self, *args, **kwargs): with ml.autocast(): self.forward(*args, **kwargs) diff --git a/vall_e/export.py b/vall_e/export.py index 3abbefd..27329a0 100755 --- a/vall_e/export.py +++ b/vall_e/export.py @@ -6,9 +6,10 @@ import torch.nn from .data import get_phone_symmap from .engines import load_engines from .config import cfg +from .models.lora import lora_get_state_dict -# stitches embeddings into one embedding + classifier => lm_head -def convert_to_hf( state_dict, config = None ): +# stitches embeddings into one embedding & classifier => lm_head +def convert_to_hf( state_dict, config = None, save_path = None ): n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1 token_dim = 1024 embedding = torch.nn.Embedding(n_tokens, token_dim) @@ -53,22 +54,46 @@ def convert_to_hf( state_dict, config = None ): state_dict['module']['lm_head.weight'] = out_proj del state_dict['module']['classifier.bias'] - torch.save(state_dict, "./data/export_test.pth") + return state_dict - raise Exception("!") +def extract_lora( state_dict, config = None, save_path = None ): + lora = state_dict["lora"] if "lora" in state_dict else None + # should always be included, but just in case + if lora is None and "module" in state_dict: + lora, module = lora_get_state_dict( state_dict["module"], split = True ) + state_dict["module"] = module + state_dict["lora"] = lora + + # should raise an exception since there's nothing to extract, or at least a warning + if not lora: + return state_dict + + # save lora specifically + # should probably export other attributes, similar to what SD LoRAs do + save_path = save_path.parent / "lora.pth" + torch.save( { "module": lora }, save_path ) return state_dict + def main(): parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--module-only", action='store_true') parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style + parser.add_argument("--lora", action='store_true', default=None) # exports LoRA args, unknown = parser.parse_known_args() if args.module_only: cfg.trainer.load_module_only = True - callback = convert_to_hf if args.hf else None + callback = None + if args.hf: + callback = convert_to_hf + elif args.lora: + callback = extract_lora + + if args.hf and args.lora: + raise Exception("Requesting more than one callback") engines = load_engines() engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index b0cf76c..a6f7a67 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -14,7 +14,7 @@ USE_PARAMETRIZATION = True # LoRA Linear for replacement # Pros: simple, just needs to reuse the replace_linear and copy weights -# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc) +# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you) class Linear(nn.Linear): def __init__( self,