I'll just cope and say I cannot apply segmented attention masks to the smaller model as it's too trained on not doing it, and the regression came from dumb python aliasing rules

This commit is contained in:
mrq 2025-03-27 13:27:51 -05:00
parent 2fd82a7a22
commit 90b3509404
3 changed files with 66 additions and 53 deletions

View File

@ -58,6 +58,26 @@ Like the previous implementation, this model can operate entirely non-autoregres
Unlike the previous implementation, duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.
#### Attention
Unlike the previous implementation, attention needs to be paid towards the attention mask used.
Previously, a full non-causal attention mask was employed, allowing for every token to attend to every other token. This is *sort of* fine, but is unneccessary, as some segments do not need to attend to other segments.
This new implementation aims to restrict each segment from attending to future segments. In other words, the input text does not need to attend to the audio tokens, while the reference audio does not need to attend to the output.
* *Technically*, the reference audio doesn't need to attend to the input text, but it could allow for the model to explicitly map phonemes to the reference prompt.
* Unfortunately, this does not seem to work for the `nemo-smaller` model
* I'm not too sure why this is the case, but I suppose it's just how that model's weights progressed, as the `nemo-larger` model seems fine.
Additionally, sliding window attention is supported in this implementation, but has shown big regressions when performing additional training on existing weights.
* The fundamental principle behind this is that audio shouldn't be *that* directly dependent on an utterance X seconds in the past/future, so a sliding window is beneficial. However, I imagine the theory on why this doesn't work so well is that the model has established a non-trivial dependency on the entire utterance.
* I imagine in a broader sense, it can ensure coherency by ensuring an utterance is delivered in a similar way, or the model derives a "speaker" from utilizing the existing utterance tokens.
* A fresh model *could* have no issues, as it wouldn't be enough of a detriment.
* An existing model *could* be coerced with enough time, but I am not patient enough of a man to wait.
This implementation could utilize a causal attention mask, but both prior "testing" (in loose quotes, as it was due to an oversight) in the previous implementation and careless testing with this implementation shows that it's also a detriment to the model.
* Like the above, I imagine a fresh model *could* resolve this issue.
### Pure AR
Unlike the previous implementation, this model can also operate entirely autoregressively as a causal transformer, where each step samples *all* codebooks at one code-frame.
@ -96,6 +116,7 @@ audio_level_loss_factors: "normal" # distribution of loss weights per codebook (
masking_train_p: 1.0 # pure AR
masking_ratio: 0.8 # fixed mask ratio proves to be better
ignore_inputs_for_loss: True # False is not implemented
use_segmented_attention_mask: True # restricts each section within its own section + prior section (in other words, does not let the text/prom see further into the future outside of its segment)
use_streamlined_calc_loss: True # False has no effect now
len_loss_factor: 0.0001 # start with the default for a while to not let duration training overpower the model, then gradually increase this (but this may only be required when introducing duration training on existing weights)
@ -116,8 +137,6 @@ These settings should be avoided:
* `parallel_attention_mask_dropout`: this governs the rate of flipping to a causal (triangle) mask for training
* there's *some* reason to do this ablation, but it ruins the model (but the model can easily recover if erroneously trained with this)
* the model might eventually train itself to work around this, or it might need to be aware of this from the beginning, but it's not something to toy with.
* `use_segmented_attention_mask`: this *should* apply a special attention mask
* but in reality the model seems to fall apart after a while
* `use_sliding_attention_mask`: this applies a sliding attention mask within each segment of the input (for example, slide within the text, slide within the prom, slide within the resp), as something said in the beginning of the utterance shouldn't affect what's aid at the end
* however, it seems this is a detriment to the model, I imagine because the model could rely on how something sounds earlier on, even if there shouldn't be a direct causal relationship
* this could be something that might need to be trained from the very beginning rather than early on, but training existing models does not seem to fare well

View File

@ -363,7 +363,15 @@ class Attention(nn.Module):
if attention_mask is not None:
x_mask = x_mask[:, :, :, : key_states.shape[-2]]
# is_causal = True if x_mask is None and q_len > 1 else False
if isinstance( is_causal, list ):
count = sum( is_causal )
if count == 0:
is_causal = False
elif count == len( is_causal ):
is_causal = True
if self.attn_mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION] or is_causal:
x_mask = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@ -406,14 +414,13 @@ class Attention(nn.Module):
)
elif mode in ["default"]:
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# cringe logic
if x_mask is not None:
if x_mask.dtype == torch.bool:
attn_weights = attn_scores.masked_fill_(x_mask.logical_not(), float("-inf"))
else:
attn_weights = attn_scores + x_mask
else:
if x_mask is None:
attn_weights = attn_scores
elif x_mask.dtype == torch.bool:
attn_weights = attn_scores.masked_fill(x_mask.logical_not(), float("-inf"))
else:
attn_weights = attn_scores + x_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@ -425,26 +432,12 @@ class Attention(nn.Module):
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
if isinstance( is_causal, list ):
is_causal = is_causal[0]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None, # ROCm FA2 through SDPA doesn't allow masks, bummer
dropout_p=dropout_rate,
is_causal=is_causal,
)
elif mode == "sdpa":
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=None if is_causal else x_mask,
attn_mask=x_mask,
dropout_p=dropout_rate,
is_causal=is_causal,
)
@ -454,7 +447,7 @@ class Attention(nn.Module):
query_states,
key_states,
value_states,
attn_mask=None if is_causal else x_mask,
attn_mask=x_mask,
dropout_p=dropout_rate,
is_causal=is_causal,
)
@ -586,7 +579,7 @@ class Model(LlamaPreTrainedModel):
self,
attention_mask,
inputs_embeds,
past_key_values_length,
past_key_values_length = 0,
):
# create noncausal mask
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
@ -602,7 +595,6 @@ class Model(LlamaPreTrainedModel):
# make square
mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to(dtype)
# mask = AttentionMaskConverter._unmask_unattended(mask, min_dtype)
return mask
# generate a sliding window pattern
@ -627,7 +619,7 @@ class Model(LlamaPreTrainedModel):
inputs_embeds,
aux_lens, # (bsz, lens), where [batch_index, 0] = text_len, and [batch_index, 1] = prom_len
window_sizes = None, # (bsz, lens), same as above
past_key_values_length=0,
past_key_values_length = 0,
):
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
bsz, seq_len, _ = inputs_embeds.size()
@ -639,7 +631,7 @@ class Model(LlamaPreTrainedModel):
if attention_mask is None:
attention_mask = torch.ones((bsz, seq_len), dtype=dtype, device=device)
expanded_mask = torch.zeros( (bsz, 1, seq_len, seq_len), dtype=dtype, device=device )
mask = torch.full( (bsz, 1, seq_len, seq_len), min_dtype, dtype=dtype, device=device )
for batch_index, aux_len in enumerate( aux_lens ):
window_size = window_sizes[batch_index] if window_sizes is not None else None
@ -655,29 +647,27 @@ class Model(LlamaPreTrainedModel):
prom_start, prom_end = text_end, text_end + prom_len
output_start, output_end = prom_end, prom_end + output_len
"""
output_start, output_end = prom_end, seq_len # prom_end + output_len
if text_len:
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = True
mask[batch_index, 0, text_start:text_end, text_start:text_end] = True
if prom_len:
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = True
mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = True
if output_len:
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = True
"""
mask[batch_index, 0, output_start:output_end, text_start:output_end] = True
"""
if text_len:
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = True if not text_window else self._sliding_window( text_len, text_window )
mask[batch_index, 0, text_start:text_end, text_start:text_end] = True if not text_window else self._sliding_window( text_len, text_window )
if prom_len:
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:text_end] = True
expanded_mask[batch_index, 0, prom_start:prom_end, prom_start:prom_end] = True if not prom_window else self._sliding_window( prom_len, prom_window )
mask[batch_index, 0, prom_start:prom_end, text_start:text_end] = True
mask[batch_index, 0, prom_start:prom_end, prom_start:prom_end] = True if not prom_window else self._sliding_window( prom_len, prom_window )
if output_len:
expanded_mask[batch_index, 0, output_start:output_end, text_start:text_end] = True
expanded_mask[batch_index, 0, output_start:output_end, prom_start:prom_end] = True
expanded_mask[batch_index, 0, output_start:output_end, output_start:output_end] = True if not output_window else self._sliding_window( output_len, output_window )
mask[batch_index, 0, output_start:output_end, text_start:text_end] = True
mask[batch_index, 0, output_start:output_end, prom_start:prom_end] = True
mask[batch_index, 0, output_start:output_end, output_start:output_end] = True if not output_window else self._sliding_window( output_len, output_window )
"""
# apply the original attention mask
mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len).to(dtype)
#mask = AttentionMaskConverter._unmask_unattended(mask, min_dtype)
mask = mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len).to(dtype)
return mask
@staticmethod

View File

@ -1117,8 +1117,6 @@ class Base_V2(nn.Module):
padding = torch.zeros(shape[:2], dtype=x.dtype, device=x.device)
mask = torch.cat([mask, padding], dim=1)
m = mask.unsqueeze(dim=-1)
# needs to be done here as we still have our raw inputs
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
classifier_levels = self.get_input( inputs, name="classifier_level" )
@ -1136,21 +1134,27 @@ class Base_V2(nn.Module):
text_window = 32 if self.use_sliding_attention_mask else 0
audio_window = self.audio_frames_per_second // 2 if self.use_sliding_attention_mask else 0
aux_lens = [[2, 0, 0]] * batch_size
aux_windows = [[text_window, audio_window, audio_window]] * batch_size
aux_lens = []
aux_windows = []
# fill aux lens
for batch_index, batch_input in enumerate( inputs ):
lens = [2, 0, 0]
windows = [text_window, audio_window, audio_window]
for name, input in batch_input:
if name in ["phn", "text"]:
aux_lens[batch_index][0] = input.shape[0] + 1
lens[0] = input.shape[0] + 1
elif name == "lang":
aux_lens[batch_index][0] += 2
lens[0] += 2
elif name == "prom":
aux_lens[batch_index][1] = input.shape[0] + 1
lens[1] = input.shape[0] + 1
elif name == "tone":
aux_lens[batch_index][1] += 2
lens[1] += 2
elif name == "resp":
aux_lens[batch_index][2] = input.shape[0]
lens[2] = input.shape[0]
aux_lens.append( lens )
aux_windows.append( windows )
if self.use_segmented_attention_mask and not any(is_causal):
mask = self.model._update_segmented_mask( mask, x, aux_lens, window_sizes=aux_windows )