diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index 6e288b4..d63a9a9 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -9,6 +9,12 @@ from torch import Tensor, nn import math from typing import Optional, List +# to-do: set cfg to decide +USE_PARAMETRIZATION = False + +# LoRA Linear for replacement +# Pros: simple, just needs to reuse the replace_linear and copy weights +# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc) class Linear(nn.Linear): def __init__( self, @@ -72,37 +78,88 @@ class Linear(nn.Linear): def from_linear( cls, layer, **kwargs ): return cls( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ) -# broken, the in_features / out_features change somehow -def parameterize_model( layer, register = True, merge = False, **kwargs ): +# Uses parametrization to inject LoRA weights +# Pros: should work with any Linears +# Cons: TBD +class ParameterizedLinear(nn.Module): + def __init__( + self, + + in_features: int, + out_features: int, + bias: bool = True, + + rank: int = 4, + alpha: int = 1, + + dropout: float = 0.1, + + device = None, + dtype = None + ): + super().__init__() + self.rank = rank + self.alpha = alpha + self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x + + self.lora_A = nn.Parameter( torch.zeros( (rank, in_features) ) ).to( device=device, dtype=dtype ) + self.lora_B = nn.Parameter( torch.zeros( (out_features, rank) ) ).to( device=device, dtype=dtype ) + self.scaling = self.alpha / self.rank + self.enabled = True + + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_( self.lora_A, a=math.sqrt(5) ) + nn.init.zeros_( self.lora_B ) + + def forward(self, x: torch.Tensor): + if self.enabled: + return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling + + return x + + @classmethod + def from_linear( cls, layer, **kwargs ): + # swap because we're feeding the output as our input + return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ) + +def parametrize_model( layer, register = True, merge = False, **kwargs ): if not isinstance( layer, nn.Linear ): return if register: - parametrize.register_parametrization( layer, "weight", Linear.from_linear( layer, **kwargs ) ) + parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, **kwargs ) ) else: parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge ) def apply_lora( model, **kwargs ): - klass = Linear - target = nn.Linear - device = next(model.parameters()).device dtype = next(model.parameters()).dtype - modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] - for *parent, k in modules: - name = '.'.join(parent) + if USE_PARAMETRIZATION: + model.apply( partial( parametrize_model, device=device, dtype=dtype, **kwargs ) ) + else: + klass = Linear + target = nn.Linear - layer = getattr( model.get_submodule(name), k ) + device = next(model.parameters()).device + dtype = next(model.parameters()).dtype + modules = [k.split('.') for k, m in model.named_modules() if isinstance(m, target)] - if isinstance(layer, klass): - continue + for *parent, k in modules: + name = '.'.join(parent) - injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) - injected.weight = layer.weight + layer = getattr( model.get_submodule(name), k ) - # overwrite - setattr( model.get_submodule(name), k, injected ) + if isinstance(layer, klass): + continue + + injected = klass( in_features = layer.in_features, out_features = layer.out_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype) + injected.weight = layer.weight + + # overwrite + setattr( model.get_submodule(name), k, injected ) return model