load exported LoRA weights if exists (to-do: make a better LoRA loading mechanism)
This commit is contained in:
parent
2bfe786ebd
commit
8a986eb480
|
@ -333,6 +333,7 @@ class LoRA:
|
||||||
rank: int = 8 # rank for the LoRA
|
rank: int = 8 # rank for the LoRA
|
||||||
alpha: int = 16 # rank for the LoRA
|
alpha: int = 16 # rank for the LoRA
|
||||||
training: bool = True #
|
training: bool = True #
|
||||||
|
parametrize: bool = False #
|
||||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -56,9 +56,9 @@ def load_engines(training=True):
|
||||||
model.model = ml.replace_embedding( model.model )
|
model.model = ml.replace_embedding( model.model )
|
||||||
|
|
||||||
for lora in cfg.loras:
|
for lora in cfg.loras:
|
||||||
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy )
|
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
||||||
|
|
||||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
||||||
optimizer_class = None
|
optimizer_class = None
|
||||||
scheduler_class = None
|
scheduler_class = None
|
||||||
|
|
||||||
|
@ -164,6 +164,14 @@ def load_engines(training=True):
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
|
# load lora weights if exists
|
||||||
|
if cfg.lora is not None:
|
||||||
|
lora_path = cfg.ckpt_dir / lora.full_name / "lora.pth"
|
||||||
|
if lora_path.exists():
|
||||||
|
state = torch.load(lora_path, map_location=torch.device(cfg.device))
|
||||||
|
state = state['lora' if 'lora' in state else 'module']
|
||||||
|
model.load_state_dict(state, strict=False)
|
||||||
|
|
||||||
# wrap if DDP is requested
|
# wrap if DDP is requested
|
||||||
if ddp:
|
if ddp:
|
||||||
model = ddp_model(model)
|
model = ddp_model(model)
|
||||||
|
|
|
@ -4,18 +4,17 @@ import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.nn.utils.parametrize as parametrize
|
import torch.nn.utils.parametrize as parametrize
|
||||||
|
|
||||||
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
# to-do: set cfg to decide
|
|
||||||
USE_PARAMETRIZATION = False
|
|
||||||
|
|
||||||
# LoRA Linear for replacement
|
# LoRA Linear for replacement
|
||||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||||
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
||||||
class Linear(nn.Linear):
|
class LoRALinear(nn.Linear):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
|
@ -86,7 +85,7 @@ class Linear(nn.Linear):
|
||||||
# Uses parametrization to inject LoRA weights
|
# Uses parametrization to inject LoRA weights
|
||||||
# Pros: should work with any Linears
|
# Pros: should work with any Linears
|
||||||
# Cons: TBD
|
# Cons: TBD
|
||||||
class ParameterizedLinear(nn.Module):
|
class ParameterizedLoRA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
||||||
|
@ -119,6 +118,7 @@ class ParameterizedLinear(nn.Module):
|
||||||
nn.init.zeros_( self.lora_B )
|
nn.init.zeros_( self.lora_B )
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
|
print( self.enabled, x.shape )
|
||||||
if self.enabled:
|
if self.enabled:
|
||||||
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
||||||
return x
|
return x
|
||||||
|
@ -133,10 +133,21 @@ class ParameterizedLinear(nn.Module):
|
||||||
# M$'s LoRA class arranges things to where this isn't necessary
|
# M$'s LoRA class arranges things to where this isn't necessary
|
||||||
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
|
||||||
|
if device is None:
|
||||||
|
device = layer.weight.device
|
||||||
|
if dtype is None:
|
||||||
|
dtype = layer.weight.dtype
|
||||||
|
|
||||||
|
in_channels, out_channels = layer.weight.shape
|
||||||
|
# swap because we're feeding the output as our input
|
||||||
|
# M$'s LoRA class arranges things to where this isn't necessary
|
||||||
|
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
def passes_policy( policy, name ):
|
def passes_policy( policy, name ):
|
||||||
if policy is None:
|
if policy is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if "exclude" in policy:
|
if "exclude" in policy:
|
||||||
for term in policy["exclude"]:
|
for term in policy["exclude"]:
|
||||||
if term in name:
|
if term in name:
|
||||||
|
@ -149,33 +160,43 @@ def passes_policy( policy, name ):
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
|
||||||
def apply_lora( model, register = True, merge = False, policy = None, **kwargs ):
|
|
||||||
device = next(model.parameters()).device
|
device = next(model.parameters()).device
|
||||||
dtype = next(model.parameters()).dtype
|
dtype = next(model.parameters()).dtype
|
||||||
|
|
||||||
klass = Linear
|
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
|
||||||
target = nn.Linear
|
|
||||||
|
|
||||||
device = next(model.parameters()).device
|
|
||||||
dtype = next(model.parameters()).dtype
|
|
||||||
modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ]
|
|
||||||
|
|
||||||
for *parent, k in modules:
|
for *parent, k in modules:
|
||||||
name = '.'.join(parent)
|
name = '.'.join(parent)
|
||||||
layer = getattr( model.get_submodule(name), k )
|
layer = getattr( model.get_submodule(name), k )
|
||||||
|
|
||||||
if USE_PARAMETRIZATION:
|
if isinstance( layer, nn.Linear ):
|
||||||
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
|
target = nn.Linear
|
||||||
# parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
|
klass = ParameterizedLoRA if use_parametrize else LoRALinear
|
||||||
|
replacer = klass.from_linear
|
||||||
|
elif isinstance( layer, nn.Conv1d ):
|
||||||
|
target = nn.Conv1d
|
||||||
|
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
||||||
|
replacer = klass.from_conv1d
|
||||||
|
elif isinstance( layer, Conv1D ):
|
||||||
|
target = Conv1D
|
||||||
|
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
||||||
|
replacer = klass.from_conv1d
|
||||||
else:
|
else:
|
||||||
setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
|
continue
|
||||||
|
|
||||||
|
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
|
||||||
|
|
||||||
|
if use_parametrize:
|
||||||
|
parametrize.register_parametrization( layer, "weight", replacement )
|
||||||
|
else:
|
||||||
|
setattr( model.get_submodule(name), k, replacement )
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def enable_lora( model, mode = True ):
|
def enable_lora( model, mode = True ):
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if not isinstance( module, ParameterizedLinear if USE_PARAMETRIZATION else Linear ):
|
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
|
||||||
continue
|
continue
|
||||||
module.enabled = mode
|
module.enabled = mode
|
||||||
return model
|
return model
|
||||||
|
|
Loading…
Reference in New Issue
Block a user