2024-06-17 05:09:16 +00:00
|
|
|
# Adapted from https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
|
|
|
from functools import partial
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.utils.parametrize as parametrize
|
|
|
|
|
2024-06-19 02:45:46 +00:00
|
|
|
from transformers.pytorch_utils import Conv1D
|
|
|
|
|
2024-06-17 05:09:16 +00:00
|
|
|
from torch import Tensor, nn
|
|
|
|
|
|
|
|
import math
|
|
|
|
from typing import Optional, List
|
|
|
|
|
2024-06-17 14:58:34 +00:00
|
|
|
# LoRA Linear for replacement
|
|
|
|
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
2024-06-17 18:55:37 +00:00
|
|
|
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
2024-06-19 02:45:46 +00:00
|
|
|
class LoRALinear(nn.Linear):
|
2024-06-17 05:09:16 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
|
|
|
|
in_features: int,
|
|
|
|
out_features: int,
|
|
|
|
bias: bool = True,
|
|
|
|
|
|
|
|
rank: int = 4,
|
|
|
|
alpha: int = 1,
|
|
|
|
|
|
|
|
dropout: float = 0.1,
|
2024-06-18 02:45:03 +00:00
|
|
|
merge_weights: bool = False,
|
2024-06-17 05:09:16 +00:00
|
|
|
**kwargs,
|
|
|
|
):
|
|
|
|
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
|
|
|
|
|
|
|
|
self.rank = rank
|
|
|
|
self.alpha = alpha
|
|
|
|
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
|
|
|
self.merge_weights = merge_weights
|
|
|
|
self.merged = False
|
2024-06-18 02:45:03 +00:00
|
|
|
self.enabled = True
|
2024-06-17 05:09:16 +00:00
|
|
|
|
|
|
|
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
|
2024-06-17 18:05:06 +00:00
|
|
|
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
|
2024-06-17 05:09:16 +00:00
|
|
|
self.scaling = self.alpha / self.rank
|
|
|
|
|
|
|
|
self.weight.requires_grad = False
|
|
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
super().reset_parameters()
|
|
|
|
# super silly but necessary because nn.Linear's constructor calls this
|
|
|
|
if hasattr(self, 'lora_A'):
|
|
|
|
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
|
|
|
|
nn.init.zeros_( self.lora_B )
|
|
|
|
|
|
|
|
def train(self, mode: bool = True):
|
|
|
|
super().train(mode)
|
|
|
|
|
|
|
|
# training, separate lora from base weights
|
|
|
|
if mode and self.merge_weights and self.merged:
|
|
|
|
self.weight.data -= (self.lora_B @ self.lora_A) * self.scaling
|
|
|
|
self.merged = False
|
|
|
|
|
|
|
|
# not training, merge lora to base weights
|
|
|
|
if not mode and self.merge_weights and not self.merged:
|
|
|
|
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
|
|
|
self.merged = True
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
2024-06-18 02:45:03 +00:00
|
|
|
if not self.merged and self.enabled:
|
2024-06-17 05:09:16 +00:00
|
|
|
result = F.linear(x, self.weight, bias=self.bias)
|
|
|
|
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
|
|
|
return result
|
|
|
|
|
|
|
|
return F.linear(x, self.weight, bias=self.bias)
|
|
|
|
|
|
|
|
@classmethod
|
2024-06-17 18:05:06 +00:00
|
|
|
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
|
|
|
|
if device is None:
|
|
|
|
device = layer.weight.device
|
|
|
|
if dtype is None:
|
|
|
|
dtype = layer.weight.dtype
|
|
|
|
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
2024-06-17 05:09:16 +00:00
|
|
|
|
2024-06-17 14:58:34 +00:00
|
|
|
# Uses parametrization to inject LoRA weights
|
|
|
|
# Pros: should work with any Linears
|
|
|
|
# Cons: TBD
|
2024-06-19 02:45:46 +00:00
|
|
|
class ParameterizedLoRA(nn.Module):
|
2024-06-17 14:58:34 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
|
|
|
|
in_features: int,
|
|
|
|
out_features: int,
|
|
|
|
bias: bool = True,
|
|
|
|
|
|
|
|
rank: int = 4,
|
|
|
|
alpha: int = 1,
|
|
|
|
|
|
|
|
dropout: float = 0.1,
|
|
|
|
|
|
|
|
device = None,
|
|
|
|
dtype = None
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
self.rank = rank
|
|
|
|
self.alpha = alpha
|
|
|
|
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
|
|
|
|
|
|
|
self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
|
2024-06-17 18:05:06 +00:00
|
|
|
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
|
2024-06-17 14:58:34 +00:00
|
|
|
self.scaling = self.alpha / self.rank
|
|
|
|
self.enabled = True
|
|
|
|
|
|
|
|
self.reset_parameters()
|
|
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
|
nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) )
|
|
|
|
nn.init.zeros_( self.lora_B )
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor):
|
2024-06-19 02:45:46 +00:00
|
|
|
print( self.enabled, x.shape )
|
2024-06-17 14:58:34 +00:00
|
|
|
if self.enabled:
|
|
|
|
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
|
|
|
return x
|
|
|
|
|
|
|
|
@classmethod
|
2024-06-17 18:05:06 +00:00
|
|
|
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
|
|
|
|
if device is None:
|
|
|
|
device = layer.weight.device
|
|
|
|
if dtype is None:
|
|
|
|
dtype = layer.weight.dtype
|
2024-06-17 14:58:34 +00:00
|
|
|
# swap because we're feeding the output as our input
|
2024-06-17 18:05:06 +00:00
|
|
|
# M$'s LoRA class arranges things to where this isn't necessary
|
|
|
|
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
2024-06-17 14:58:34 +00:00
|
|
|
|
2024-06-19 02:45:46 +00:00
|
|
|
@classmethod
|
|
|
|
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
|
|
|
|
if device is None:
|
|
|
|
device = layer.weight.device
|
|
|
|
if dtype is None:
|
|
|
|
dtype = layer.weight.dtype
|
|
|
|
|
|
|
|
in_channels, out_channels = layer.weight.shape
|
|
|
|
# swap because we're feeding the output as our input
|
|
|
|
# M$'s LoRA class arranges things to where this isn't necessary
|
|
|
|
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
|
|
|
|
2024-06-17 18:05:06 +00:00
|
|
|
def passes_policy( policy, name ):
|
|
|
|
if policy is None:
|
|
|
|
return True
|
|
|
|
if "exclude" in policy:
|
|
|
|
for term in policy["exclude"]:
|
|
|
|
if term in name:
|
|
|
|
return False
|
2024-06-17 05:09:16 +00:00
|
|
|
|
2024-06-17 18:05:06 +00:00
|
|
|
if "include" in policy:
|
|
|
|
for term in policy["include"]:
|
|
|
|
if term in name:
|
|
|
|
return True
|
2024-06-17 05:09:16 +00:00
|
|
|
|
2024-06-17 18:05:06 +00:00
|
|
|
return False
|
2024-06-17 14:58:34 +00:00
|
|
|
|
2024-06-19 02:45:46 +00:00
|
|
|
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
|
2024-06-17 18:05:06 +00:00
|
|
|
device = next(model.parameters()).device
|
|
|
|
dtype = next(model.parameters()).dtype
|
2024-06-17 05:09:16 +00:00
|
|
|
|
2024-06-19 02:45:46 +00:00
|
|
|
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
|
2024-06-17 05:09:16 +00:00
|
|
|
|
2024-06-17 18:05:06 +00:00
|
|
|
for *parent, k in modules:
|
|
|
|
name = '.'.join(parent)
|
|
|
|
layer = getattr( model.get_submodule(name), k )
|
2024-06-17 05:09:16 +00:00
|
|
|
|
2024-06-19 02:45:46 +00:00
|
|
|
if isinstance( layer, nn.Linear ):
|
|
|
|
target = nn.Linear
|
|
|
|
klass = ParameterizedLoRA if use_parametrize else LoRALinear
|
|
|
|
replacer = klass.from_linear
|
|
|
|
elif isinstance( layer, nn.Conv1d ):
|
|
|
|
target = nn.Conv1d
|
|
|
|
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
|
|
|
replacer = klass.from_conv1d
|
|
|
|
elif isinstance( layer, Conv1D ):
|
|
|
|
target = Conv1D
|
|
|
|
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
|
|
|
replacer = klass.from_conv1d
|
|
|
|
else:
|
|
|
|
continue
|
|
|
|
|
|
|
|
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
|
|
|
|
|
|
|
|
if use_parametrize:
|
|
|
|
parametrize.register_parametrization( layer, "weight", replacement )
|
2024-06-17 18:05:06 +00:00
|
|
|
else:
|
2024-06-19 02:45:46 +00:00
|
|
|
setattr( model.get_submodule(name), k, replacement )
|
2024-06-17 05:09:16 +00:00
|
|
|
|
|
|
|
return model
|
|
|
|
|
2024-06-18 02:45:03 +00:00
|
|
|
def enable_lora( model, mode = True ):
|
|
|
|
for name, module in model.named_modules():
|
2024-06-19 02:45:46 +00:00
|
|
|
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
|
2024-06-18 02:45:03 +00:00
|
|
|
continue
|
|
|
|
module.enabled = mode
|
|
|
|
return model
|
|
|
|
|
|
|
|
def disable_lora( model ):
|
|
|
|
return enable_lora( model, False )
|
|
|
|
|
2024-06-17 05:09:16 +00:00
|
|
|
def freeze_non_lora_weights( model ):
|
|
|
|
for name, param in model.named_parameters():
|
|
|
|
param.requires_grad_('lora_' in name)
|
|
|
|
return model
|
|
|
|
|
|
|
|
def lora_get_state_dict( state_dict, split = True ):
|
|
|
|
lora = { name: param for name, param in state_dict.items() if "lora_" in name }
|
|
|
|
if not split:
|
|
|
|
return lora
|
|
|
|
|
|
|
|
return lora, { name: param for name, param in state_dict.items() if "lora_" not in name }
|
|
|
|
|
|
|
|
def lora_load_state_dict( model, state_dict ):
|
|
|
|
return model.load_state_dict( state_dict, strict = False )
|