diff --git a/vall_e/config.py b/vall_e/config.py index ceb5278..990bdc4 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -314,17 +314,30 @@ class Model: def gradient_checkpointing(self): return cfg.trainer.gradient_checkpointing + @property + def lora_policy(self): + include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever) + exclude = [] + + if self.arch_type == "llama": + include = ["self_attn", "mlp"] # target only the attention + mlp + exclude = ["self_attn.k_proj"] # common literature says to ignore it + + return dict(include=include, exclude=exclude) + @dataclass() class LoRA: name: str = "lora" # vanity name + # to-do: find sane default values rank: int = 8 # rank for the LoRA - alpha: int = 1 # rank for the LoRA + alpha: int = 16 # rank for the LoRA training: bool = True # @property def full_name(self): name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ] return "-".join(name) + @dataclass() class Hyperparameters: diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 306fa5c..f55a642 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -157,6 +157,10 @@ class Engine(): torch.distributed.barrier() def load_checkpoint(self, load_dir, tag=None, load_module_strict=True, load_optimizer_states=True, load_lr_scheduler_states=True, load_module_only=False): + # override to load the lora instead + if cfg.lora is not None: + load_dir = cfg.ckpt_dir / cfg.lora.full_name + if tag is None: tag_path = load_dir / "latest" if not tag_path.exists(): @@ -190,7 +194,7 @@ class Engine(): def load_loras( self ): # apply lora weights for lora in cfg.loras: - self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha ) + self.module = apply_lora( self.module, rank = lora.rank, alpha = lora.alpha, policy = self.hyper_config.lora_policy ) lora_path = cfg.ckpt_dir / lora.full_name / "fp32.pth" if lora_path.exists(): @@ -327,6 +331,10 @@ class Engines(dict[str, Engine]): engine.dispatch_attribute(*args, **kwargs) def export(self, userdata={}, callback=None): + # to-do: lora exporting + if cfg.lora is not None: + return + for name, engine in self.items(): outpath = cfg.ckpt_dir / name / "fp32.pth" state_dict = { diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index d63a9a9..926c3a8 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -38,8 +38,8 @@ class Linear(nn.Linear): self.merge_weights = merge_weights self.merged = False - self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) ) self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) ) + self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) ) self.scaling = self.alpha / self.rank self.weight.requires_grad = False @@ -75,8 +75,12 @@ class Linear(nn.Linear): return F.linear(x, self.weight, bias=self.bias) @classmethod - def from_linear( cls, layer, **kwargs ): - return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ) + def from_linear( cls, layer, device = None, dtype = None, **kwargs ): + if device is None: + device = layer.weight.device + if dtype is None: + dtype = layer.weight.dtype + return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) # Uses parametrization to inject LoRA weights # Pros: should work with any Linears @@ -102,8 +106,8 @@ class ParameterizedLinear(nn.Module): self.alpha = alpha self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x - self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype ) self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype ) + self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype ) self.scaling = self.alpha / self.rank self.enabled = True @@ -120,46 +124,52 @@ class ParameterizedLinear(nn.Module): return x @classmethod - def from_linear( cls, layer, **kwargs ): + def from_linear( cls, layer, device = None, dtype = None, **kwargs ): + if device is None: + device = layer.weight.device + if dtype is None: + dtype = layer.weight.dtype # swap because we're feeding the output as our input - return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ) + # M$'s LoRA class arranges things to where this isn't necessary + return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) -def parametrize_model( layer, register = True, merge = False, **kwargs ): - if not isinstance( layer, nn.Linear ): - return +def passes_policy( policy, name ): + if policy is None: + return True - if register: - parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, **kwargs ) ) - else: - parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge ) + if "exclude" in policy: + for term in policy["exclude"]: + if term in name: + return False -def apply_lora( model, **kwargs ): + if "include" in policy: + for term in policy["include"]: + if term in name: + return True + + return False + + +def apply_lora( model, register = True, merge = False, policy = None, **kwargs ): device = next(model.parameters()).device dtype = next(model.parameters()).dtype - if USE_PARAMETRIZATION: - model.apply( partial( parametrize_model, device=device, dtype=dtype, **kwargs ) ) - else: - klass = Linear - target = nn.Linear + klass = Linear + target = nn.Linear - device = next(model.parameters()).device - dtype = next(model.parameters()).dtype - modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ] - for *parent, k in modules: - name = '.'.join(parent) + for *parent, k in modules: + name = '.'.join(parent) + layer = getattr( model.get_submodule(name), k ) - layer = getattr( model.get_submodule(name), k ) - - if isinstance(layer, klass): - continue - - injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) - injected.weight = layer.weight - - # overwrite - setattr( model.get_submodule(name), k, injected ) + if USE_PARAMETRIZATION: + parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) ) + # parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge ) + else: + setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) ) return model