added LoRA policy to decide what layer of the model gets adapted based on simple inclusion/exclusion terms
This commit is contained in:
parent
be051d9544
commit
bd0bc10ec0
|
@ -314,17 +314,30 @@ class Model:
|
|||
def gradient_checkpointing(self):
|
||||
return cfg.trainer.gradient_checkpointing
|
||||
|
||||
@property
|
||||
def lora_policy(self):
|
||||
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
|
||||
exclude = []
|
||||
|
||||
if self.arch_type == "llama":
|
||||
include = ["self_attn", "mlp"] # target only the attention + mlp
|
||||
exclude = ["self_attn.k_proj"] # common literature says to ignore it
|
||||
|
||||
return dict(include=include, exclude=exclude)
|
||||
|
||||
@dataclass()
|
||||
class LoRA:
|
||||
name: str = "lora" # vanity name
|
||||
# to-do: find sane default values
|
||||
rank: int = 8 # rank for the LoRA
|
||||
alpha: int = 1 # rank for the LoRA
|
||||
alpha: int = 16 # rank for the LoRA
|
||||
training: bool = True #
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
||||
return "-".join(name)
|
||||
|
||||
|
||||
@dataclass()
|
||||
class Hyperparameters:
|
||||
|
|
|
@ -157,6 +157,10 @@ class Engine():
|
|||
torch.distributed.barrier()
|
||||
|
||||
def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False):
|
||||
# override to load the lora instead
|
||||
if cfg.lora is not None:
|
||||
load_dir = cfg.ckpt_dir / cfg.lora.full_name
|
||||
|
||||
if tag is None:
|
||||
tag_path = load_dir / "latest"
|
||||
if not tag_path.exists():
|
||||
|
@ -190,7 +194,7 @@ class Engine():
|
|||
def load_loras( self ):
|
||||
# apply lora weights
|
||||
for lora in cfg.loras:
|
||||
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha )
|
||||
self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy )
|
||||
|
||||
lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth"
|
||||
if lora_path.exists():
|
||||
|
@ -327,6 +331,10 @@ class Engines(dict[str, Engine]):
|
|||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def export(self, userdata={}, callback=None):
|
||||
# to-do: lora exporting
|
||||
if cfg.lora is not None:
|
||||
return
|
||||
|
||||
for name, engine in self.items():
|
||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state_dict = {
|
||||
|
|
|
@ -38,8 +38,8 @@ class Linear(nn.Linear):
|
|||
self.merge_weights = merge_weights
|
||||
self.merged = False
|
||||
|
||||
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
|
||||
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
|
||||
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
|
||||
self.scaling = self.alpha / self.rank
|
||||
|
||||
self.weight.requires_grad = False
|
||||
|
@ -75,8 +75,12 @@ class Linear(nn.Linear):
|
|||
return F.linear(x, self.weight, bias=self.bias)
|
||||
|
||||
@classmethod
|
||||
def from_linear( cls, layer, **kwargs ):
|
||||
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs )
|
||||
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
|
||||
if device is None:
|
||||
device = layer.weight.device
|
||||
if dtype is None:
|
||||
dtype = layer.weight.dtype
|
||||
return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
# Uses parametrization to inject LoRA weights
|
||||
# Pros: should work with any Linears
|
||||
|
@ -102,8 +106,8 @@ class ParameterizedLinear(nn.Module):
|
|||
self.alpha = alpha
|
||||
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
||||
|
||||
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
|
||||
self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype )
|
||||
self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype )
|
||||
self.scaling = self.alpha / self.rank
|
||||
self.enabled = True
|
||||
|
||||
|
@ -120,46 +124,52 @@ class ParameterizedLinear(nn.Module):
|
|||
return x
|
||||
|
||||
@classmethod
|
||||
def from_linear( cls, layer, **kwargs ):
|
||||
def from_linear( cls, layer, device = None, dtype = None, **kwargs ):
|
||||
if device is None:
|
||||
device = layer.weight.device
|
||||
if dtype is None:
|
||||
dtype = layer.weight.dtype
|
||||
# swap because we're feeding the output as our input
|
||||
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs )
|
||||
# M$'s LoRA class arranges things to where this isn't necessary
|
||||
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
def parametrize_model( layer, register = True, merge = False, **kwargs ):
|
||||
if not isinstance( layer, nn.Linear ):
|
||||
return
|
||||
def passes_policy( policy, name ):
|
||||
if policy is None:
|
||||
return True
|
||||
|
||||
if register:
|
||||
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, **kwargs ) )
|
||||
else:
|
||||
parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
|
||||
if "exclude" in policy:
|
||||
for term in policy["exclude"]:
|
||||
if term in name:
|
||||
return False
|
||||
|
||||
def apply_lora( model, **kwargs ):
|
||||
if "include" in policy:
|
||||
for term in policy["include"]:
|
||||
if term in name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def apply_lora( model, register = True, merge = False, policy = None, **kwargs ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
if USE_PARAMETRIZATION:
|
||||
model.apply( partial( parametrize_model, device=device, dtype=dtype, **kwargs ) )
|
||||
else:
|
||||
klass = Linear
|
||||
target = nn.Linear
|
||||
klass = Linear
|
||||
target = nn.Linear
|
||||
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)]
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
layer = getattr( model.get_submodule(name), k )
|
||||
|
||||
layer = getattr( model.get_submodule(name), k )
|
||||
|
||||
if isinstance(layer, klass):
|
||||
continue
|
||||
|
||||
injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
injected.weight = layer.weight
|
||||
|
||||
# overwrite
|
||||
setattr( model.get_submodule(name), k, injected )
|
||||
if USE_PARAMETRIZATION:
|
||||
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
|
||||
# parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
|
||||
else:
|
||||
setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
|
||||
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user