very rudimentary lora support (no deepspeed support, tested training and saving but not loading yet)

This commit is contained in:
mrq 2024-06-17 00:09:16 -05:00
parent 19410a919e
commit 45a39fb79f
5 changed files with 196 additions and 8 deletions

View File

@ -188,9 +188,6 @@ class Dataset:
# I really need to clean this up
@dataclass()
class Model:
_max_levels: int = 0
_embeddings: str | None = None
name: str = "" # vanity name for the model
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
size: str | dict = "full" # preset string or explicitly defined dimensionality
@ -223,7 +220,7 @@ class Model:
@property
def max_levels(self):
return self._max_levels if self._max_levels > 0 else self.prom_levels
return max(self.prom_levels, self.resp_levels)
@property
# required for fp8 as the lengths needs to be divisible by 8
@ -316,6 +313,18 @@ class Model:
@property
def gradient_checkpointing(self):
return cfg.trainer.gradient_checkpointing
@dataclass()
class LoRA:
name: str = "lora" # vanity name
rank: int = 8 # rank for the LoRA
alpha: int = 1 # rank for the LoRA
training: bool = True #
@property
def full_name(self):
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
return "-".join(name)
@dataclass()
class Hyperparameters:
@ -622,7 +631,8 @@ class Config(BaseConfig):
experimental: bool = False # So I can stop commenting out things when committing
dataset: Dataset = field(default_factory=lambda: Dataset)
models: dict | list | None = field(default_factory=lambda: [Model])
models: dict | list | None = field(default_factory=lambda: [])
loras: dict | list | None = field(default_factory=lambda: [])
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
trainer: Trainer = field(default_factory=lambda: Trainer)
@ -643,7 +653,15 @@ class Config(BaseConfig):
if model.training:
return model
return self.models[0]
return self.models[0] if len(self.models) > 0 else None
@property
def lora(self):
for i, lora in enumerate(self.loras):
if lora.training:
return lora
return self.loras[0] if len(self.loras) > 0 else None
@property
def distributed(self):
@ -686,6 +704,9 @@ class Config(BaseConfig):
if isinstance(self.models, type):
self.models = dict()
if isinstance(self.loras, type):
self.loras = dict()
if isinstance(self.hyperparameters, type):
self.hyperparameters = dict()
@ -715,6 +736,7 @@ class Config(BaseConfig):
"""
self.models = [ Model(**model) for model in self.models ]
self.loras = [ LoRA(**lora) for lora in self.loras ]
self.hyperparameters = Hyperparameters(**self.hyperparameters)
@ -758,6 +780,10 @@ class Config(BaseConfig):
if not training:
self.dataset.use_hdf5 = False
# raise error if DeepSpeed and a LoRA is loaded, because I don't support it yet
if self.trainer.backend == "deepspeed" and self.lora is not None:
raise Exception("LoRAs are currently unsupported with deepspeed backend")
# load our HDF5 file if requested here
if self.dataset.use_hdf5:
self.load_hdf5()

View File

@ -112,7 +112,7 @@ def load_engines(training=True):
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
print("DeepSpeed checkpoint missing, but weights found.")
loads_state_dict = True
@ -178,6 +178,9 @@ def load_engines(training=True):
engines = Engines(engines)
engines.setup()
for name, engine in engines.items():
engine.load_loras()
if not cfg.trainer.load_state_dict:
engines.load_checkpoint()
@ -185,6 +188,7 @@ def load_engines(training=True):
for name, engine in engines.items():
engine.freeze(freeze_all=False)
"""
# copy embeddings if requested
if cfg.model._embeddings is not None:
embeddings_path = cfg.rel_path / cfg.model._embeddings
@ -210,6 +214,7 @@ def load_engines(training=True):
continue
param.requires_grad_(False)
engine._frozen_params.add(param)
"""
#do_gc()

View File

@ -29,6 +29,7 @@ def default_feeder(engine, batch):
from ..config import cfg
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
from ..models.lora import apply_lora, freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
import logging
import time
@ -70,11 +71,17 @@ class Engine():
self.max_nan_losses = 8
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
self._global_grad_norm = None
def freeze(self, freeze_all=True):
# set to freeze
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all:
return freeze_non_lora_weights( self.module )
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)
@ -119,10 +126,21 @@ class Engine():
def save_checkpoint(self, save_dir, tag ):
if is_global_leader():
module = self.module.state_dict()
# if training lora
# this is a separate path to override saving the weights
lora = None
if cfg.lora is not None:
lora, module = lora_get_state_dict( module, split = True )
save_dir = cfg.ckpt_dir / cfg.lora.full_name
save_path = save_dir / tag / "state.pth"
save_path.parent.mkdir(parents=True, exist_ok=True)
torch.save({
"module": self.module.state_dict(),
"module": module,
"lora": lora,
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
@ -165,6 +183,20 @@ class Engine():
if load_lr_scheduler_states:
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device))
if 'lora' in state:
lora_load_state_dict( self.module, state['lora'] )
def load_loras( self ):
# apply lora weights
for lora in cfg.loras:
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha )
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
if lora_path.exists():
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
self.module = lora_load_state_dict( self.module, state_dict )
def eval(self):
return self.module.eval()

View File

@ -108,6 +108,9 @@ class Engine(DeepSpeedEngine):
except Exception as e:
print(str(e))
def load_loras(self):
...
def traverse(self, *args, **kwargs):
with ml.autocast():
self.forward(*args, **kwargs)

122
vall_e/models/lora.py Normal file
View File

@ -0,0 +1,122 @@
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
from functools import partial
import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from torch import Tensor, nn
import math
from typing import Optional, List
class Linear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
rank: int = 4,
alpha: int = 1,
dropout: float = 0.1,
merge_weights: bool = True,
**kwargs,
):
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
self.rank = rank
self.alpha = alpha
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.merge_weights = merge_weights
self.merged = False
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
self.scaling = self.alpha / self.rank
self.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
super().reset_parameters()
# super silly but necessary because nn.Linear's constructor calls this
if hasattr(self, 'lora_A'):
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
nn.init.zeros_( self.lora_B )
def train(self, mode: bool = True):
super().train(mode)
# training, separate lora from base weights
if mode and self.merge_weights and self.merged:
self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
self.merged = False
# not training, merge lora to base weights
if not mode and self.merge_weights and not self.merged:
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
self.merged = True
def forward(self, x: torch.Tensor):
if not self.merged:
result = F.linear(x, self.weight, bias=self.bias)
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
return F.linear(x, self.weight, bias=self.bias)
@classmethod
def from_linear( cls, layer, **kwargs ):
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs )
# broken, the in_features / out_features change somehow
def parameterize_model( layer, register = True, merge = False, **kwargs ):
if not isinstance( layer, nn.Linear ):
return
if register:
parametrize.register_parametrization( layer, "weight", Linear.from_linear( layer, **kwargs ) )
else:
parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
def apply_lora( model, **kwargs ):
klass = Linear
target = nn.Linear
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
layer = getattr( model.get_submodule(name), k )
if isinstance(layer, klass):
continue
injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
injected.weight = layer.weight
# overwrite
setattr( model.get_submodule(name), k, injected )
return model
def freeze_non_lora_weights( model ):
for name, param in model.named_parameters():
param.requires_grad_('lora_' in name)
return model
def lora_get_state_dict( state_dict, split = True ):
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
if not split:
return lora
return lora, { name: param for name, param in state_dict.items() if "lora_" not in name }
def lora_load_state_dict( model, state_dict ):
return model.load_state_dict( state_dict, strict = False )