when you already had these ideas to stabilize training but you just ignored them

This commit is contained in:
mrq 2025-02-27 23:39:20 -06:00
parent 0a45c9c042
commit f4f435d7f5
3 changed files with 86 additions and 22 deletions

View File

@ -277,6 +277,7 @@ class ModelExperimentalSettings:
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
# * NAR-demask would semi-doubly train for AR
# * the model wouldn't also need to learn when to predict the token in place
audio_encoder_mode: str = "sum" # audio encoder mode for version >= 7, because I cannot make up my damn mind
# these technically should be as hyperparameters
# performs token dropout to compensate for errors
@ -737,6 +738,7 @@ class Trainer:
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
detect_grad_anomaly: bool = False # torch.autograd.set_detect_anomaly
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
gc_mode: str | None = None # deprecated, but marks when to do GC

View File

@ -88,20 +88,42 @@ class AudioEncoder(nn.Module):
n_tokens: int,
n_levels: int,
token_dim: int,
enc_mode: str = "sum"
enc_mode: str = "sum",
l_weights: list[float] | None = None,
):
super().__init__()
self.enc_mode = enc_mode
d_ffn = 4
if not l_weights:
l_weights = [1 for _ in range(n_levels)]
if enc_mode == "sum":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = None
self.weights = nn.Parameter(torch.tensor(l_weights))
elif enc_mode == "sub_interleave":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
self.proj = None
elif enc_mode == "interleave":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
#self.proj = nn.Linear(n_levels * token_dim, token_dim)
self.proj = nn.Sequential(
nn.Linear(n_levels * token_dim, d_ffn * token_dim),
nn.GELU(),
nn.Linear(d_ffn * token_dim, token_dim)
)
elif enc_mode == "attn":
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
self.cross_attn = nn.MultiheadAttention(embed_dim=token_dim,num_heads=n_levels,dropout=0.1)
self.proj = nn.Sequential(
nn.Linear(n_levels * token_dim, d_ffn * token_dim),
nn.GELU(),
nn.Linear(d_ffn * token_dim, token_dim)
)
for emb in self.embs:
nn.init.normal_(emb.weight, mean=0.0, std=0.02)
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
# empty
@ -114,12 +136,26 @@ class AudioEncoder(nn.Module):
# old way
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
if self.enc_mode == "sum":
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
weights = F.softmax( self.weights, dim=0 )
x = sum([ weights[l] * emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
# attention-based crunge
elif self.enc_mode == "attn":
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
attn, _ = self.cross_attn(
x.permute(1, 0, 2),
x.permute(1, 0, 2),
x.permute(1, 0, 2),
)
attn = attn.permute(1, 0, 2)
x = x + attn
x = x.view(x.shape[0], -1)
# x = attn.reshape(x.shape[0], -1)
# encode by interleaving embeddings into one "token"
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
else:
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
x = x.view(x.shape[0], -1)
if self.proj is not None:
x = self.proj(x)
@ -207,6 +243,7 @@ class Base_V2(nn.Module):
if not attention:
attention = config.attention if config is not None else "auto"
n_resp_levels = config.resp_levels if config is not None else 8
attention_backend = attention
unified_position_ids = config.experimental.unified_position_ids if config is not None else True
noncausal_masks = config.experimental.noncausal_masks if config is not None else False
@ -218,6 +255,8 @@ class Base_V2(nn.Module):
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
predict_causally = config.experimental.predict_causally if config is not None else False
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
audio_encoder_mode = config.experimental.audio_encoder_mode if config is not None else "sum"
audio_level_weights = [1.0 / (i + 1) for i in range(n_resp_levels)] # to-do: find the weights for FSQ
n_vocab = 256
n_tasks = config.tasks if config is not None else 8
@ -283,6 +322,7 @@ class Base_V2(nn.Module):
self.masking_ratio = masking_ratio
self.ignore_inputs_for_loss = ignore_inputs_for_loss
self.noncausal_masks = noncausal_masks
self.audio_level_weights = audio_level_weights
self.sep = nn.Parameter(torch.randn(d_model))
@ -302,17 +342,23 @@ class Base_V2(nn.Module):
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
enc_mode=audio_encoder_mode,
l_weights=audio_level_weights,
)
else:
self.proms_emb = AudioEncoder(
n_tokens=n_audio_tokens,
n_levels=self.n_resp_levels,
token_dim=d_model,
enc_mode=audio_encoder_mode,
l_weights=audio_level_weights,
)
self.resps_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
n_levels=self.n_resp_levels,
token_dim=d_model,
enc_mode=audio_encoder_mode,
l_weights=audio_level_weights,
)
self.audio_decoder = AudioDecoder(
@ -747,6 +793,7 @@ class Base_V2(nn.Module):
device = logits[0].device
batch_size = len(logits)
classifier_levels = self.get_input( inputs, "classifier_level" )
level_weights = self.audio_level_weights
# handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ):
@ -755,7 +802,7 @@ class Base_V2(nn.Module):
return input
def _calc_loss( logit, sequence, causal = True ):
def _calc_loss( logit, sequence, causal = True, level = None ):
# filter tokens that exceed the vocab size
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
# drop if all tokens are ignored
@ -769,7 +816,8 @@ class Base_V2(nn.Module):
sequence = sequence[..., l:] # ...predicts token n + 1
# flatten batch
if sequence.dim() > 1:
parallel = sequence.dim() > 1
if parallel:
logit = logit.reshape(-1, logit.shape[-1])
sequence = sequence.reshape(-1)
@ -777,7 +825,11 @@ class Base_V2(nn.Module):
metrics = None
if compute_hard_loss:
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction='mean' if not parallel else 'none' ) * (level_weights[level] if level is not None and not parallel else 1)
# manually weigh each level
if parallel:
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_weights, device=device)
if compute_acc:
accuracy_metric = MulticlassAccuracy(
@ -875,9 +927,6 @@ class Base_V2(nn.Module):
if logits[batch_index].dim() < 3:
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
if name == "resp":
name = f'{name}[{quant_level}]'
elif not self.resp_parallel_training:
# cringe way to deduce "requested" level
level = quant_level
@ -885,24 +934,35 @@ class Base_V2(nn.Module):
if classifier_level.endswith(f':{i}:{i}'):
level = i
break
"""
if name == "resp":
name = f'{name}[{level}]'
"""
sequence = token if token.dim() <= 1 else token[:, level]
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
else:
sequence = token.t()
sequence = token.t()
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
for level in enumerate(self.n_resp_levels):
loss_key = f'{name}[{level}].nll'
if loss_key not in loss:
loss[loss_key] = []
loss[loss_key].append( nll[level] * loss_factor )
nll = None
loss_key = f'{name}.nll'
acc_key = f'{name}.acc'
if nll is not None:
if f'{name}.nll' not in loss:
loss[f'{name}.nll'] = []
loss[f"{name}.nll"].append( nll * loss_factor )
if loss_key not in loss:
loss[loss_key] = []
loss[loss_key].append( nll * loss_factor )
if metrics is not None:
if f'{name}.acc' not in stats:
stats[f'{name}.acc'] = []
stats[f"{name}.acc"].append( metrics )
if acc_key not in stats:
stats[acc_key] = []
stats[acc_key].append( metrics )
# add to list
else:
target.append( token )
@ -922,7 +982,7 @@ class Base_V2(nn.Module):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
else:
nlls = []
accs = []
@ -930,7 +990,7 @@ class Base_V2(nn.Module):
for level, logit in enumerate( logits[batch_index] ):
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
nll, metrics = _calc_loss( logit, sequence, causal )
nll, metrics = _calc_loss( logit, sequence, causal, level )
if nll:
nlls.append( nll )

View File

@ -180,7 +180,9 @@ def train(
break
#batch = to_device(batch, torch.cuda.current_device())
stats = engines.step(batch=batch, feeder=train_feeder)
with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly):
stats = engines.step(batch=batch, feeder=train_feeder)
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
elapsed_time = stats.get("elapsed_time", 0)