training + LoRA training works? (keeps OOMing after a step)

This commit is contained in:
mrq 2024-06-18 13:28:50 -05:00
parent d7b63d2f70
commit 7aae9d48ab
7 changed files with 86 additions and 38 deletions

View File

@ -175,6 +175,7 @@ class Dataset:
class Model:
name: str = "" # vanity name for the model
training: bool = False
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
def get(self, name=None):
return [ self ] if not name or self.name == name else []
@ -190,7 +191,7 @@ class Model:
@property
def lora_policy(self):
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
include = ["gpt"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
exclude = []
return dict(include=include, exclude=exclude)
@ -202,6 +203,8 @@ class LoRA:
rank: int = 8 # rank for the LoRA
alpha: int = 16 # rank for the LoRA
training: bool = True #
parametrize: bool = False #
module: str = "linear" # linear | conv1d
@property
def full_name(self):

View File

@ -497,9 +497,13 @@ class Dataset(_Dataset):
if key not in cfg.hdf5:
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
text = cfg.hdf5[key]["text"][:]
mel = cfg.hdf5[key]["audio"][:]
latents = cfg.hdf5[key]["latents"][:]
try:
text = cfg.hdf5[key]["text"][:]
mel = cfg.hdf5[key]["audio"][:]
latents = cfg.hdf5[key]["latents"][:]
except Exception as e:
print( key, cfg.hdf5[key].keys() )
raise e
text = torch.from_numpy(text).to(self.text_dtype)
mel = torch.from_numpy(mel).to(torch.int16)

View File

@ -56,9 +56,11 @@ def load_engines(training=True):
model.model = ml.replace_embedding( model.model )
for lora in cfg.loras:
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy )
if hasattr(model, "gpt"):
#model.gpt = apply_lora( model.gpt, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, parametrize = lora.parametrize )
model = apply_lora( model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
optimizer_class = None
scheduler_class = None
@ -124,7 +126,7 @@ def load_engines(training=True):
loads_state_dict = True
stats = None
if loads_state_dict:
if loads_state_dict and load_path.exists():
state = torch.load(load_path, map_location=torch.device(cfg.device))
# state dict is not just the module, extract the extra trainer details

View File

@ -59,8 +59,8 @@ def get_model(config, training=True):
name = config.name
model = load_model(config.name)
config.training = False
config.training = "autoregressive" in config.name
model.config = config
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")

View File

@ -4,18 +4,17 @@ import torch
import torch.nn.functional as F
import torch.nn.utils.parametrize as parametrize
from transformers.pytorch_utils import Conv1D
from torch import Tensor, nn
import math
from typing import Optional, List
# to-do: set cfg to decide
USE_PARAMETRIZATION = True
# LoRA Linear for replacement
# Pros: simple, just needs to reuse the replace_linear and copy weights
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
class Linear(nn.Linear):
class LoRALinear(nn.Linear):
def __init__(
self,
@ -27,7 +26,7 @@ class Linear(nn.Linear):
alpha: int = 1,
dropout: float = 0.1,
merge_weights: bool = True,
merge_weights: bool = False,
**kwargs,
):
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
@ -37,6 +36,7 @@ class Linear(nn.Linear):
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
self.merge_weights = merge_weights
self.merged = False
self.enabled = True
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
@ -67,7 +67,7 @@ class Linear(nn.Linear):
self.merged = True
def forward(self, x: torch.Tensor):
if not self.merged:
if not self.merged and self.enabled:
result = F.linear(x, self.weight, bias=self.bias)
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
return result
@ -85,7 +85,7 @@ class Linear(nn.Linear):
# Uses parametrization to inject LoRA weights
# Pros: should work with any Linears
# Cons: TBD
class ParameterizedLinear(nn.Module):
class ParameterizedLoRA(nn.Module):
def __init__(
self,
@ -120,7 +120,6 @@ class ParameterizedLinear(nn.Module):
def forward(self, x: torch.Tensor):
if self.enabled:
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
return x
@classmethod
@ -133,10 +132,21 @@ class ParameterizedLinear(nn.Module):
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
@classmethod
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
if device is None:
device = layer.weight.device
if dtype is None:
dtype = layer.weight.dtype
in_channels, out_channels = layer.weight.shape
# swap because we're feeding the output as our input
# M$'s LoRA class arranges things to where this isn't necessary
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
def passes_policy( policy, name ):
if policy is None:
return True
if "exclude" in policy:
for term in policy["exclude"]:
if term in name:
@ -149,30 +159,50 @@ def passes_policy( policy, name ):
return False
def apply_lora( model, register = True, merge = False, policy = None, **kwargs ):
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
klass = Linear
target = nn.Linear
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ]
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
for *parent, k in modules:
name = '.'.join(parent)
layer = getattr( model.get_submodule(name), k )
if USE_PARAMETRIZATION:
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
# parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
if isinstance( layer, nn.Linear ):
target = nn.Linear
klass = ParameterizedLoRA if use_parametrize else LoRALinear
replacer = klass.from_linear
elif isinstance( layer, nn.Conv1d ):
target = nn.Conv1d
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
replacer = klass.from_conv1d
elif isinstance( layer, Conv1D ):
target = Conv1D
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
replacer = klass.from_conv1d
else:
setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
continue
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
if use_parametrize:
parametrize.register_parametrization( layer, "weight", replacement )
else:
setattr( model.get_submodule(name), k, replacement )
return model
def enable_lora( model, mode = True ):
for name, module in model.named_modules():
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
continue
module.enabled = mode
return model
def disable_lora( model ):
return enable_lora( model, False )
def freeze_non_lora_weights( model ):
for name, param in model.named_parameters():
param.requires_grad_('lora_' in name)

View File

@ -273,9 +273,14 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
n_embd=model_dim,
n_layer=layers,
n_head=heads,
gradient_checkpointing=checkpointing,
use_cache=not checkpointing)
gpt = GPT2Model(gpt_config)
if checkpointing:
gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
use_reentrant=False
))
# Override the built in positional embeddings
del gpt.wpe
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)

View File

@ -29,12 +29,14 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
def train_feeder(engine, batch):
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
device = batch["text"][0].device
batch_size = len(batch["text"])
conditioning_latents = pad_sequence([ latents[0] for latents in batch["latents"] ], batch_first = True)
text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True)
mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True)
wav_lengths = pad_sequence([ length for length in batch["wav_length"] ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True)
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths)
@ -48,7 +50,7 @@ def train_feeder(engine, batch):
stats |= {k: v.item() for k, v in stat.items()}
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
engine.tokens_processed += sum([ mel.shape[-1] for mel in batch["mel"] ])
return loss, stats
@ -76,9 +78,11 @@ def run_eval(engines, eval_name, dl):
ref_path.parent.mkdir(parents=True, exist_ok=True)
prom_path.parent.mkdir(parents=True, exist_ok=True)
"""
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
"""
# pseudo loss calculation since we don't get the logits during eval
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
@ -137,10 +141,10 @@ def train():
print(traceback.format_exc())
engines.train()
qnt.unload_model()
#qnt.unload_model()
do_gc()
qnt.unload_model()
#qnt.unload_model()
if args.eval:
return eval_fn(engines=trainer.load_engines())