training + LoRA training works? (keeps OOMing after a step)
This commit is contained in:
parent
d7b63d2f70
commit
7aae9d48ab
|
@ -175,6 +175,7 @@ class Dataset:
|
|||
class Model:
|
||||
name: str = "" # vanity name for the model
|
||||
training: bool = False
|
||||
frozen_params: list[str] = field(default_factory=lambda: []) # frozen parameters that are not updated when training
|
||||
|
||||
def get(self, name=None):
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
@ -190,7 +191,7 @@ class Model:
|
|||
|
||||
@property
|
||||
def lora_policy(self):
|
||||
include = ["model"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
|
||||
include = ["gpt"] # by default only adapt the main model (not embeddings nor classifier/output projection/LM head/whatever)
|
||||
exclude = []
|
||||
|
||||
return dict(include=include, exclude=exclude)
|
||||
|
@ -202,6 +203,8 @@ class LoRA:
|
|||
rank: int = 8 # rank for the LoRA
|
||||
alpha: int = 16 # rank for the LoRA
|
||||
training: bool = True #
|
||||
parametrize: bool = False #
|
||||
module: str = "linear" # linear | conv1d
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
|
|
@ -497,9 +497,13 @@ class Dataset(_Dataset):
|
|||
if key not in cfg.hdf5:
|
||||
raise RuntimeError(f'Key of Path ({path}) not in HDF5: {key}')
|
||||
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
mel = cfg.hdf5[key]["audio"][:]
|
||||
latents = cfg.hdf5[key]["latents"][:]
|
||||
try:
|
||||
text = cfg.hdf5[key]["text"][:]
|
||||
mel = cfg.hdf5[key]["audio"][:]
|
||||
latents = cfg.hdf5[key]["latents"][:]
|
||||
except Exception as e:
|
||||
print( key, cfg.hdf5[key].keys() )
|
||||
raise e
|
||||
|
||||
text = torch.from_numpy(text).to(self.text_dtype)
|
||||
mel = torch.from_numpy(mel).to(torch.int16)
|
||||
|
|
|
@ -56,9 +56,11 @@ def load_engines(training=True):
|
|||
model.model = ml.replace_embedding( model.model )
|
||||
|
||||
for lora in cfg.loras:
|
||||
model.model = apply_lora( model.model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy )
|
||||
if hasattr(model, "gpt"):
|
||||
#model.gpt = apply_lora( model.gpt, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, parametrize = lora.parametrize )
|
||||
model = apply_lora( model, rank = lora.rank, alpha = lora.alpha, policy = model.config.lora_policy, use_parametrize = lora.parametrize )
|
||||
|
||||
if backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
if not inferencing and (backend == "local" or (backend == "deepspeed" and cfg.hyperparameters.torch_optimizer)):
|
||||
optimizer_class = None
|
||||
scheduler_class = None
|
||||
|
||||
|
@ -124,7 +126,7 @@ def load_engines(training=True):
|
|||
loads_state_dict = True
|
||||
|
||||
stats = None
|
||||
if loads_state_dict:
|
||||
if loads_state_dict and load_path.exists():
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
|
|
|
@ -59,8 +59,8 @@ def get_model(config, training=True):
|
|||
name = config.name
|
||||
|
||||
model = load_model(config.name)
|
||||
|
||||
config.training = False
|
||||
config.training = "autoregressive" in config.name
|
||||
model.config = config
|
||||
|
||||
print(f"{name} ({next(model.parameters()).dtype}): {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters")
|
||||
|
||||
|
|
|
@ -4,18 +4,17 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
import torch.nn.utils.parametrize as parametrize
|
||||
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
# to-do: set cfg to decide
|
||||
USE_PARAMETRIZATION = True
|
||||
|
||||
# LoRA Linear for replacement
|
||||
# Pros: simple, just needs to reuse the replace_linear and copy weights
|
||||
# Cons: does not work with other Linears (bnb, bitnet, te's fp8, etc), cannot apply multiple LoRAs (although for audio why would you)
|
||||
class Linear(nn.Linear):
|
||||
class LoRALinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
|
@ -27,7 +26,7 @@ class Linear(nn.Linear):
|
|||
alpha: int = 1,
|
||||
|
||||
dropout: float = 0.1,
|
||||
merge_weights: bool = True,
|
||||
merge_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(in_features=in_features, out_features=out_features, bias=bias, **kwargs)
|
||||
|
@ -37,6 +36,7 @@ class Linear(nn.Linear):
|
|||
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else lambda x: x
|
||||
self.merge_weights = merge_weights
|
||||
self.merged = False
|
||||
self.enabled = True
|
||||
|
||||
self.lora_B = nn.Parameter( self.weight.new_zeros( (out_features, rank) ) )
|
||||
self.lora_A = nn.Parameter( self.weight.new_zeros( (rank, in_features) ) )
|
||||
|
@ -67,7 +67,7 @@ class Linear(nn.Linear):
|
|||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if not self.merged:
|
||||
if not self.merged and self.enabled:
|
||||
result = F.linear(x, self.weight, bias=self.bias)
|
||||
result += (self.dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
|
||||
return result
|
||||
|
@ -85,7 +85,7 @@ class Linear(nn.Linear):
|
|||
# Uses parametrization to inject LoRA weights
|
||||
# Pros: should work with any Linears
|
||||
# Cons: TBD
|
||||
class ParameterizedLinear(nn.Module):
|
||||
class ParameterizedLoRA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
|
@ -120,7 +120,6 @@ class ParameterizedLinear(nn.Module):
|
|||
def forward(self, x: torch.Tensor):
|
||||
if self.enabled:
|
||||
return x + torch.matmul(self.lora_B, self.dropout(self.lora_A)).view(x.shape) * self.scaling
|
||||
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
|
@ -133,10 +132,21 @@ class ParameterizedLinear(nn.Module):
|
|||
# M$'s LoRA class arranges things to where this isn't necessary
|
||||
return cls( in_features = layer.out_features, out_features = layer.in_features, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
@classmethod
|
||||
def from_conv1d( cls, layer, device = None, dtype = None, **kwargs ):
|
||||
if device is None:
|
||||
device = layer.weight.device
|
||||
if dtype is None:
|
||||
dtype = layer.weight.dtype
|
||||
|
||||
in_channels, out_channels = layer.weight.shape
|
||||
# swap because we're feeding the output as our input
|
||||
# M$'s LoRA class arranges things to where this isn't necessary
|
||||
return cls( in_features = out_channels, out_features = in_channels, bias = layer.bias is not None, **kwargs ).to(device=device, dtype=dtype)
|
||||
|
||||
def passes_policy( policy, name ):
|
||||
if policy is None:
|
||||
return True
|
||||
|
||||
if "exclude" in policy:
|
||||
for term in policy["exclude"]:
|
||||
if term in name:
|
||||
|
@ -149,30 +159,50 @@ def passes_policy( policy, name ):
|
|||
|
||||
return False
|
||||
|
||||
|
||||
def apply_lora( model, register = True, merge = False, policy = None, **kwargs ):
|
||||
def apply_lora( model, register = True, merge = False, policy = None, use_parametrize = False, **kwargs ):
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
|
||||
klass = Linear
|
||||
target = nn.Linear
|
||||
|
||||
device = next(model.parameters()).device
|
||||
dtype = next(model.parameters()).dtype
|
||||
modules = [ k.split('.') for k, m in model.named_modules() if isinstance(m, target) and not isinstance(m, klass) and passes_policy( policy, k ) ]
|
||||
modules = [ k.split('.') for k, m in model.named_modules() if passes_policy( policy, k ) ]
|
||||
|
||||
for *parent, k in modules:
|
||||
name = '.'.join(parent)
|
||||
layer = getattr( model.get_submodule(name), k )
|
||||
|
||||
if USE_PARAMETRIZATION:
|
||||
parametrize.register_parametrization( layer, "weight", ParameterizedLinear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
|
||||
# parametrize.remove_parametrizations( layer, "weight", leave_parametrized=merge )
|
||||
if isinstance( layer, nn.Linear ):
|
||||
target = nn.Linear
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRALinear
|
||||
replacer = klass.from_linear
|
||||
elif isinstance( layer, nn.Conv1d ):
|
||||
target = nn.Conv1d
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
||||
replacer = klass.from_conv1d
|
||||
elif isinstance( layer, Conv1D ):
|
||||
target = Conv1D
|
||||
klass = ParameterizedLoRA if use_parametrize else LoRAConv1d
|
||||
replacer = klass.from_conv1d
|
||||
else:
|
||||
setattr( model.get_submodule(name), k, Linear.from_linear( layer, device=device, dtype=dtype, **kwargs ) )
|
||||
continue
|
||||
|
||||
replacement = replacer( layer, device=device, dtype=dtype, **kwargs )
|
||||
|
||||
if use_parametrize:
|
||||
parametrize.register_parametrization( layer, "weight", replacement )
|
||||
else:
|
||||
setattr( model.get_submodule(name), k, replacement )
|
||||
|
||||
return model
|
||||
|
||||
def enable_lora( model, mode = True ):
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance( module, ParameterizedLoRA ) and not isinstance( module, LoRALinear ):
|
||||
continue
|
||||
module.enabled = mode
|
||||
return model
|
||||
|
||||
def disable_lora( model ):
|
||||
return enable_lora( model, False )
|
||||
|
||||
def freeze_non_lora_weights( model ):
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad_('lora_' in name)
|
||||
|
|
|
@ -273,9 +273,14 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text
|
|||
n_embd=model_dim,
|
||||
n_layer=layers,
|
||||
n_head=heads,
|
||||
gradient_checkpointing=checkpointing,
|
||||
use_cache=not checkpointing)
|
||||
gpt = GPT2Model(gpt_config)
|
||||
|
||||
if checkpointing:
|
||||
gpt.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
|
||||
use_reentrant=False
|
||||
))
|
||||
|
||||
# Override the built in positional embeddings
|
||||
del gpt.wpe
|
||||
gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
|
||||
|
|
|
@ -29,12 +29,14 @@ mel_stft_loss = auraloss.freq.MelSTFTLoss(cfg.sample_rate, device="cpu")
|
|||
|
||||
def train_feeder(engine, batch):
|
||||
with torch.autocast("cuda", dtype=cfg.trainer.dtype, enabled=cfg.trainer.amp):
|
||||
device = batch["text"][0].device
|
||||
batch_size = len(batch["text"])
|
||||
|
||||
conditioning_latents = pad_sequence([ latents[0] for latents in batch["latents"] ], batch_first = True)
|
||||
text_inputs = pad_sequence([ text for text in batch["text"] ], batch_first = True)
|
||||
text_lengths = pad_sequence([ text.shape[0] for text in batch["text"] ], batch_first = True)
|
||||
mel_codes = pad_sequence([ code for codes in batch["mel"] ], batch_first = True)
|
||||
wav_lengths = pad_sequence([ length for length in batch["wav_length"] ], batch_first = True)
|
||||
|
||||
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
|
||||
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True)
|
||||
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
|
||||
|
||||
engine.forward(conditioning_latents, text_inputs, text_lengths, mel_codes, wav_lengths)
|
||||
|
||||
|
@ -48,7 +50,7 @@ def train_feeder(engine, batch):
|
|||
stats |= {k: v.item() for k, v in stat.items()}
|
||||
|
||||
engine.tokens_processed += sum([ text.shape[0] for text in batch["text"] ])
|
||||
engine.tokens_processed += sum([ resps.shape[0] for resps in batch["resps"] ])
|
||||
engine.tokens_processed += sum([ mel.shape[-1] for mel in batch["mel"] ])
|
||||
|
||||
return loss, stats
|
||||
|
||||
|
@ -76,9 +78,11 @@ def run_eval(engines, eval_name, dl):
|
|||
ref_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
prom_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
"""
|
||||
ref_audio, sr = qnt.decode_to_file(ref, ref_path)
|
||||
hyp_audio, sr = qnt.decode_to_file(hyp, hyp_path)
|
||||
prom_audio, sr = qnt.decode_to_file(prom, prom_path)
|
||||
"""
|
||||
|
||||
# pseudo loss calculation since we don't get the logits during eval
|
||||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
|
@ -137,10 +141,10 @@ def train():
|
|||
print(traceback.format_exc())
|
||||
|
||||
engines.train()
|
||||
qnt.unload_model()
|
||||
#qnt.unload_model()
|
||||
do_gc()
|
||||
|
||||
qnt.unload_model()
|
||||
#qnt.unload_model()
|
||||
|
||||
if args.eval:
|
||||
return eval_fn(engines=trainer.load_engines())
|
||||
|
|
Loading…
Reference in New Issue
Block a user