diff --git a/README.md b/README.md index d0bc303..ba705ad 100755 --- a/README.md +++ b/README.md @@ -104,6 +104,27 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`. + +### Finetuning + +Finetuning can be done by training the full model, or using a LoRA. + +Finetuning the full model is done the same way as training a model, but be sure to have the weights in the correct spot, as if you're loading them for inferencing. + +For training a LoRA, add the following block to your `config.yaml`: + +``` +loras: +- name : "arbitrary name" # whatever you want + rank: 128 # dimensionality of the LoRA + alpha: 256 # scaling factor of the LoRA + training: True +``` + +And thats it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values. + +To export your LoRA weights, run `python3 -m vall_e.export --lora --yaml="./training/config.yaml"`. + ### Plotting Metrics Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot --yaml="./training/config.yaml"` diff --git a/vall_e/config.py b/vall_e/config.py index 0fa98d2..daafd0b 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -211,6 +211,7 @@ class Model: capabilities: list = field(default_factory=lambda: ["ar", "nar"]) experimental: str | None = None # for now it sets things to be HF compatible kv_heads: int = 0 # MHA or GQA (for supported backends) + rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range def get(self, name=None): return [ self ] if not name or self.name == name else [] @@ -332,12 +333,17 @@ class LoRA: rank: int = 8 # rank for the LoRA alpha: int = 16 # rank for the LoRA training: bool = True # + rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA @property def full_name(self): name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ] return "-".join(name) + def active_level( self, level ): + if not self.rvq_levels: + return True + return level in self.rvq_levels @dataclass() class Hyperparameters: diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index aef8b92..7cc973d 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -131,6 +131,10 @@ def load_engines(training=True): if "stats" in state: stats = state["stats"] + # do not load stats if we're training a LoRA + if "lora" not in state: + stats = None + if "module" in state: state = state["module"] diff --git a/vall_e/engines/base.py b/vall_e/engines/base.py index 47e658a..7ec6c8c 100755 --- a/vall_e/engines/base.py +++ b/vall_e/engines/base.py @@ -344,7 +344,7 @@ class Engines(dict[str, Engine]): state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) torch.save(state_dict, save_path) - print(f"Exported {name} to {outpath}") + print(f"Exported {name} to {save_path}") def save_checkpoint(self, tag=None): if not tag: diff --git a/vall_e/engines/deepspeed.py b/vall_e/engines/deepspeed.py index 8196a8f..e6d3640 100755 --- a/vall_e/engines/deepspeed.py +++ b/vall_e/engines/deepspeed.py @@ -63,9 +63,6 @@ class Engine(DeepSpeedEngine): self.max_nan_losses = 8 def freeze(self, freeze_all=True): - if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): - raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") - # freeze non-LoRA params if requested if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None: for name, param in self.module.named_parameters(): @@ -75,6 +72,9 @@ class Engine(DeepSpeedEngine): self._frozen_params.add(param) return + if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"): + raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None") + for name, param in self.module.named_parameters(): if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params): param.requires_grad_(False) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 81781f5..0e324d4 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -20,6 +20,8 @@ from tqdm import trange from ..emb.qnt import trim +from .lora import enable_lora + class AR_NAR(Base): @property def capabilities(self) -> list[str]: @@ -27,6 +29,12 @@ class AR_NAR(Base): return self.config.capabilities return cfg.model.capabilities + @property + def quant_level_range(self) -> list[int]: + if hasattr(self, "config") and self.config.rvq_level_range: + return self.config.rvq_level_range + return [ 0 if self.causal else 1, self.n_resp_levels ] + @property def causal(self): return "ar" in self.capabilities @@ -153,7 +161,7 @@ class AR_NAR(Base): task_list = [ sample_task() for _ in range(batch_size) ] # determines which RVQ level to target per batch - quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ] + quant_level_range = self.quant_level_range if cfg.experimental: # makes higher levels less likely @@ -212,6 +220,9 @@ class AR_NAR(Base): if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels break + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( level ) ) + quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level) inputs = self.inputs( @@ -246,7 +257,7 @@ class AR_NAR(Base): # filter """ - if self.arch_type in ["mamba2-hf"]: + if self.arch_type in ["mamba2-hf"] or cfg.lora is not None: for batch_index, resp in enumerate(resps_list): for i, token in enumerate(resp): if token >= 1024: @@ -255,9 +266,15 @@ class AR_NAR(Base): prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ] + if cfg.lora is not None: + enable_lora( self ) + return prev_list # is AR + if cfg.lora is not None: + enable_lora( self, cfg.lora.active_level( 0 ) ) + sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] stopped = torch.zeros(batch_size, device=device).bool() diff --git a/vall_e/models/lora.py b/vall_e/models/lora.py index a6f7a67..338af2e 100644 --- a/vall_e/models/lora.py +++ b/vall_e/models/lora.py @@ -10,7 +10,7 @@ import math from typing import Optional, List # to-do: set cfg to decide -USE_PARAMETRIZATION = True +USE_PARAMETRIZATION = False # LoRA Linear for replacement # Pros: simple, just needs to reuse the replace_linear and copy weights @@ -27,7 +27,7 @@ class Linear(nn.Linear): alpha: int = 1, dropout: float = 0.1, - merge_weights: bool = True, + merge_weights: bool = False, **kwargs, ): super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs) @@ -37,6 +37,7 @@ class Linear(nn.Linear): self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x self.merge_weights = merge_weights self.merged = False + self.enabled = True self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) ) self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) ) @@ -67,7 +68,7 @@ class Linear(nn.Linear): self.merged = True def forward(self, x: torch.Tensor): - if not self.merged: + if not self.merged and self.enabled: result = F.linear(x, self.weight, bias=self.bias) result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling return result @@ -120,7 +121,6 @@ class ParameterizedLinear(nn.Module): def forward(self, x: torch.Tensor): if self.enabled: return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling - return x @classmethod @@ -173,6 +173,16 @@ def apply_lora( model, register = True, merge = False, policy = None, **kwargs ) return model +def enable_lora( model, mode = True ): + for name, module in model.named_modules(): + if not isinstance( module, ParameterizedLinear if USE_PARAMETRIZATION else Linear ): + continue + module.enabled = mode + return model + +def disable_lora( model ): + return enable_lora( model, False ) + def freeze_non_lora_weights( model ): for name, param in model.named_parameters(): param.requires_grad_('lora_' in name)