load exported LoRA weights if exists (to-do: make a better LoRA loading mechanism)

This commit is contained in:
mrq 2024-06-18 21:45:46 -05:00
parent 2bfe786ebd
commit 8a986eb480
3 changed files with 51 additions and 21 deletions

View File

@ -333,6 +333,7 @@ class LoRA:
rank: int = 8 # rank for the LoRA
alpha: int = 16 # rank for the LoRA
training: bool = True #
parametrize: bool = False #
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
@property

View File

@ -56,9 +56,9 @@ def load_engines(training=True):
model.model = ml.replace_embedding( model.model )
for lora in cfg.loras:
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy )
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
optimizer_class = None
scheduler_class = None
@ -164,6 +164,14 @@ def load_engines(training=True):
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# load lora weights if exists
if cfg.lora is not None:
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
if lora_path.exists():
state = torch.load(lora_path, map_location=torch.device(cfg.device))
state = state['lora' if 'lora' in state else 'module']
model.load_state_dict(state, strict=False)
# wrap if DDP is requested
if ddp:
model = ddp_model(model)

View File

@ -4,18 +4,17 @@ import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from transformers.pytorch_utils import Conv1D
from torch import Tensor, nn
import math
from typing import Optional, List
# to-do: set cfg to decide
USE_PARAMETRIZATION = False
# LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
class Linear(nn.Linear):
class LoRALinear(nn.Linear):
def __init__(
self,
@ -86,7 +85,7 @@ class Linear(nn.Linear):
# Uses parametrization to inject LoRA weights
# Pros: should work with any Linears
# Cons: TBD
class ParameterizedLinear(nn.Module):
class ParameterizedLoRA(nn.Module):
def __init__(
self,
@ -119,6 +118,7 @@ class ParameterizedLinear(nn.Module):
nn.init.zeros_( self.lora_B )
def forward(self, x: torch.Tensor):
print( self.enabled, x.shape )
if self.enabled:
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
return x
@ -133,10 +133,21 @@ class ParameterizedLinear(nn.Module):
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
@classmethod
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
in_channels, out_channels = layer.weight.shape
# swap because we're feeding the output as our input
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
@ -149,33 +160,43 @@ def passes_policy( policy, name ):
return False
def apply_lora( model, register = True, merge = False, policy = None, **kwargs ):
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
klass = Linear
target = nn.Linear
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ]
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
for *parent, k in modules:
name = '.'.join(parent)
layer = getattr( model.get_submodule(name), k )
if USE_PARAMETRIZATION:
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
# parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
if isinstance( layer, nn.Linear ):
target = nn.Linear
klass = ParameterizedLoRA if use_parametrize else LoRALinear
replacer = klass.from_linear
elif isinstance( layer, nn.Conv1d ):
target = nn.Conv1d
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
replacer = klass.from_conv1d
elif isinstance( layer, Conv1D ):
target = Conv1D
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
replacer = klass.from_conv1d
else:
setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
continue
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
if use_parametrize:
parametrize.register_parametrization( layer, "weight", replacement )
else:
setattr( model.get_submodule(name), k, replacement )
return model
def enable_lora( model, mode = True ):
for name, module in model.named_modules():
if not isinstance( module, ParameterizedLinear if USE_PARAMETRIZATION else Linear ):
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
continue
module.enabled = mode
return model