added LoRA policy to decide what layer of the model gets adapted based on simple inclusion/exclusion terms

This commit is contained in:
mrq 2024-06-17 13:05:06 -05:00
parent be051d9544
commit bd0bc10ec0
3 changed files with 67 additions and 36 deletions

View File

@ -314,17 +314,30 @@ class Model:
def gradient_checkpointing(self):
return cfg.trainer.gradient_checkpointing
@property
def lora_policy(self):
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
exclude = []
if self.arch_type == "llama":
include = ["self_attn", "mlp"] # target only the attention + mlp
exclude = ["self_attn.k_proj"] # common literature says to ignore it
return dict(include=include, exclude=exclude)
@dataclass()
class LoRA:
name: str = "lora" # vanity name
# to-do: find sane default values
rank: int = 8 # rank for the LoRA
alpha: int = 1 # rank for the LoRA
alpha: int = 16 # rank for the LoRA
training: bool = True #
@property
def full_name(self):
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
return "-".join(name)
@dataclass()
class Hyperparameters:

View File

@ -157,6 +157,10 @@ class Engine():
torch.distributed.barrier()
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
# override to load the lora instead
if cfg.lora is not None:
load_dir = cfg.ckpt_dir / cfg.lora.full_name
if tag is None:
tag_path = load_dir / "latest"
if not tag_path.exists():
@ -190,7 +194,7 @@ class Engine():
def load_loras( self ):
# apply lora weights
for lora in cfg.loras:
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha )
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
if lora_path.exists():
@ -327,6 +331,10 @@ class Engines(dict[str, Engine]):
engine.dispatch_attribute(*args, **kwargs)
def export(self, userdata={}, callback=None):
# to-do: lora exporting
if cfg.lora is not None:
return
for name, engine in self.items():
outpath = cfg.ckpt_dir / name / "fp32.pth"
state_dict = {

View File

@ -38,8 +38,8 @@ class Linear(nn.Linear):
self.merge_weights = merge_weights
self.merged = False
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
self.scaling = self.alpha / self.rank
self.weight.requires_grad = False
@ -75,8 +75,12 @@ class Linear(nn.Linear):
return F.linear(x, self.weight, bias=self.bias)
@classmethod
def from_linear( cls, layer, **kwargs ):
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs )
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
# Uses parametrization to inject LoRA weights
# Pros: should work with any Linears
@ -102,8 +106,8 @@ class ParameterizedLinear(nn.Module):
self.alpha = alpha
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
self.scaling = self.alpha / self.rank
self.enabled = True
@ -120,46 +124,52 @@ class ParameterizedLinear(nn.Module):
return x
@classmethod
def from_linear( cls, layer, **kwargs ):
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
# swap because we're feeding the output as our input
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs )
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
def parametrize_model( layer, register = True, merge = False, **kwargs ):
if not isinstance( layer, nn.Linear ):
return
def passes_policy( policy, name ):
if policy is None:
return True
if register:
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, **kwargs ) )
else:
parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
return False
def apply_lora( model, **kwargs ):
if "include" in policy:
for term in policy["include"]:
if term in name:
return True
return False
def apply_lora( model, register = True, merge = False, policy = None, **kwargs ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
if USE_PARAMETRIZATION:
model.apply( partial( parametrize_model, device=device, dtype=dtype, **kwargs ) )
else:
klass = Linear
target = nn.Linear
klass = Linear
target = nn.Linear
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ]
for *parent, k in modules:
name = '.'.join(parent)
for *parent, k in modules:
name = '.'.join(parent)
layer = getattr( model.get_submodule(name), k )
layer = getattr( model.get_submodule(name), k )
if isinstance(layer, klass):
continue
injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
injected.weight = layer.weight
# overwrite
setattr( model.get_submodule(name), k, injected )
if USE_PARAMETRIZATION:
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
# parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
else:
setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
return model