added other LoRA method using parametrization rather than linear injection
This commit is contained in:
parent
45a39fb79f
commit
be051d9544
|
@ -9,6 +9,12 @@ from torch import Tensor, nn
|
||||||
import math
|
import math
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
|
# to-do: set cfg to decide
|
||||||
|
USE_PARAMETRIZATION = False
|
||||||
|
|
||||||
|
# LoRA Linear for replacement
|
||||||
|
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||||
|
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc)
|
||||||
class Linear(nn.Linear):
|
class Linear(nn.Linear):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -72,17 +78,68 @@ class Linear(nn.Linear):
|
||||||
def from_linear( cls, layer, **kwargs ):
|
def from_linear( cls, layer, **kwargs ):
|
||||||
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs )
|
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs )
|
||||||
|
|
||||||
# broken, the in_features / out_features change somehow
|
# Uses parametrization to inject LoRA weights
|
||||||
def parameterize_model( layer, register = True, merge = False, **kwargs ):
|
# Pros: should work with any Linears
|
||||||
|
# Cons: TBD
|
||||||
|
class ParameterizedLinear(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
|
||||||
|
rank: int = 4,
|
||||||
|
alpha: int = 1,
|
||||||
|
|
||||||
|
dropout: float = 0.1,
|
||||||
|
|
||||||
|
device = None,
|
||||||
|
dtype = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.rank = rank
|
||||||
|
self.alpha = alpha
|
||||||
|
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
||||||
|
|
||||||
|
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
|
||||||
|
self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
|
||||||
|
self.scaling = self.alpha / self.rank
|
||||||
|
self.enabled = True
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
|
||||||
|
nn.init.zeros_( self.lora_B )
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.enabled:
|
||||||
|
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_linear( cls, layer, **kwargs ):
|
||||||
|
# swap because we're feeding the output as our input
|
||||||
|
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs )
|
||||||
|
|
||||||
|
def parametrize_model( layer, register = True, merge = False, **kwargs ):
|
||||||
if not isinstance( layer, nn.Linear ):
|
if not isinstance( layer, nn.Linear ):
|
||||||
return
|
return
|
||||||
|
|
||||||
if register:
|
if register:
|
||||||
parametrize.register_parametrization( layer, "weight", Linear.from_linear( layer, **kwargs ) )
|
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, **kwargs ) )
|
||||||
else:
|
else:
|
||||||
parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
|
parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
|
||||||
|
|
||||||
def apply_lora( model, **kwargs ):
|
def apply_lora( model, **kwargs ):
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
dtype = next(model.parameters()).dtype
|
||||||
|
|
||||||
|
if USE_PARAMETRIZATION:
|
||||||
|
model.apply( partial( parametrize_model, device=device, dtype=dtype, **kwargs ) )
|
||||||
|
else:
|
||||||
klass = Linear
|
klass = Linear
|
||||||
target = nn.Linear
|
target = nn.Linear
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user