From 8a986eb480a3b8aabc2da844f7ac719b6afaa65c Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 18 Jun 2024 21:45:46 -0500 Subject: [PATCH] load exported LoRA weights if exists (to-do: make a better LoRA loading mechanism) --- vall_e/config.py | 1 + vall_e/engines/__init__.py | 12 ++++++-- vall_e/models/lora.py | 59 ++++++++++++++++++++++++++------------ 3 files changed, 51 insertions(+), 21 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index daafd0b..a77bce7 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -333,6 +333,7 @@ class LoRA: rank: int = 8 # rank for the LoRA alpha: int = 16 # rank for the LoRA training: bool = True # + parametrize: bool = False # rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA @property diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 7cc973d..707c61d 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -56,9 +56,9 @@ def load_engines(training=True): model.model = ml.replace_embedding( model.model ) for lora in cfg.loras: - model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy ) + model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize ) - if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): + if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)): optimizer_class = None scheduler_class = None @@ -164,6 +164,14 @@ def load_engines(training=True): model.load_state_dict(state, strict=cfg.trainer.strict_loading) + # load lora weights if exists + if cfg.lora is not None: + lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth" + if lora_path.exists(): + state = torch.load(lora_path, map_location=torch.device(cfg.device)) + state = state['lora' if 'lora' in state else 'module'] + model.load_state_dict(state, strict=False) + # wrap if DDP is requested if ddp: model = ddp_model(model) diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index 338af2e..eb0cbb7 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -4,18 +4,17 @@ import torch import torch.nn.functional as F import torch.nn.utils.parametrize as parametrize +from transformers.pytorch_utils import Conv1D + from torch import Tensor, nn import math from typing import Optional, List -# to-do: set cfg to decide -USE_PARAMETRIZATION = False - # LoRA Linear for replacement # Pros: simple, just needs to reuse the replace_linear and copy weights # Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you) -class Linear(nn.Linear): +class LoRALinear(nn.Linear): def __init__( self, @@ -86,7 +85,7 @@ class Linear(nn.Linear): # Uses parametrization to inject LoRA weights # Pros: should work with any Linears # Cons: TBD -class ParameterizedLinear(nn.Module): +class ParameterizedLoRA(nn.Module): def __init__( self, @@ -119,6 +118,7 @@ class ParameterizedLinear(nn.Module): nn.init.zeros_( self.lora_B ) def forward(self, x: torch.Tensor): + print( self.enabled, x.shape ) if self.enabled: return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling return x @@ -133,10 +133,21 @@ class ParameterizedLinear(nn.Module): # M$'s LoRA class arranges things to where this isn't necessary return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + @classmethod + def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ): + if device is None: + device = layer.weight.device + if dtype is None: + dtype = layer.weight.dtype + + in_channels, out_channels = layer.weight.shape + # swap because we're feeding the output as our input + # M$'s LoRA class arranges things to where this isn't necessary + return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + def passes_policy( policy, name ): if policy is None: return True - if "exclude" in policy: for term in policy["exclude"]: if term in name: @@ -149,33 +160,43 @@ def passes_policy( policy, name ): return False - -def apply_lora( model, register = True, merge = False, policy = None, **kwargs ): +def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype - klass = Linear - target = nn.Linear - - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ] + modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ] for *parent, k in modules: name = '.'.join(parent) layer = getattr( model.get_submodule(name), k ) - if USE_PARAMETRIZATION: - parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) ) - # parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge ) + if isinstance( layer, nn.Linear ): + target = nn.Linear + klass = ParameterizedLoRA if use_parametrize else LoRALinear + replacer = klass.from_linear + elif isinstance( layer, nn.Conv1d ): + target = nn.Conv1d + klass = ParameterizedLoRA if use_parametrize else LoRAConv1d + replacer = klass.from_conv1d + elif isinstance( layer, Conv1D ): + target = Conv1D + klass = ParameterizedLoRA if use_parametrize else LoRAConv1d + replacer = klass.from_conv1d else: - setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) ) + continue + + replacement = replacer( layer, device=device, dtype=dtype, **kwargs ) + + if use_parametrize: + parametrize.register_parametrization( layer, "weight", replacement ) + else: + setattr( model.get_submodule(name), k, replacement ) return model def enable_lora( model, mode = True ): for name, module in model.named_modules(): - if not isinstance( module, ParameterizedLinear if USE_PARAMETRIZATION else Linear ): + if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ): continue module.enabled = mode return model