From 45a39fb79f28328252f18ef380a2e7a63d8c1347 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 17 Jun 2024 00:09:16 -0500 Subject: [PATCH] very rudimentary lora support (no deepspeed support, tested training and saving but not loading yet) --- vall_e/config.py | 38 +++++++++-- vall_e/engines/__init__.py | 7 ++- vall_e/engines/base.py | 34 +++++++++- vall_e/engines/deepspeed.py | 3 + vall_e/models/lora.py | 122 ++++++++++++++++++++++++++++++++++++ 5 files changed, 196 insertions(+), 8 deletions(-) create mode 100644 vall_e/models/lora.py diff --git a/vall_e/config.py b/vall_e/config.py index d639b02..ceb5278 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -188,9 +188,6 @@ class Dataset: # I really need to clean this up @dataclass() class Model: - _max_levels: int = 0 - _embeddings: str | None = None - name: str = "" # vanity name for the model version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding size: str | dict = "full" # preset string or explicitly defined dimensionality @@ -223,7 +220,7 @@ class Model: @property def max_levels(self): - return self._max_levels if self._max_levels > 0 else self.prom_levels + return max(self.prom_levels, self.resp_levels) @property # required for fp8 as the lengths needs to be divisible by 8 @@ -316,6 +313,18 @@ class Model: @property def gradient_checkpointing(self): return cfg.trainer.gradient_checkpointing + +@dataclass() +class LoRA: + name: str = "lora" # vanity name + rank: int = 8 # rank for the LoRA + alpha: int = 1 # rank for the LoRA + training: bool = True # + + @property + def full_name(self): + name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ] + return "-".join(name) @dataclass() class Hyperparameters: @@ -622,7 +631,8 @@ class Config(BaseConfig): experimental: bool = False # So I can stop commenting out things when committing dataset: Dataset = field(default_factory=lambda: Dataset) - models: dict | list | None = field(default_factory=lambda: [Model]) + models: dict | list | None = field(default_factory=lambda: []) + loras: dict | list | None = field(default_factory=lambda: []) hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters) evaluation: Evaluation = field(default_factory=lambda: Evaluation) trainer: Trainer = field(default_factory=lambda: Trainer) @@ -643,7 +653,15 @@ class Config(BaseConfig): if model.training: return model - return self.models[0] + return self.models[0] if len(self.models) > 0 else None + + @property + def lora(self): + for i, lora in enumerate(self.loras): + if lora.training: + return lora + + return self.loras[0] if len(self.loras) > 0 else None @property def distributed(self): @@ -686,6 +704,9 @@ class Config(BaseConfig): if isinstance(self.models, type): self.models = dict() + + if isinstance(self.loras, type): + self.loras = dict() if isinstance(self.hyperparameters, type): self.hyperparameters = dict() @@ -715,6 +736,7 @@ class Config(BaseConfig): """ self.models = [ Model(**model) for model in self.models ] + self.loras = [ LoRA(**lora) for lora in self.loras ] self.hyperparameters = Hyperparameters(**self.hyperparameters) @@ -758,6 +780,10 @@ class Config(BaseConfig): if not training: self.dataset.use_hdf5 = False + # raise error if DeepSpeed and a LoRA is loaded, because I don't support it yet + if self.trainer.backend == "deepspeed" and self.lora is not None: + raise Exception("LoRAs are currently unsupported with deepspeed backend") + # load our HDF5 file if requested here if self.dataset.use_hdf5: self.load_hdf5() diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index b8c9d1b..5d1175f 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -112,7 +112,7 @@ def load_engines(training=True): # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present load_path = cfg.ckpt_dir / name / "fp32.pth" - if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): + if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): print("DeepSpeed checkpoint missing, but weights found.") loads_state_dict = True @@ -178,6 +178,9 @@ def load_engines(training=True): engines = Engines(engines) engines.setup() + for name, engine in engines.items(): + engine.load_loras() + if not cfg.trainer.load_state_dict: engines.load_checkpoint() @@ -185,6 +188,7 @@ def load_engines(training=True): for name, engine in engines.items(): engine.freeze(freeze_all=False) + """ # copy embeddings if requested if cfg.model._embeddings is not None: embeddings_path = cfg.rel_path / cfg.model._embeddings @@ -210,6 +214,7 @@ def load_engines(training=True): continue param.requires_grad_(False) engine._frozen_params.add(param) + """ #do_gc() diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 3eb1996..306fa5c 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -29,6 +29,7 @@ def default_feeder(engine, batch): from ..config import cfg from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size +from ..models.lora import apply_lora, freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict import logging import time @@ -70,11 +71,17 @@ class Engine(): self.max_nan_losses = 8 self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None + self._global_grad_norm = None + def freeze(self, freeze_all=True): # set to freeze if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") + # freeze non-LoRA params if requested + if not self.hyper_config.frozen_params and not freeze_all: + return freeze_non_lora_weights( self.module ) + for name, param in self.module.named_parameters(): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): param.requires_grad_(False) @@ -119,10 +126,21 @@ class Engine(): def save_checkpoint(self, save_dir, tag ): if is_global_leader(): + module = self.module.state_dict() + + # if training lora + # this is a separate path to override saving the weights + lora = None + if cfg.lora is not None: + lora, module = lora_get_state_dict( module, split = True ) + save_dir = cfg.ckpt_dir / cfg.lora.full_name + save_path = save_dir / tag / "state.pth" save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save({ - "module": self.module.state_dict(), + "module": module, + "lora": lora, "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, @@ -165,6 +183,20 @@ class Engine(): if load_lr_scheduler_states: self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device)) + if 'lora' in state: + lora_load_state_dict( self.module, state['lora'] ) + + + def load_loras( self ): + # apply lora weights + for lora in cfg.loras: + self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha ) + + lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth" + if lora_path.exists(): + state_dict = torch.load(lora_path, map_location=torch.device(cfg.device)) + self.module = lora_load_state_dict( self.module, state_dict ) + def eval(self): return self.module.eval() diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 08258ae..afa7c87 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -108,6 +108,9 @@ class Engine(DeepSpeedEngine): except Exception as e: print(str(e)) + def load_loras(self): + ... + def traverse(self, *args, **kwargs): with ml.autocast(): self.forward(*args, **kwargs) diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py new file mode 100644 index 0000000..6e288b4 --- /dev/null +++ b/vall_e/models/lora.py @@ -0,0 +1,122 @@ +# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +from functools import partial +import torch +import torch.nn.functional as F +import torch.nn.utils.parametrize as parametrize + +from torch import Tensor, nn + +import math +from typing import Optional, List + +class Linear(nn.Linear): + def __init__( + self, + + in_features: int, + out_features: int, + bias: bool = True, + + rank: int = 4, + alpha: int = 1, + + dropout: float = 0.1, + merge_weights: bool = True, + **kwargs, + ): + super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs) + + self.rank = rank + self.alpha = alpha + self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x + self.merge_weights = merge_weights + self.merged = False + + self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) ) + self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) ) + self.scaling = self.alpha / self.rank + + self.weight.requires_grad = False + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + # super silly but necessary because nn.Linear's constructor calls this + if hasattr(self, 'lora_A'): + nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) ) + nn.init.zeros_( self.lora_B ) + + def train(self, mode: bool = True): + super().train(mode) + + # training, separate lora from base weights + if mode and self.merge_weights and self.merged: + self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling + self.merged = False + + # not training, merge lora to base weights + if not mode and self.merge_weights and not self.merged: + self.weight.data += (self.lora_B @ self.lora_A) * self.scaling + self.merged = True + + def forward(self, x: torch.Tensor): + if not self.merged: + result = F.linear(x, self.weight, bias=self.bias) + result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return result + + return F.linear(x, self.weight, bias=self.bias) + + @classmethod + def from_linear( cls, layer, **kwargs ): + return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ) + +# broken, the in_features / out_features change somehow +def parameterize_model( layer, register = True, merge = False, **kwargs ): + if not isinstance( layer, nn.Linear ): + return + + if register: + parametrize.register_parametrization( layer, "weight", Linear.from_linear( layer, **kwargs ) ) + else: + parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge ) + +def apply_lora( model, **kwargs ): + klass = Linear + target = nn.Linear + + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + + for *parent, k in modules: + name = '.'.join(parent) + + layer = getattr( model.get_submodule(name), k ) + + if isinstance(layer, klass): + continue + + injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + injected.weight = layer.weight + + # overwrite + setattr( model.get_submodule(name), k, injected ) + + return model + +def freeze_non_lora_weights( model ): + for name, param in model.named_parameters(): + param.requires_grad_('lora_' in name) + return model + +def lora_get_state_dict( state_dict, split = True ): + lora = { name: param for name, param in state_dict.items() if "lora_" in name } + if not split: + return lora + + return lora, { name: param for name, param in state_dict.items() if "lora_" not in name } + +def lora_load_state_dict( model, state_dict ): + return model.load_state_dict( state_dict, strict = False ) \ No newline at end of file