added other LoRA method using parametrization rather than linear injection

This commit is contained in:
mrq 2024-06-17 09:58:34 -05:00
parent 45a39fb79f
commit be051d9544

View File

@ -9,6 +9,12 @@ from torch import Tensor, nn
import math
from typing import Optional, List
# to-do: set cfg to decide
USE_PARAMETRIZATION = False
# LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc)
class Linear(nn.Linear):
def __init__(
self,
@ -72,37 +78,88 @@ class Linear(nn.Linear):
def from_linear( cls, layer, **kwargs ):
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs )
# broken, the in_features / out_features change somehow
def parameterize_model( layer, register = True, merge = False, **kwargs ):
# Uses parametrization to inject LoRA weights
# Pros: should work with any Linears
# Cons: TBD
class ParameterizedLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
rank: int = 4,
alpha: int = 1,
dropout: float = 0.1,
device = None,
dtype = None
):
super().__init__()
self.rank = rank
self.alpha = alpha
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
self.scaling = self.alpha / self.rank
self.enabled = True
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
nn.init.zeros_( self.lora_B )
def forward(self, x: torch.Tensor):
if self.enabled:
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
return x
@classmethod
def from_linear( cls, layer, **kwargs ):
# swap because we're feeding the output as our input
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs )
def parametrize_model( layer, register = True, merge = False, **kwargs ):
if not isinstance( layer, nn.Linear ):
return
if register:
parametrize.register_parametrization( layer, "weight", Linear.from_linear( layer, **kwargs ) )
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, **kwargs ) )
else:
parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
def apply_lora( model, **kwargs ):
klass = Linear
target = nn.Linear
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
for *parent, k in modules:
name = '.'.join(parent)
if USE_PARAMETRIZATION:
model.apply( partial( parametrize_model, device=device, dtype=dtype, **kwargs ) )
else:
klass = Linear
target = nn.Linear
layer = getattr( model.get_submodule(name), k )
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
if isinstance(layer, klass):
continue
for *parent, k in modules:
name = '.'.join(parent)
injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
injected.weight = layer.weight
layer = getattr( model.get_submodule(name), k )
# overwrite
setattr( model.get_submodule(name), k, injected )
if isinstance(layer, klass):
continue
injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
injected.weight = layer.weight
# overwrite
setattr( model.get_submodule(name), k, injected )
return model