when you already had these ideas to stabilize training but you just ignored them
This commit is contained in:
parent
0a45c9c042
commit
f4f435d7f5
|
@ -277,6 +277,7 @@ class ModelExperimentalSettings:
|
||||||
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
|
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
|
||||||
# * NAR-demask would semi-doubly train for AR
|
# * NAR-demask would semi-doubly train for AR
|
||||||
# * the model wouldn't also need to learn when to predict the token in place
|
# * the model wouldn't also need to learn when to predict the token in place
|
||||||
|
audio_encoder_mode: str = "sum" # audio encoder mode for version >= 7, because I cannot make up my damn mind
|
||||||
|
|
||||||
# these technically should be as hyperparameters
|
# these technically should be as hyperparameters
|
||||||
# performs token dropout to compensate for errors
|
# performs token dropout to compensate for errors
|
||||||
|
@ -737,6 +738,7 @@ class Trainer:
|
||||||
|
|
||||||
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
activation_checkpointing: bool | None = None # deprecated, should technically be used for only on activations and not the entire gradients, but HF only has gradient checkpointing
|
||||||
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
gradient_checkpointing: bool = True # enables gradient checkpointing to save VRAM at the cost of slightly reduced performance when training
|
||||||
|
detect_grad_anomaly: bool = False # torch.autograd.set_detect_anomaly
|
||||||
|
|
||||||
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
check_for_oom: bool = True # checks for OOMs thrown during forward/backwards
|
||||||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||||
|
|
|
@ -88,20 +88,42 @@ class AudioEncoder(nn.Module):
|
||||||
n_tokens: int,
|
n_tokens: int,
|
||||||
n_levels: int,
|
n_levels: int,
|
||||||
token_dim: int,
|
token_dim: int,
|
||||||
enc_mode: str = "sum"
|
enc_mode: str = "sum",
|
||||||
|
l_weights: list[float] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.enc_mode = enc_mode
|
self.enc_mode = enc_mode
|
||||||
|
|
||||||
|
d_ffn = 4
|
||||||
|
if not l_weights:
|
||||||
|
l_weights = [1 for _ in range(n_levels)]
|
||||||
|
|
||||||
if enc_mode == "sum":
|
if enc_mode == "sum":
|
||||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||||
self.proj = None
|
self.proj = None
|
||||||
|
self.weights = nn.Parameter(torch.tensor(l_weights))
|
||||||
elif enc_mode == "sub_interleave":
|
elif enc_mode == "sub_interleave":
|
||||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
|
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim // n_levels) for l in range(n_levels)])
|
||||||
self.proj = None
|
self.proj = None
|
||||||
elif enc_mode == "interleave":
|
elif enc_mode == "interleave":
|
||||||
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||||
self.proj = nn.Linear(8 * token_dim, 1 * token_dim)
|
#self.proj = nn.Linear(n_levels * token_dim, token_dim)
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Linear(n_levels * token_dim, d_ffn * token_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(d_ffn * token_dim, token_dim)
|
||||||
|
)
|
||||||
|
elif enc_mode == "attn":
|
||||||
|
self.embs = nn.ModuleList([ml.Embedding(n_tokens, token_dim) for l in range(n_levels)])
|
||||||
|
self.cross_attn = nn.MultiheadAttention(embed_dim=token_dim,num_heads=n_levels,dropout=0.1)
|
||||||
|
self.proj = nn.Sequential(
|
||||||
|
nn.Linear(n_levels * token_dim, d_ffn * token_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(d_ffn * token_dim, token_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
for emb in self.embs:
|
||||||
|
nn.init.normal_(emb.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
def forward(self, xi: Tensor, dropout_mask = None, dropout_token = None ) -> Tensor:
|
||||||
# empty
|
# empty
|
||||||
|
@ -114,12 +136,26 @@ class AudioEncoder(nn.Module):
|
||||||
# old way
|
# old way
|
||||||
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
|
# in theory RVQ-based codecs should prefer this, but this doesn't yield good results
|
||||||
if self.enc_mode == "sum":
|
if self.enc_mode == "sum":
|
||||||
x = sum([ emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
weights = F.softmax( self.weights, dim=0 )
|
||||||
|
x = sum([ weights[l] * emb( xi[:, l] ) for l, emb in enumerate(self.embs) ])
|
||||||
|
# attention-based crunge
|
||||||
|
elif self.enc_mode == "attn":
|
||||||
|
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||||
|
attn, _ = self.cross_attn(
|
||||||
|
x.permute(1, 0, 2),
|
||||||
|
x.permute(1, 0, 2),
|
||||||
|
x.permute(1, 0, 2),
|
||||||
|
)
|
||||||
|
attn = attn.permute(1, 0, 2)
|
||||||
|
x = x + attn
|
||||||
|
x = x.view(x.shape[0], -1)
|
||||||
|
# x = attn.reshape(x.shape[0], -1)
|
||||||
# encode by interleaving embeddings into one "token"
|
# encode by interleaving embeddings into one "token"
|
||||||
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
|
# this "works" but I imagine it being excessive and doesn't seem to help the model all that much
|
||||||
else:
|
else:
|
||||||
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
x = torch.stack([emb(xi[:, l]) for l, emb in enumerate(self.embs)], dim=1)
|
||||||
x = x.view(x.shape[0], -1)
|
x = x.view(x.shape[0], -1)
|
||||||
|
|
||||||
if self.proj is not None:
|
if self.proj is not None:
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
|
|
||||||
|
@ -207,6 +243,7 @@ class Base_V2(nn.Module):
|
||||||
if not attention:
|
if not attention:
|
||||||
attention = config.attention if config is not None else "auto"
|
attention = config.attention if config is not None else "auto"
|
||||||
|
|
||||||
|
n_resp_levels = config.resp_levels if config is not None else 8
|
||||||
attention_backend = attention
|
attention_backend = attention
|
||||||
unified_position_ids = config.experimental.unified_position_ids if config is not None else True
|
unified_position_ids = config.experimental.unified_position_ids if config is not None else True
|
||||||
noncausal_masks = config.experimental.noncausal_masks if config is not None else False
|
noncausal_masks = config.experimental.noncausal_masks if config is not None else False
|
||||||
|
@ -218,6 +255,8 @@ class Base_V2(nn.Module):
|
||||||
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
||||||
predict_causally = config.experimental.predict_causally if config is not None else False
|
predict_causally = config.experimental.predict_causally if config is not None else False
|
||||||
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
||||||
|
audio_encoder_mode = config.experimental.audio_encoder_mode if config is not None else "sum"
|
||||||
|
audio_level_weights = [1.0 / (i + 1) for i in range(n_resp_levels)] # to-do: find the weights for FSQ
|
||||||
|
|
||||||
n_vocab = 256
|
n_vocab = 256
|
||||||
n_tasks = config.tasks if config is not None else 8
|
n_tasks = config.tasks if config is not None else 8
|
||||||
|
@ -283,6 +322,7 @@ class Base_V2(nn.Module):
|
||||||
self.masking_ratio = masking_ratio
|
self.masking_ratio = masking_ratio
|
||||||
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
self.ignore_inputs_for_loss = ignore_inputs_for_loss
|
||||||
self.noncausal_masks = noncausal_masks
|
self.noncausal_masks = noncausal_masks
|
||||||
|
self.audio_level_weights = audio_level_weights
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
@ -302,17 +342,23 @@ class Base_V2(nn.Module):
|
||||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
|
enc_mode=audio_encoder_mode,
|
||||||
|
l_weights=audio_level_weights,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.proms_emb = AudioEncoder(
|
self.proms_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens,
|
n_tokens=n_audio_tokens,
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
|
enc_mode=audio_encoder_mode,
|
||||||
|
l_weights=audio_level_weights,
|
||||||
)
|
)
|
||||||
self.resps_emb = AudioEncoder(
|
self.resps_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
token_dim=d_model,
|
token_dim=d_model,
|
||||||
|
enc_mode=audio_encoder_mode,
|
||||||
|
l_weights=audio_level_weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.audio_decoder = AudioDecoder(
|
self.audio_decoder = AudioDecoder(
|
||||||
|
@ -747,6 +793,7 @@ class Base_V2(nn.Module):
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
batch_size = len(logits)
|
batch_size = len(logits)
|
||||||
classifier_levels = self.get_input( inputs, "classifier_level" )
|
classifier_levels = self.get_input( inputs, "classifier_level" )
|
||||||
|
level_weights = self.audio_level_weights
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
@ -755,7 +802,7 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
return input
|
return input
|
||||||
|
|
||||||
def _calc_loss( logit, sequence, causal = True ):
|
def _calc_loss( logit, sequence, causal = True, level = None ):
|
||||||
# filter tokens that exceed the vocab size
|
# filter tokens that exceed the vocab size
|
||||||
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
sequence = torch.where( sequence >= logit.shape[-1], self.ignore_index, sequence )
|
||||||
# drop if all tokens are ignored
|
# drop if all tokens are ignored
|
||||||
|
@ -769,7 +816,8 @@ class Base_V2(nn.Module):
|
||||||
sequence = sequence[..., l:] # ...predicts token n + 1
|
sequence = sequence[..., l:] # ...predicts token n + 1
|
||||||
|
|
||||||
# flatten batch
|
# flatten batch
|
||||||
if sequence.dim() > 1:
|
parallel = sequence.dim() > 1
|
||||||
|
if parallel:
|
||||||
logit = logit.reshape(-1, logit.shape[-1])
|
logit = logit.reshape(-1, logit.shape[-1])
|
||||||
sequence = sequence.reshape(-1)
|
sequence = sequence.reshape(-1)
|
||||||
|
|
||||||
|
@ -777,7 +825,11 @@ class Base_V2(nn.Module):
|
||||||
metrics = None
|
metrics = None
|
||||||
|
|
||||||
if compute_hard_loss:
|
if compute_hard_loss:
|
||||||
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index )
|
nll = F.cross_entropy( logit, sequence, ignore_index=self.ignore_index, reduction='mean' if not parallel else 'none' ) * (level_weights[level] if level is not None and not parallel else 1)
|
||||||
|
|
||||||
|
# manually weigh each level
|
||||||
|
if parallel:
|
||||||
|
nll = nll.view( self.n_resp_levels, -1 ).mean(dim=-1) * torch.tensor(level_weights, device=device)
|
||||||
|
|
||||||
if compute_acc:
|
if compute_acc:
|
||||||
accuracy_metric = MulticlassAccuracy(
|
accuracy_metric = MulticlassAccuracy(
|
||||||
|
@ -875,9 +927,6 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
if logits[batch_index].dim() < 3:
|
if logits[batch_index].dim() < 3:
|
||||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||||
|
|
||||||
if name == "resp":
|
|
||||||
name = f'{name}[{quant_level}]'
|
|
||||||
elif not self.resp_parallel_training:
|
elif not self.resp_parallel_training:
|
||||||
# cringe way to deduce "requested" level
|
# cringe way to deduce "requested" level
|
||||||
level = quant_level
|
level = quant_level
|
||||||
|
@ -885,24 +934,35 @@ class Base_V2(nn.Module):
|
||||||
if classifier_level.endswith(f':{i}:{i}'):
|
if classifier_level.endswith(f':{i}:{i}'):
|
||||||
level = i
|
level = i
|
||||||
break
|
break
|
||||||
"""
|
|
||||||
if name == "resp":
|
if name == "resp":
|
||||||
name = f'{name}[{level}]'
|
name = f'{name}[{level}]'
|
||||||
"""
|
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
sequence = token if token.dim() <= 1 else token[:, level]
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal, level )
|
||||||
else:
|
else:
|
||||||
sequence = token.t()
|
sequence = token.t()
|
||||||
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][:, start:end], sequence.long(), causal )
|
||||||
|
|
||||||
|
for level in enumerate(self.n_resp_levels):
|
||||||
|
loss_key = f'{name}[{level}].nll'
|
||||||
|
if loss_key not in loss:
|
||||||
|
loss[loss_key] = []
|
||||||
|
loss[loss_key].append( nll[level] * loss_factor )
|
||||||
|
|
||||||
|
nll = None
|
||||||
|
|
||||||
|
loss_key = f'{name}.nll'
|
||||||
|
acc_key = f'{name}.acc'
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if f'{name}.nll' not in loss:
|
if loss_key not in loss:
|
||||||
loss[f'{name}.nll'] = []
|
loss[loss_key] = []
|
||||||
loss[f"{name}.nll"].append( nll * loss_factor )
|
loss[loss_key].append( nll * loss_factor )
|
||||||
|
|
||||||
if metrics is not None:
|
if metrics is not None:
|
||||||
if f'{name}.acc' not in stats:
|
if acc_key not in stats:
|
||||||
stats[f'{name}.acc'] = []
|
stats[acc_key] = []
|
||||||
stats[f"{name}.acc"].append( metrics )
|
stats[acc_key].append( metrics )
|
||||||
# add to list
|
# add to list
|
||||||
else:
|
else:
|
||||||
target.append( token )
|
target.append( token )
|
||||||
|
@ -922,7 +982,7 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal, level )
|
||||||
else:
|
else:
|
||||||
nlls = []
|
nlls = []
|
||||||
accs = []
|
accs = []
|
||||||
|
@ -930,7 +990,7 @@ class Base_V2(nn.Module):
|
||||||
for level, logit in enumerate( logits[batch_index] ):
|
for level, logit in enumerate( logits[batch_index] ):
|
||||||
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
nll, metrics = _calc_loss( logit, sequence, causal )
|
nll, metrics = _calc_loss( logit, sequence, causal, level )
|
||||||
|
|
||||||
if nll:
|
if nll:
|
||||||
nlls.append( nll )
|
nlls.append( nll )
|
||||||
|
|
|
@ -180,7 +180,9 @@ def train(
|
||||||
break
|
break
|
||||||
|
|
||||||
#batch = to_device(batch, torch.cuda.current_device())
|
#batch = to_device(batch, torch.cuda.current_device())
|
||||||
stats = engines.step(batch=batch, feeder=train_feeder)
|
with torch.autograd.set_detect_anomaly(cfg.trainer.detect_grad_anomaly):
|
||||||
|
stats = engines.step(batch=batch, feeder=train_feeder)
|
||||||
|
|
||||||
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
|
stats['epoch'] = engines.global_samples / (len(train_dl.dataset.paths) * world_size())
|
||||||
|
|
||||||
elapsed_time = stats.get("elapsed_time", 0)
|
elapsed_time = stats.get("elapsed_time", 0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user