enable LoRA for targetted RVQ levels (to experiment with, seems to help)

This commit is contained in:
mrq 2024-06-17 21:45:03 -05:00
parent 7047fcc6e2
commit 7cfb78fa64
7 changed files with 68 additions and 10 deletions

View File

@ -104,6 +104,27 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
### Finetuning
Finetuning can be done by training the full model, or using a LoRA.
Finetuning the full model is done the same way as training a model, but be sure to have the weights in the correct spot, as if you're loading them for inferencing.
For training a LoRA, add the following block to your `config.yaml`:
```
loras:
- name : "arbitrary name" # whatever you want
rank: 128 # dimensionality of the LoRA
alpha: 256 # scaling factor of the LoRA
training: True
```
And thats it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values.
To export your LoRA weights, run `python3 -m vall_e.export --lora --yaml="./training/config.yaml"`.
### Plotting Metrics
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot --yaml="./training/config.yaml"`

View File

@ -211,6 +211,7 @@ class Model:
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
experimental: str | None = None # for now it sets things to be HF compatible
kv_heads: int = 0 # MHA or GQA (for supported backends)
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@ -332,12 +333,17 @@ class LoRA:
rank: int = 8 # rank for the LoRA
alpha: int = 16 # rank for the LoRA
training: bool = True #
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
@property
def full_name(self):
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
return "-".join(name)
def active_level( self, level ):
if not self.rvq_levels:
return True
return level in self.rvq_levels
@dataclass()
class Hyperparameters:

View File

@ -131,6 +131,10 @@ def load_engines(training=True):
if "stats" in state:
stats = state["stats"]
# do not load stats if we're training a LoRA
if "lora" not in state:
stats = None
if "module" in state:
state = state["module"]

View File

@ -344,7 +344,7 @@ class Engines(dict[str, Engine]):
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
torch.save(state_dict, save_path)
print(f"Exported {name} to {outpath}")
print(f"Exported {name} to {save_path}")
def save_checkpoint(self, tag=None):
if not tag:

View File

@ -63,9 +63,6 @@ class Engine(DeepSpeedEngine):
self.max_nan_losses = 8
def freeze(self, freeze_all=True):
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
# freeze non-LoRA params if requested
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
for name, param in self.module.named_parameters():
@ -75,6 +72,9 @@ class Engine(DeepSpeedEngine):
self._frozen_params.add(param)
return
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
for name, param in self.module.named_parameters():
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
param.requires_grad_(False)

View File

@ -20,6 +20,8 @@ from tqdm import trange
from ..emb.qnt import trim
from .lora import enable_lora
class AR_NAR(Base):
@property
def capabilities(self) -> list[str]:
@ -27,6 +29,12 @@ class AR_NAR(Base):
return self.config.capabilities
return cfg.model.capabilities
@property
def quant_level_range(self) -> list[int]:
if hasattr(self, "config") and self.config.rvq_level_range:
return self.config.rvq_level_range
return [ 0 if self.causal else 1, self.n_resp_levels ]
@property
def causal(self):
return "ar" in self.capabilities
@ -153,7 +161,7 @@ class AR_NAR(Base):
task_list = [ sample_task() for _ in range(batch_size) ]
# determines which RVQ level to target per batch
quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ]
quant_level_range = self.quant_level_range
if cfg.experimental:
# makes higher levels less likely
@ -212,6 +220,9 @@ class AR_NAR(Base):
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
break
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( level ) )
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
inputs = self.inputs(
@ -246,7 +257,7 @@ class AR_NAR(Base):
# filter
"""
if self.arch_type in ["mamba2-hf"]:
if self.arch_type in ["mamba2-hf"] or cfg.lora is not None:
for batch_index, resp in enumerate(resps_list):
for i, token in enumerate(resp):
if token >= 1024:
@ -255,9 +266,15 @@ class AR_NAR(Base):
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
if cfg.lora is not None:
enable_lora( self )
return prev_list
# is AR
if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) )
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool()

View File

@ -10,7 +10,7 @@ import math
from typing import Optional, List
# to-do: set cfg to decide
USE_PARAMETRIZATION = True
USE_PARAMETRIZATION = False
# LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights
@ -27,7 +27,7 @@ class Linear(nn.Linear):
alpha: int = 1,
dropout: float = 0.1,
merge_weights: bool = True,
merge_weights: bool = False,
**kwargs,
):
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
@ -37,6 +37,7 @@ class Linear(nn.Linear):
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.merge_weights = merge_weights
self.merged = False
self.enabled = True
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
@ -67,7 +68,7 @@ class Linear(nn.Linear):
self.merged = True
def forward(self, x: torch.Tensor):
if not self.merged:
if not self.merged and self.enabled:
result = F.linear(x, self.weight, bias=self.bias)
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
@ -120,7 +121,6 @@ class ParameterizedLinear(nn.Module):
def forward(self, x: torch.Tensor):
if self.enabled:
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
return x
@classmethod
@ -173,6 +173,16 @@ def apply_lora( model, register = True, merge = False, policy = None, **kwargs )
return model
def enable_lora( model, mode = True ):
for name, module in model.named_modules():
if not isinstance( module, ParameterizedLinear if USE_PARAMETRIZATION else Linear ):
continue
module.enabled = mode
return model
def disable_lora( model ):
return enable_lora( model, False )
def freeze_non_lora_weights( model ):
for name, param in model.named_parameters():
param.requires_grad_('lora_' in name)