very rudimentary lora support (no deepspeed support, tested training and saving but not loading yet)
This commit is contained in:
parent
19410a919e
commit
45a39fb79f
|
@ -188,9 +188,6 @@ class Dataset:
|
||||||
# I really need to clean this up
|
# I really need to clean this up
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
_max_levels: int = 0
|
|
||||||
_embeddings: str | None = None
|
|
||||||
|
|
||||||
name: str = "" # vanity name for the model
|
name: str = "" # vanity name for the model
|
||||||
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
|
version: int = 1 # 1 = old with MultiEmbedding, 2 = new with AudioEmbedding
|
||||||
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
size: str | dict = "full" # preset string or explicitly defined dimensionality
|
||||||
|
@ -223,7 +220,7 @@ class Model:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_levels(self):
|
def max_levels(self):
|
||||||
return self._max_levels if self._max_levels > 0 else self.prom_levels
|
return max(self.prom_levels, self.resp_levels)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
# required for fp8 as the lengths needs to be divisible by 8
|
# required for fp8 as the lengths needs to be divisible by 8
|
||||||
|
@ -316,6 +313,18 @@ class Model:
|
||||||
@property
|
@property
|
||||||
def gradient_checkpointing(self):
|
def gradient_checkpointing(self):
|
||||||
return cfg.trainer.gradient_checkpointing
|
return cfg.trainer.gradient_checkpointing
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
|
class LoRA:
|
||||||
|
name: str = "lora" # vanity name
|
||||||
|
rank: int = 8 # rank for the LoRA
|
||||||
|
alpha: int = 1 # rank for the LoRA
|
||||||
|
training: bool = True #
|
||||||
|
|
||||||
|
@property
|
||||||
|
def full_name(self):
|
||||||
|
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
||||||
|
return "-".join(name)
|
||||||
|
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Hyperparameters:
|
class Hyperparameters:
|
||||||
|
@ -622,7 +631,8 @@ class Config(BaseConfig):
|
||||||
experimental: bool = False # So I can stop commenting out things when committing
|
experimental: bool = False # So I can stop commenting out things when committing
|
||||||
|
|
||||||
dataset: Dataset = field(default_factory=lambda: Dataset)
|
dataset: Dataset = field(default_factory=lambda: Dataset)
|
||||||
models: dict | list | None = field(default_factory=lambda: [Model])
|
models: dict | list | None = field(default_factory=lambda: [])
|
||||||
|
loras: dict | list | None = field(default_factory=lambda: [])
|
||||||
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
hyperparameters: Hyperparameters = field(default_factory=lambda: Hyperparameters)
|
||||||
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
evaluation: Evaluation = field(default_factory=lambda: Evaluation)
|
||||||
trainer: Trainer = field(default_factory=lambda: Trainer)
|
trainer: Trainer = field(default_factory=lambda: Trainer)
|
||||||
|
@ -643,7 +653,15 @@ class Config(BaseConfig):
|
||||||
if model.training:
|
if model.training:
|
||||||
return model
|
return model
|
||||||
|
|
||||||
return self.models[0]
|
return self.models[0] if len(self.models) > 0 else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora(self):
|
||||||
|
for i, lora in enumerate(self.loras):
|
||||||
|
if lora.training:
|
||||||
|
return lora
|
||||||
|
|
||||||
|
return self.loras[0] if len(self.loras) > 0 else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distributed(self):
|
def distributed(self):
|
||||||
|
@ -686,6 +704,9 @@ class Config(BaseConfig):
|
||||||
|
|
||||||
if isinstance(self.models, type):
|
if isinstance(self.models, type):
|
||||||
self.models = dict()
|
self.models = dict()
|
||||||
|
|
||||||
|
if isinstance(self.loras, type):
|
||||||
|
self.loras = dict()
|
||||||
|
|
||||||
if isinstance(self.hyperparameters, type):
|
if isinstance(self.hyperparameters, type):
|
||||||
self.hyperparameters = dict()
|
self.hyperparameters = dict()
|
||||||
|
@ -715,6 +736,7 @@ class Config(BaseConfig):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.models = [ Model(**model) for model in self.models ]
|
self.models = [ Model(**model) for model in self.models ]
|
||||||
|
self.loras = [ LoRA(**lora) for lora in self.loras ]
|
||||||
|
|
||||||
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
self.hyperparameters = Hyperparameters(**self.hyperparameters)
|
||||||
|
|
||||||
|
@ -758,6 +780,10 @@ class Config(BaseConfig):
|
||||||
if not training:
|
if not training:
|
||||||
self.dataset.use_hdf5 = False
|
self.dataset.use_hdf5 = False
|
||||||
|
|
||||||
|
# raise error if DeepSpeed and a LoRA is loaded, because I don't support it yet
|
||||||
|
if self.trainer.backend == "deepspeed" and self.lora is not None:
|
||||||
|
raise Exception("LoRAs are currently unsupported with deepspeed backend")
|
||||||
|
|
||||||
# load our HDF5 file if requested here
|
# load our HDF5 file if requested here
|
||||||
if self.dataset.use_hdf5:
|
if self.dataset.use_hdf5:
|
||||||
self.load_hdf5()
|
self.load_hdf5()
|
||||||
|
|
|
@ -112,7 +112,7 @@ def load_engines(training=True):
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
|
||||||
if not loads_state_dict and backend == "deepspeed" and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
||||||
print("DeepSpeed checkpoint missing, but weights found.")
|
print("DeepSpeed checkpoint missing, but weights found.")
|
||||||
loads_state_dict = True
|
loads_state_dict = True
|
||||||
|
|
||||||
|
@ -178,6 +178,9 @@ def load_engines(training=True):
|
||||||
engines = Engines(engines)
|
engines = Engines(engines)
|
||||||
engines.setup()
|
engines.setup()
|
||||||
|
|
||||||
|
for name, engine in engines.items():
|
||||||
|
engine.load_loras()
|
||||||
|
|
||||||
if not cfg.trainer.load_state_dict:
|
if not cfg.trainer.load_state_dict:
|
||||||
engines.load_checkpoint()
|
engines.load_checkpoint()
|
||||||
|
|
||||||
|
@ -185,6 +188,7 @@ def load_engines(training=True):
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
engine.freeze(freeze_all=False)
|
engine.freeze(freeze_all=False)
|
||||||
|
|
||||||
|
"""
|
||||||
# copy embeddings if requested
|
# copy embeddings if requested
|
||||||
if cfg.model._embeddings is not None:
|
if cfg.model._embeddings is not None:
|
||||||
embeddings_path = cfg.rel_path / cfg.model._embeddings
|
embeddings_path = cfg.rel_path / cfg.model._embeddings
|
||||||
|
@ -210,6 +214,7 @@ def load_engines(training=True):
|
||||||
continue
|
continue
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
engine._frozen_params.add(param)
|
engine._frozen_params.add(param)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
#do_gc()
|
#do_gc()
|
||||||
|
|
|
@ -29,6 +29,7 @@ def default_feeder(engine, batch):
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
from ..utils import dispatch_attribute, flatten_dict, gather_attribute, do_gc, to_device
|
||||||
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
from ..utils.distributed import init_distributed, distributed_initialized, is_global_leader, world_size
|
||||||
|
from ..models.lora import apply_lora, freeze_non_lora_weights, lora_get_state_dict, lora_load_state_dict
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
@ -70,11 +71,17 @@ class Engine():
|
||||||
self.max_nan_losses = 8
|
self.max_nan_losses = 8
|
||||||
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
self.loss_scaler = torch.cuda.amp.GradScaler() if cfg.trainer.scale_loss else None
|
||||||
|
|
||||||
|
self._global_grad_norm = None
|
||||||
|
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
# set to freeze
|
# set to freeze
|
||||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||||
|
|
||||||
|
# freeze non-LoRA params if requested
|
||||||
|
if not self.hyper_config.frozen_params and not freeze_all:
|
||||||
|
return freeze_non_lora_weights( self.module )
|
||||||
|
|
||||||
for name, param in self.module.named_parameters():
|
for name, param in self.module.named_parameters():
|
||||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
@ -119,10 +126,21 @@ class Engine():
|
||||||
|
|
||||||
def save_checkpoint(self, save_dir, tag ):
|
def save_checkpoint(self, save_dir, tag ):
|
||||||
if is_global_leader():
|
if is_global_leader():
|
||||||
|
module = self.module.state_dict()
|
||||||
|
|
||||||
|
# if training lora
|
||||||
|
# this is a separate path to override saving the weights
|
||||||
|
lora = None
|
||||||
|
if cfg.lora is not None:
|
||||||
|
lora, module = lora_get_state_dict( module, split = True )
|
||||||
|
save_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||||
|
|
||||||
save_path = save_dir / tag / "state.pth"
|
save_path = save_dir / tag / "state.pth"
|
||||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
torch.save({
|
torch.save({
|
||||||
"module": self.module.state_dict(),
|
"module": module,
|
||||||
|
"lora": lora,
|
||||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||||
|
|
||||||
|
@ -165,6 +183,20 @@ class Engine():
|
||||||
if load_lr_scheduler_states:
|
if load_lr_scheduler_states:
|
||||||
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device))
|
self.lr_scheduler.load_state_dict(state['lr_scheduler']) #, map_location=torch.device(cfg.device))
|
||||||
|
|
||||||
|
if 'lora' in state:
|
||||||
|
lora_load_state_dict( self.module, state['lora'] )
|
||||||
|
|
||||||
|
|
||||||
|
def load_loras( self ):
|
||||||
|
# apply lora weights
|
||||||
|
for lora in cfg.loras:
|
||||||
|
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha )
|
||||||
|
|
||||||
|
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
|
||||||
|
if lora_path.exists():
|
||||||
|
state_dict = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||||
|
self.module = lora_load_state_dict( self.module, state_dict )
|
||||||
|
|
||||||
def eval(self):
|
def eval(self):
|
||||||
return self.module.eval()
|
return self.module.eval()
|
||||||
|
|
||||||
|
|
|
@ -108,6 +108,9 @@ class Engine(DeepSpeedEngine):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(str(e))
|
print(str(e))
|
||||||
|
|
||||||
|
def load_loras(self):
|
||||||
|
...
|
||||||
|
|
||||||
def traverse(self, *args, **kwargs):
|
def traverse(self, *args, **kwargs):
|
||||||
with ml.autocast():
|
with ml.autocast():
|
||||||
self.forward(*args, **kwargs)
|
self.forward(*args, **kwargs)
|
||||||
|
|
122
vall_e/models/lora.py
Normal file
122
vall_e/models/lora.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||||
|
from functools import partial
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.utils.parametrize as parametrize
|
||||||
|
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
|
||||||
|
rank: int = 4,
|
||||||
|
alpha: int = 1,
|
||||||
|
|
||||||
|
dropout: float = 0.1,
|
||||||
|
merge_weights: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
|
||||||
|
|
||||||
|
self.rank = rank
|
||||||
|
self.alpha = alpha
|
||||||
|
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
||||||
|
self.merge_weights = merge_weights
|
||||||
|
self.merged = False
|
||||||
|
|
||||||
|
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
|
||||||
|
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
|
||||||
|
self.scaling = self.alpha / self.rank
|
||||||
|
|
||||||
|
self.weight.requires_grad = False
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
super().reset_parameters()
|
||||||
|
# super silly but necessary because nn.Linear's constructor calls this
|
||||||
|
if hasattr(self, 'lora_A'):
|
||||||
|
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
|
||||||
|
nn.init.zeros_( self.lora_B )
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
super().train(mode)
|
||||||
|
|
||||||
|
# training, separate lora from base weights
|
||||||
|
if mode and self.merge_weights and self.merged:
|
||||||
|
self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
|
||||||
|
self.merged = False
|
||||||
|
|
||||||
|
# not training, merge lora to base weights
|
||||||
|
if not mode and self.merge_weights and not self.merged:
|
||||||
|
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||||
|
self.merged = True
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if not self.merged:
|
||||||
|
result = F.linear(x, self.weight, bias=self.bias)
|
||||||
|
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
||||||
|
return result
|
||||||
|
|
||||||
|
return F.linear(x, self.weight, bias=self.bias)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_linear( cls, layer, **kwargs ):
|
||||||
|
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs )
|
||||||
|
|
||||||
|
# broken, the in_features / out_features change somehow
|
||||||
|
def parameterize_model( layer, register = True, merge = False, **kwargs ):
|
||||||
|
if not isinstance( layer, nn.Linear ):
|
||||||
|
return
|
||||||
|
|
||||||
|
if register:
|
||||||
|
parametrize.register_parametrization( layer, "weight", Linear.from_linear( layer, **kwargs ) )
|
||||||
|
else:
|
||||||
|
parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
|
||||||
|
|
||||||
|
def apply_lora( model, **kwargs ):
|
||||||
|
klass = Linear
|
||||||
|
target = nn.Linear
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||||
|
|
||||||
|
for *parent, k in modules:
|
||||||
|
name = '.'.join(parent)
|
||||||
|
|
||||||
|
layer = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
|
if isinstance(layer, klass):
|
||||||
|
continue
|
||||||
|
|
||||||
|
injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
injected.weight = layer.weight
|
||||||
|
|
||||||
|
# overwrite
|
||||||
|
setattr( model.get_submodule(name), k, injected )
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def freeze_non_lora_weights( model ):
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
param.requires_grad_('lora_' in name)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def lora_get_state_dict( state_dict, split = True ):
|
||||||
|
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
|
||||||
|
if not split:
|
||||||
|
return lora
|
||||||
|
|
||||||
|
return lora, { name: param for name, param in state_dict.items() if "lora_" not in name }
|
||||||
|
|
||||||
|
def lora_load_state_dict( model, state_dict ):
|
||||||
|
return model.load_state_dict( state_dict, strict = False )
|
Loading…
Reference in New Issue
Block a user