enable LoRA for targetted RVQ levels (to experiment with, seems to help)
This commit is contained in:
parent
7047fcc6e2
commit
7cfb78fa64
21
README.md
21
README.md
|
@ -104,6 +104,27 @@ You can enter `save` to save the state at any time, or `quit` to save and quit t
|
|||
|
||||
The `lr` will also let you adjust the learning rate on the fly. For example: `lr 1.0e-3` will set the learning rate to `0.001`.
|
||||
|
||||
|
||||
### Finetuning
|
||||
|
||||
Finetuning can be done by training the full model, or using a LoRA.
|
||||
|
||||
Finetuning the full model is done the same way as training a model, but be sure to have the weights in the correct spot, as if you're loading them for inferencing.
|
||||
|
||||
For training a LoRA, add the following block to your `config.yaml`:
|
||||
|
||||
```
|
||||
loras:
|
||||
- name : "arbitrary name" # whatever you want
|
||||
rank: 128 # dimensionality of the LoRA
|
||||
alpha: 256 # scaling factor of the LoRA
|
||||
training: True
|
||||
```
|
||||
|
||||
And thats it. Training of the LoRA is done with the same command. Depending on the rank and alpha specified, the loss may be higher than it should, as the LoRA weights are initialized to appropriately random values.
|
||||
|
||||
To export your LoRA weights, run `python3 -m vall_e.export --lora --yaml="./training/config.yaml"`.
|
||||
|
||||
### Plotting Metrics
|
||||
|
||||
Included is a helper script to parse the training metrics. Simply invoke it with, for example: `python3 -m vall_e.plot --yaml="./training/config.yaml"`
|
||||
|
|
|
@ -211,6 +211,7 @@ class Model:
|
|||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
experimental: str | None = None # for now it sets things to be HF compatible
|
||||
kv_heads: int = 0 # MHA or GQA (for supported backends)
|
||||
rvq_level_range: list = field(default_factory=lambda: []) # some cringe to try and limit the RVQ training range
|
||||
|
||||
def get(self, name=None):
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
@ -332,12 +333,17 @@ class LoRA:
|
|||
rank: int = 8 # rank for the LoRA
|
||||
alpha: int = 16 # rank for the LoRA
|
||||
training: bool = True #
|
||||
rvq_levels: list[int] = field(default_factory=lambda: []) # determines RVQ levels to activate the LoRA
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
name = [ self.name, f"r{self.rank}", f"a{self.alpha}" ]
|
||||
return "-".join(name)
|
||||
|
||||
def active_level( self, level ):
|
||||
if not self.rvq_levels:
|
||||
return True
|
||||
return level in self.rvq_levels
|
||||
|
||||
@dataclass()
|
||||
class Hyperparameters:
|
||||
|
|
|
@ -131,6 +131,10 @@ def load_engines(training=True):
|
|||
if "stats" in state:
|
||||
stats = state["stats"]
|
||||
|
||||
# do not load stats if we're training a LoRA
|
||||
if "lora" not in state:
|
||||
stats = None
|
||||
|
||||
if "module" in state:
|
||||
state = state["module"]
|
||||
|
||||
|
|
|
@ -344,7 +344,7 @@ class Engines(dict[str, Engine]):
|
|||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
||||
torch.save(state_dict, save_path)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
print(f"Exported {name} to {save_path}")
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
|
|
|
@ -63,9 +63,6 @@ class Engine(DeepSpeedEngine):
|
|||
self.max_nan_losses = 8
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||
|
||||
# freeze non-LoRA params if requested
|
||||
if not self.hyper_config.frozen_params and not freeze_all and cfg.lora is not None:
|
||||
for name, param in self.module.named_parameters():
|
||||
|
@ -75,6 +72,9 @@ class Engine(DeepSpeedEngine):
|
|||
self._frozen_params.add(param)
|
||||
return
|
||||
|
||||
if self.hyper_config is None or not hasattr(self.hyper_config, "frozen_params"):
|
||||
raise Exception("freeze_all=False yet self.hyper_config.frozen_params is None")
|
||||
|
||||
for name, param in self.module.named_parameters():
|
||||
if (freeze_all and param.requires_grad) or (not freeze_all and name in self.hyper_config.frozen_params):
|
||||
param.requires_grad_(False)
|
||||
|
|
|
@ -20,6 +20,8 @@ from tqdm import trange
|
|||
|
||||
from ..emb.qnt import trim
|
||||
|
||||
from .lora import enable_lora
|
||||
|
||||
class AR_NAR(Base):
|
||||
@property
|
||||
def capabilities(self) -> list[str]:
|
||||
|
@ -27,6 +29,12 @@ class AR_NAR(Base):
|
|||
return self.config.capabilities
|
||||
return cfg.model.capabilities
|
||||
|
||||
@property
|
||||
def quant_level_range(self) -> list[int]:
|
||||
if hasattr(self, "config") and self.config.rvq_level_range:
|
||||
return self.config.rvq_level_range
|
||||
return [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
|
||||
@property
|
||||
def causal(self):
|
||||
return "ar" in self.capabilities
|
||||
|
@ -153,7 +161,7 @@ class AR_NAR(Base):
|
|||
task_list = [ sample_task() for _ in range(batch_size) ]
|
||||
|
||||
# determines which RVQ level to target per batch
|
||||
quant_level_range = [ 0 if self.causal else 1, self.n_resp_levels ]
|
||||
quant_level_range = self.quant_level_range
|
||||
|
||||
if cfg.experimental:
|
||||
# makes higher levels less likely
|
||||
|
@ -212,6 +220,9 @@ class AR_NAR(Base):
|
|||
if level >= max_levels + 1: # min(max_levels + 1, self.n_resp_levels): # commented out to experiment with exceeding trained levels
|
||||
break
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( level ) )
|
||||
|
||||
quant_levels = [ level for _ in range(batch_size) ] # torch.full((len(text_list),), level)
|
||||
|
||||
inputs = self.inputs(
|
||||
|
@ -246,7 +257,7 @@ class AR_NAR(Base):
|
|||
|
||||
# filter
|
||||
"""
|
||||
if self.arch_type in ["mamba2-hf"]:
|
||||
if self.arch_type in ["mamba2-hf"] or cfg.lora is not None:
|
||||
for batch_index, resp in enumerate(resps_list):
|
||||
for i, token in enumerate(resp):
|
||||
if token >= 1024:
|
||||
|
@ -255,9 +266,15 @@ class AR_NAR(Base):
|
|||
|
||||
prev_list = [ torch.cat([rs, r.unsqueeze(-1).to(device)], dim=-1) for rs, r in zip(prev_list, resps_list) ]
|
||||
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self )
|
||||
|
||||
return prev_list
|
||||
|
||||
# is AR
|
||||
if cfg.lora is not None:
|
||||
enable_lora( self, cfg.lora.active_level( 0 ) )
|
||||
|
||||
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
|
||||
stopped = torch.zeros(batch_size, device=device).bool()
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import math
|
|||
from typing import Optional, List
|
||||
|
||||
# to-do: set cfg to decide
|
||||
USE_PARAMETRIZATION = True
|
||||
USE_PARAMETRIZATION = False
|
||||
|
||||
# LoRA Linear for replacement
|
||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||
|
@ -27,7 +27,7 @@ class Linear(nn.Linear):
|
|||
alpha: int = 1,
|
||||
|
||||
dropout: float = 0.1,
|
||||
merge_weights: bool = True,
|
||||
merge_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
|
||||
|
@ -37,6 +37,7 @@ class Linear(nn.Linear):
|
|||
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
||||
self.merge_weights = merge_weights
|
||||
self.merged = False
|
||||
self.enabled = True
|
||||
|
||||
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
|
||||
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
|
||||
|
@ -67,7 +68,7 @@ class Linear(nn.Linear):
|
|||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if not self.merged:
|
||||
if not self.merged and self.enabled:
|
||||
result = F.linear(x, self.weight, bias=self.bias)
|
||||
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
||||
return result
|
||||
|
@ -120,7 +121,6 @@ class ParameterizedLinear(nn.Module):
|
|||
def forward(self, x: torch.Tensor):
|
||||
if self.enabled:
|
||||
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
||||
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
|
@ -173,6 +173,16 @@ def apply_lora( model, register = True, merge = False, policy = None, **kwargs )
|
|||
|
||||
return model
|
||||
|
||||
def enable_lora( model, mode = True ):
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance( module, ParameterizedLinear if USE_PARAMETRIZATION else Linear ):
|
||||
continue
|
||||
module.enabled = mode
|
||||
return model
|
||||
|
||||
def disable_lora( model ):
|
||||
return enable_lora( model, False )
|
||||
|
||||
def freeze_non_lora_weights( model ):
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad_('lora_' in name)
|
||||
|
|
Loading…
Reference in New Issue
Block a user