when you already had these ideas to stabilize training but you just ignored them
This commit is contained in:
parent
0a45c9c042
commit
f4f435d7f5
|
@ -277,6 +277,7 @@ class ModelExperimentalSettings:
|
|||
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
|
||||
# * NAR-demask would semi-doubly train for AR
|
||||
# * the model wouldn't also need to learn when to predict the token in place
|
||||
audio_encoder_mode: str = "sum" # audio encoder mode for version >= 7, because I cannot make up my damn mind
|
||||
|
||||
# these technically should be as hyperparameters
|
||||
# performs token dropout to compensate for errors
|
||||
|
@ -737,6 +738,7 @@ class Trainer:
|
|||
|
||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
||||
detect_grad_anomaly: bool = False # torch.autograd.set_detect_anomaly
|
||||
|
||||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||
|
|
|
@ -88,20 +88,42 @@ class AudioEncoder(nn.Module):
|
|||
n_tokens: int,
|
||||
n_levels: int,
|
||||
token_dim: int,
|
||||
enc_mode: str = "sum"
|
||||
enc_mode: str = "sum",
|
||||
l_weights: list[float] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.enc_mode = enc_mode
|
||||
|
||||
d_ffn = 4
|
||||
if not l_weights:
|
||||
l_weights = [1 for _ in range(n_levels)]
|
||||
|
||||
if enc_mode == "sum":
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = None
|
||||
self.weights = nn.Parameter(torch.tensor(l_weights))
|
||||
elif enc_mode == "sub_interleave":
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
|
||||
self.proj = None
|
||||
elif enc_mode == "interleave":
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
||||
#self.proj = nn.Linear(n_levels * token_dim, token_dim)
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(n_levels * token_dim, d_ffn * token_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_ffn * token_dim, token_dim)
|
||||
)
|
||||
elif enc_mode == "attn":
|
||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||
self.cross_attn = nn.MultiheadAttention(embed_dim=token_dim,num_heads=n_levels,dropout=0.1)
|
||||
self.proj = nn.Sequential(
|
||||
nn.Linear(n_levels * token_dim, d_ffn * token_dim),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_ffn * token_dim, token_dim)
|
||||
)
|
||||
|
||||
for emb in self.embs:
|
||||
nn.init.normal_(emb.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||
# empty
|
||||
|
@ -114,12 +136,26 @@ class AudioEncoder(nn.Module):
|
|||
# old way
|
||||
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
|
||||
if self.enc_mode == "sum":
|
||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
weights = F.softmax( self.weights, dim=0 )
|
||||
x = sum([ weights[l] * emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||
# attention-based crunge
|
||||
elif self.enc_mode == "attn":
|
||||
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||
attn, _ = self.cross_attn(
|
||||
x.permute(1, 0, 2),
|
||||
x.permute(1, 0, 2),
|
||||
x.permute(1, 0, 2),
|
||||
)
|
||||
attn = attn.permute(1, 0, 2)
|
||||
x = x + attn
|
||||
x = x.view(x.shape[0], -1)
|
||||
# x = attn.reshape(x.shape[0], -1)
|
||||
# encode by interleaving embeddings into one "token"
|
||||
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
|
||||
else:
|
||||
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||
x = x.view(x.shape[0], -1)
|
||||
|
||||
if self.proj is not None:
|
||||
x = self.proj(x)
|
||||
|
||||
|
@ -207,6 +243,7 @@ class Base_V2(nn.Module):
|
|||
if not attention:
|
||||
attention = config.attention if config is not None else "auto"
|
||||
|
||||
n_resp_levels = config.resp_levels if config is not None else 8
|
||||
attention_backend = attention
|
||||
unified_position_ids = config.experimental.unified_position_ids if config is not None else True
|
||||
noncausal_masks = config.experimental.noncausal_masks if config is not None else False
|
||||
|
@ -218,6 +255,8 @@ class Base_V2(nn.Module):
|
|||
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
||||
predict_causally = config.experimental.predict_causally if config is not None else False
|
||||
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
||||
audio_encoder_mode = config.experimental.audio_encoder_mode if config is not None else "sum"
|
||||
audio_level_weights = [1.0 / (i + 1) for i in range(n_resp_levels)] # to-do: find the weights for FSQ
|
||||
|
||||
n_vocab = 256
|
||||
n_tasks = config.tasks if config is not None else 8
|
||||
|
@ -283,6 +322,7 @@ class Base_V2(nn.Module):
|
|||
self.masking_ratio = masking_ratio
|
||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||
self.noncausal_masks = noncausal_masks
|
||||
self.audio_level_weights = audio_level_weights
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -302,17 +342,23 @@ class Base_V2(nn.Module):
|
|||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
enc_mode=audio_encoder_mode,
|
||||
l_weights=audio_level_weights,
|
||||
)
|
||||
else:
|
||||
self.proms_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens,
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
enc_mode=audio_encoder_mode,
|
||||
l_weights=audio_level_weights,
|
||||
)
|
||||
self.resps_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
enc_mode=audio_encoder_mode,
|
||||
l_weights=audio_level_weights,
|
||||
)
|
||||
|
||||
self.audio_decoder = AudioDecoder(
|
||||
|
@ -747,6 +793,7 @@ class Base_V2(nn.Module):
|
|||
device = logits[0].device
|
||||
batch_size = len(logits)
|
||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||
level_weights = self.audio_level_weights
|
||||
|
||||
# handles tasks where the prompt has task tokens injected in the middle
|
||||
def prompt_input_to_token( input, quant_level ):
|
||||
|
@ -755,7 +802,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
return input
|
||||
|
||||
def _calc_loss( logit, sequence, causal = True ):
|
||||
def _calc_loss( logit, sequence, causal = True, level = None ):
|
||||
# filter tokens that exceed the vocab size
|
||||
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||
# drop if all tokens are ignored
|
||||
|
@ -769,7 +816,8 @@ class Base_V2(nn.Module):
|
|||
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||
|
||||
# flatten batch
|
||||
if sequence.dim() > 1:
|
||||
parallel = sequence.dim() > 1
|
||||
if parallel:
|
||||
logit = logit.reshape(-1, logit.shape[-1])
|
||||
sequence = sequence.reshape(-1)
|
||||
|
||||
|
@ -777,7 +825,11 @@ class Base_V2(nn.Module):
|
|||
metrics = None
|
||||
|
||||
if compute_hard_loss:
|
||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
|
||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction='mean' if not parallel else 'none' ) * (level_weights[level] if level is not None and not parallel else 1)
|
||||
|
||||
# manually weigh each level
|
||||
if parallel:
|
||||
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_weights, device=device)
|
||||
|
||||
if compute_acc:
|
||||
accuracy_metric = MulticlassAccuracy(
|
||||
|
@ -875,9 +927,6 @@ class Base_V2(nn.Module):
|
|||
|
||||
if logits[batch_index].dim() < 3:
|
||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||
|
||||
if name == "resp":
|
||||
name = f'{name}[{quant_level}]'
|
||||
elif not self.resp_parallel_training:
|
||||
# cringe way to deduce "requested" level
|
||||
level = quant_level
|
||||
|
@ -885,24 +934,35 @@ class Base_V2(nn.Module):
|
|||
if classifier_level.endswith(f':{i}:{i}'):
|
||||
level = i
|
||||
break
|
||||
"""
|
||||
|
||||
if name == "resp":
|
||||
name = f'{name}[{level}]'
|
||||
"""
|
||||
|
||||
sequence = token if token.dim() <= 1 else token[:, level]
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||
else:
|
||||
sequence = token.t()
|
||||
sequence = token.t()
|
||||
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||
|
||||
for level in enumerate(self.n_resp_levels):
|
||||
loss_key = f'{name}[{level}].nll'
|
||||
if loss_key not in loss:
|
||||
loss[loss_key] = []
|
||||
loss[loss_key].append( nll[level] * loss_factor )
|
||||
|
||||
nll = None
|
||||
|
||||
loss_key = f'{name}.nll'
|
||||
acc_key = f'{name}.acc'
|
||||
if nll is not None:
|
||||
if f'{name}.nll' not in loss:
|
||||
loss[f'{name}.nll'] = []
|
||||
loss[f"{name}.nll"].append( nll * loss_factor )
|
||||
if loss_key not in loss:
|
||||
loss[loss_key] = []
|
||||
loss[loss_key].append( nll * loss_factor )
|
||||
|
||||
if metrics is not None:
|
||||
if f'{name}.acc' not in stats:
|
||||
stats[f'{name}.acc'] = []
|
||||
stats[f"{name}.acc"].append( metrics )
|
||||
if acc_key not in stats:
|
||||
stats[acc_key] = []
|
||||
stats[acc_key].append( metrics )
|
||||
# add to list
|
||||
else:
|
||||
target.append( token )
|
||||
|
@ -922,7 +982,7 @@ class Base_V2(nn.Module):
|
|||
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||
else:
|
||||
nlls = []
|
||||
accs = []
|
||||
|
@ -930,7 +990,7 @@ class Base_V2(nn.Module):
|
|||
for level, logit in enumerate( logits[batch_index] ):
|
||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||
nll, metrics = _calc_loss( logit, sequence, causal )
|
||||
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||
|
||||
if nll:
|
||||
nlls.append( nll )
|
||||
|
|
|
@ -180,7 +180,9 @@ def train(
|
|||
break
|
||||
|
||||
#batch = to_device(batch, torch.cuda.current_device())
|
||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||
with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly):
|
||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||
|
||||
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
|
||||
|
||||
elapsed_time = stats.get("elapsed_time", 0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user