cannot get segmented mask to actually work without gradients exploding (need to find a different way to do duration prediction...)

This commit is contained in:
mrq 2025-03-27 00:51:41 -05:00
parent 4d777b5618
commit 2fd82a7a22
6 changed files with 89 additions and 86 deletions

View File

@ -58,26 +58,6 @@ Like the previous implementation, this model can operate entirely non-autoregres
Unlike the previous implementation, duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.
#### Attention
Unlike the previous implementation, attention needs to be paid towards the attention mask used.
Previously, a full non-causal attention mask was employed, allowing for every token to attend to every other token. This is *sort of* fine, but is unneccessary, as some segments do not need to attend to other segments.
This new implementation aims to restrict each segment from attending to future segments. In other words, the input text does not need to attend to the audio tokens, while the reference audio does not need to attend to the output.
* *Technically*, the reference audio doesn't need to attend to the input text, but it could allow for the model to explicitly map phonemes to the reference prompt.
* Unfortunately, Flash Attention through SDPA does not have granularity in the attention mask.
* Currently there's a problem with how this is implemented......
Additionally, sliding window attention is supported in this implementation, but has shown big regressions when performing additional training on existing weights.
* The fundamental principle behind this is that audio shouldn't be *that* directly dependent on an utterance X seconds in the past/future, so a sliding window is beneficial. However, I imagine the theory on why this doesn't work so well is that the model has established a non-trivial dependency on the entire utterance.
* I imagine in a broader sense, it can ensure coherency by ensuring an utterance is delivered in a similar way, or the model derives a "speaker" from utilizing the existing utterance tokens.
* A fresh model *could* have no issues, as it wouldn't be enough of a detriment.
* An existing model *could* be coerced with enough time, but I am not patient enough of a man to wait.
This implementation could utilize a causal attention mask, but both prior "testing" (in loose quotes, as it was due to an oversight) in the previous implementation and careless testing with this implementation shows that it's also a detriment to the model.
* Like the above, I imagine a fresh model *could* resolve this issue.
### Pure AR
Unlike the previous implementation, this model can also operate entirely autoregressively as a causal transformer, where each step samples *all* codebooks at one code-frame.
@ -116,7 +96,6 @@ audio_level_loss_factors: "normal" # distribution of loss weights per codebook (
masking_train_p: 1.0 # pure AR
masking_ratio: 0.8 # fixed mask ratio proves to be better
ignore_inputs_for_loss: True # False is not implemented
use_segmented_attention_mask: True # restricts each section within its own section + prior section (in other words, does not let the text/prom see further into the future outside of its segment)
use_streamlined_calc_loss: True # False has no effect now
len_loss_factor: 0.0001 # start with the default for a while to not let duration training overpower the model, then gradually increase this (but this may only be required when introducing duration training on existing weights)
@ -137,6 +116,8 @@ These settings should be avoided:
* `parallel_attention_mask_dropout`: this governs the rate of flipping to a causal (triangle) mask for training
* there's *some* reason to do this ablation, but it ruins the model (but the model can easily recover if erroneously trained with this)
* the model might eventually train itself to work around this, or it might need to be aware of this from the beginning, but it's not something to toy with.
* `use_segmented_attention_mask`: this *should* apply a special attention mask
* but in reality the model seems to fall apart after a while
* `use_sliding_attention_mask`: this applies a sliding attention mask within each segment of the input (for example, slide within the text, slide within the prom, slide within the resp), as something said in the beginning of the utterance shouldn't affect what's aid at the end
* however, it seems this is a detriment to the model, I imagine because the model could rely on how something sounds earlier on, even if there shouldn't be a direct causal relationship
* this could be something that might need to be trained from the very beginning rather than early on, but training existing models does not seem to fare well

View File

@ -799,6 +799,7 @@ class Trainer:
gc_mode: str | None = None # deprecated, but marks when to do GC
wandb: bool = False # use wandb, if available
wandb_params: dict = field(default_factory=lambda: dict)
weight_dtype: str = "float16" # dtype to have the model under

View File

@ -307,21 +307,26 @@ def load_engines(training=True, **model_kwargs):
# setup wandb
if engine._training and cfg.trainer.wandb and wandb is not None:
key_name = name
kwargs = {}
if cfg.lora is not None:
if cfg.lora is not None:
key_name = cfg.lora.full_name
salt = "-run-2"
kwargs['id'] = f'{key_name}{salt}'
kwargs['resume'] = 'allow'
salt = cfg.trainer.wandb_params.pop("salt", "-run")
kwargs = {
'id': f'{key_name}{salt}',
'resume': 'allow',
"config": dict(
config = engine.hyper_config.__dict__,
hyperparameters = cfg.hyperparameters.__dict__,
),
}
if world_size() > 1:
kwargs["group"] = f"DDP{salt}"
kwargs['id'] = f'{key_name}{salt}-{global_rank()}'
kwargs |= {
"id": f'{key_name}{salt}-{global_rank()}',
"group": f"DDP{salt}",
}
kwargs['config'] = dict(
config = engine.hyper_config.__dict__,
hyperparameters = cfg.hyperparameters.__dict__,
)
kwargs.update( cfg.trainer.wandb_params )
try:
engine.wandb = wandb.init(project=key_name, **kwargs)

View File

@ -76,9 +76,6 @@ class AR_NAR_V2(Base_V2):
# RVQ levels to apply masking training on
masking_train_rvq_levels = [0,self.n_resp_levels] # self.config.experimental.masking_train_rvq_levels
# cringe
self.audio_frames_per_second = cfg.dataset.frames_per_second
# CFG
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
@ -604,12 +601,15 @@ class AR_NAR_V2(Base_V2):
batch_size = len(resps_list)
# implicitly set for training
if training is None and phns_list is not None and resps_list is not None:
if training is None and (phns_list is not None or text_list is not None) and resps_list is not None:
n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set))
training = n_levels == self.n_resp_levels
# cringe
self.audio_frames_per_second = cfg.dataset.frames_per_second
# is training
if training:
return self.forward_train(

View File

@ -155,8 +155,6 @@ class Attention(nn.Module):
self.attn_mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
elif self.attn_mode == "cudnn":
self.attn_mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
elif self.attn_mode == "sdpa":
self.attn_mode = torch.nn.attention.SDPBackend.MATH
self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias )
self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias )
@ -365,6 +363,8 @@ class Attention(nn.Module):
if attention_mask is not None:
x_mask = x_mask[:, :, :, : key_states.shape[-2]]
# is_causal = True if x_mask is None and q_len > 1 else False
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and x_mask is not None:
@ -407,7 +407,14 @@ class Attention(nn.Module):
elif mode in ["default"]:
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
# cringe logic
attn_weights = (attn_scores + x_mask) if attention_mask is not None else (attn_scores)
if x_mask is not None:
if x_mask.dtype == torch.bool:
attn_weights = attn_scores.masked_fill_(x_mask.logical_not(), float("-inf"))
else:
attn_weights = attn_scores + x_mask
else:
attn_weights = attn_scores
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
@ -433,24 +440,21 @@ class Attention(nn.Module):
elif mode == "sdpa":
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# is_causal = True if x_mask is None and q_len > 1 else False
is_causal = True if x_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=x_mask,
attn_mask=None if is_causal else x_mask,
dropout_p=dropout_rate,
is_causal=is_causal,
)
else:
is_causal = True if x_mask is None and q_len > 1 else False
with torch.nn.attention.sdpa_kernel(self.attn_mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=x_mask,
attn_mask=None if is_causal else x_mask,
dropout_p=dropout_rate,
is_causal=is_causal,
)
@ -588,33 +592,31 @@ class Model(LlamaPreTrainedModel):
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
bsz, seq_len, _ = inputs_embeds.size()
dtype = torch.bool
device = inputs_embeds.device
min_dtype = False # torch.iinfo(dtype).min # torch.finfo(dtype).min
# generate default mask based on input
if attention_mask is None:
attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device )
attention_mask = torch.ones( (bsz, seq_len), dtype=dtype, device=device )
# make square
expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype )
mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to(dtype)
# mask = AttentionMaskConverter._unmask_unattended(mask, min_dtype)
return mask
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
# generate a sliding window pattern
def _sliding_window( self, seq_len, window_size ):
if not window_size:
return True
half_window = int(window_size // 2)
mask = torch.zeros( seq_len, seq_len, dtype=torch.bool )
def _apply_sliding_window(self, mask, start_idx, end_idx, window_size):
window_size = int(window_size // 2) # ick
seq_len = mask.size(-1)
for i in range(start_idx, min(end_idx, seq_len)):
if not window_size:
break
window_start = max(start_idx, i - window_size)
window_end = min(end_idx, i + window_size + 1)
if window_start > start_idx:
mask[..., i, start_idx:window_start] = 0
if window_end < end_idx:
mask[..., i, window_end:end_idx] = 0
for i in range( seq_len ):
window_start = max( 0, i - half_window )
window_end = min( seq_len, i + half_window + 1 )
mask[i, window_start:window_end] = True
return mask
@ -629,15 +631,15 @@ class Model(LlamaPreTrainedModel):
):
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
bsz, seq_len, _ = inputs_embeds.size()
dtype = torch.bool
device = inputs_embeds.device
min_dtype = False # torch.iinfo(dtype).min # torch.finfo(dtype).min
if attention_mask is None:
attention_mask = torch.ones((bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device)
attention_mask = torch.ones((bsz, seq_len), dtype=dtype, device=device)
expanded_mask = torch.zeros(
(bsz, 1, seq_len, seq_len),
dtype=inputs_embeds.dtype,
device=inputs_embeds.device
)
expanded_mask = torch.zeros( (bsz, 1, seq_len, seq_len), dtype=dtype, device=device )
for batch_index, aux_len in enumerate( aux_lens ):
window_size = window_sizes[batch_index] if window_sizes is not None else None
@ -653,25 +655,30 @@ class Model(LlamaPreTrainedModel):
prom_start, prom_end = text_end, text_end + prom_len
output_start, output_end = prom_end, prom_end + output_len
"""
output_start, output_end = prom_end, seq_len # prom_end + output_len
if text_len:
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], text_start, text_end, text_window )
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = True
if prom_len:
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], prom_start, prom_end, prom_window )
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = True
if output_len:
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0
expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], output_start, output_end, output_window )
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = True
"""
if text_len:
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = True if not text_window else self._sliding_window( text_len, text_window )
if prom_len:
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:text_end] = True
expanded_mask[batch_index, 0, prom_start:prom_end, prom_start:prom_end] = True if not prom_window else self._sliding_window( prom_len, prom_window )
if output_len:
expanded_mask[batch_index, 0, output_start:output_end, text_start:text_end] = True
expanded_mask[batch_index, 0, output_start:output_end, prom_start:prom_end] = True
expanded_mask[batch_index, 0, output_start:output_end, output_start:output_end] = True if not output_window else self._sliding_window( output_len, output_window )
# apply the original attention mask
expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(dtype=torch.bool),
torch.finfo(inputs_embeds.dtype).min
)
mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len).to(dtype)
#mask = AttentionMaskConverter._unmask_unattended(mask, min_dtype)
return mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
@ -684,6 +691,14 @@ class Model(LlamaPreTrainedModel):
batch_size: int,
**kwargs,
):
"""
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
bsz, seq_len, _ = inputs_embeds.size()
if attention_mask is None:
attention_mask = torch.ones((bsz, seq_len), dtype=torch.bool, device=device)
"""
if attention_mask is not None and attention_mask.dim() == 4:
causal_mask = attention_mask
else:
@ -705,9 +720,10 @@ class Model(LlamaPreTrainedModel):
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
def _update_causal_mask(
self,

View File

@ -62,7 +62,7 @@ def md5_hash( x ):
return hashlib.md5(str(x).encode("utf-8")).hexdigest()
# removes entries from a dict if that key is missing from the source
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True, ignore=["optimizer_params"] ):
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True, ignore=["optimizer_params", "wandb_params"] ):
is_obj = hasattr( source, "__dict__" )
if parent_is_obj is None:
parent_is_obj = is_obj