load exported LoRA weights if exists (to-do: make a better LoRA loading mechanism)
This commit is contained in:
parent
2bfe786ebd
commit
8a986eb480
|
@ -333,6 +333,7 @@ class LoRA:
|
|||
rank: int = 8 # rank for the LoRA
|
||||
alpha: int = 16 # rank for the LoRA
|
||||
training: bool = True #
|
||||
parametrize: bool = False #
|
||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||
|
||||
@property
|
||||
|
|
|
@ -56,9 +56,9 @@ def load_engines(training=True):
|
|||
model.model = ml.replace_embedding( model.model )
|
||||
|
||||
for lora in cfg.loras:
|
||||
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy )
|
||||
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
||||
|
||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
||||
optimizer_class = None
|
||||
scheduler_class = None
|
||||
|
||||
|
@ -164,6 +164,14 @@ def load_engines(training=True):
|
|||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# load lora weights if exists
|
||||
if cfg.lora is not None:
|
||||
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
|
||||
if lora_path.exists():
|
||||
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||
state = state['lora' if 'lora' in state else 'module']
|
||||
model.load_state_dict(state, strict=False)
|
||||
|
||||
# wrap if DDP is requested
|
||||
if ddp:
|
||||
model = ddp_model(model)
|
||||
|
|
|
@ -4,18 +4,17 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
# to-do: set cfg to decide
|
||||
USE_PARAMETRIZATION = False
|
||||
|
||||
# LoRA Linear for replacement
|
||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
||||
class Linear(nn.Linear):
|
||||
class LoRALinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
|
@ -86,7 +85,7 @@ class Linear(nn.Linear):
|
|||
# Uses parametrization to inject LoRA weights
|
||||
# Pros: should work with any Linears
|
||||
# Cons: TBD
|
||||
class ParameterizedLinear(nn.Module):
|
||||
class ParameterizedLoRA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
|
@ -119,6 +118,7 @@ class ParameterizedLinear(nn.Module):
|
|||
nn.init.zeros_( self.lora_B )
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
print( self.enabled, x.shape )
|
||||
if self.enabled:
|
||||
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
||||
return x
|
||||
|
@ -133,10 +133,21 @@ class ParameterizedLinear(nn.Module):
|
|||
# M$'s LoRA class arranges things to where this isn't necessary
|
||||
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
|
||||
if device is None:
|
||||
device = layer.weight.device
|
||||
if dtype is None:
|
||||
dtype = layer.weight.dtype
|
||||
|
||||
in_channels, out_channels = layer.weight.shape
|
||||
# swap because we're feeding the output as our input
|
||||
# M$'s LoRA class arranges things to where this isn't necessary
|
||||
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
def passes_policy( policy, name ):
|
||||
if policy is None:
|
||||
return True
|
||||
|
||||
if "exclude" in policy:
|
||||
for term in policy["exclude"]:
|
||||
if term in name:
|
||||
|
@ -149,33 +160,43 @@ def passes_policy( policy, name ):
|
|||
|
||||
return False
|
||||
|
||||
|
||||
def apply_lora( model, register = True, merge = False, policy = None, **kwargs ):
|
||||
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
klass = Linear
|
||||
target = nn.Linear
|
||||
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ]
|
||||
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
layer = getattr( model.get_submodule(name), k )
|
||||
|
||||
if USE_PARAMETRIZATION:
|
||||
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
|
||||
# parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
|
||||
if isinstance( layer, nn.Linear ):
|
||||
target = nn.Linear
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRALinear
|
||||
replacer = klass.from_linear
|
||||
elif isinstance( layer, nn.Conv1d ):
|
||||
target = nn.Conv1d
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
||||
replacer = klass.from_conv1d
|
||||
elif isinstance( layer, Conv1D ):
|
||||
target = Conv1D
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
||||
replacer = klass.from_conv1d
|
||||
else:
|
||||
setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
|
||||
continue
|
||||
|
||||
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
|
||||
|
||||
if use_parametrize:
|
||||
parametrize.register_parametrization( layer, "weight", replacement )
|
||||
else:
|
||||
setattr( model.get_submodule(name), k, replacement )
|
||||
|
||||
return model
|
||||
|
||||
def enable_lora( model, mode = True ):
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance( module, ParameterizedLinear if USE_PARAMETRIZATION else Linear ):
|
||||
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
|
||||
continue
|
||||
module.enabled = mode
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue
Block a user