cannot get segmented mask to actually work without gradients exploding (need to find a different way to do duration prediction...)
This commit is contained in:
parent
4d777b5618
commit
2fd82a7a22
|
@ -58,26 +58,6 @@ Like the previous implementation, this model can operate entirely non-autoregres
|
|||
|
||||
Unlike the previous implementation, duration prediction is trained in parallel with the base `tts` task, where the output feature is always at the separator after the input prompt. This moves away from the kludge of treating the duration as an extra "language" task with a vocab size of `11`, and decoded autoregressively, while allowing some wiggle room in the duration as it's no longer sampled using logits.
|
||||
|
||||
#### Attention
|
||||
|
||||
Unlike the previous implementation, attention needs to be paid towards the attention mask used.
|
||||
|
||||
Previously, a full non-causal attention mask was employed, allowing for every token to attend to every other token. This is *sort of* fine, but is unneccessary, as some segments do not need to attend to other segments.
|
||||
|
||||
This new implementation aims to restrict each segment from attending to future segments. In other words, the input text does not need to attend to the audio tokens, while the reference audio does not need to attend to the output.
|
||||
* *Technically*, the reference audio doesn't need to attend to the input text, but it could allow for the model to explicitly map phonemes to the reference prompt.
|
||||
* Unfortunately, Flash Attention through SDPA does not have granularity in the attention mask.
|
||||
* Currently there's a problem with how this is implemented......
|
||||
|
||||
Additionally, sliding window attention is supported in this implementation, but has shown big regressions when performing additional training on existing weights.
|
||||
* The fundamental principle behind this is that audio shouldn't be *that* directly dependent on an utterance X seconds in the past/future, so a sliding window is beneficial. However, I imagine the theory on why this doesn't work so well is that the model has established a non-trivial dependency on the entire utterance.
|
||||
* I imagine in a broader sense, it can ensure coherency by ensuring an utterance is delivered in a similar way, or the model derives a "speaker" from utilizing the existing utterance tokens.
|
||||
* A fresh model *could* have no issues, as it wouldn't be enough of a detriment.
|
||||
* An existing model *could* be coerced with enough time, but I am not patient enough of a man to wait.
|
||||
|
||||
This implementation could utilize a causal attention mask, but both prior "testing" (in loose quotes, as it was due to an oversight) in the previous implementation and careless testing with this implementation shows that it's also a detriment to the model.
|
||||
* Like the above, I imagine a fresh model *could* resolve this issue.
|
||||
|
||||
### Pure AR
|
||||
|
||||
Unlike the previous implementation, this model can also operate entirely autoregressively as a causal transformer, where each step samples *all* codebooks at one code-frame.
|
||||
|
@ -116,7 +96,6 @@ audio_level_loss_factors: "normal" # distribution of loss weights per codebook (
|
|||
masking_train_p: 1.0 # pure AR
|
||||
masking_ratio: 0.8 # fixed mask ratio proves to be better
|
||||
ignore_inputs_for_loss: True # False is not implemented
|
||||
use_segmented_attention_mask: True # restricts each section within its own section + prior section (in other words, does not let the text/prom see further into the future outside of its segment)
|
||||
use_streamlined_calc_loss: True # False has no effect now
|
||||
len_loss_factor: 0.0001 # start with the default for a while to not let duration training overpower the model, then gradually increase this (but this may only be required when introducing duration training on existing weights)
|
||||
|
||||
|
@ -137,6 +116,8 @@ These settings should be avoided:
|
|||
* `parallel_attention_mask_dropout`: this governs the rate of flipping to a causal (triangle) mask for training
|
||||
* there's *some* reason to do this ablation, but it ruins the model (but the model can easily recover if erroneously trained with this)
|
||||
* the model might eventually train itself to work around this, or it might need to be aware of this from the beginning, but it's not something to toy with.
|
||||
* `use_segmented_attention_mask`: this *should* apply a special attention mask
|
||||
* but in reality the model seems to fall apart after a while
|
||||
* `use_sliding_attention_mask`: this applies a sliding attention mask within each segment of the input (for example, slide within the text, slide within the prom, slide within the resp), as something said in the beginning of the utterance shouldn't affect what's aid at the end
|
||||
* however, it seems this is a detriment to the model, I imagine because the model could rely on how something sounds earlier on, even if there shouldn't be a direct causal relationship
|
||||
* this could be something that might need to be trained from the very beginning rather than early on, but training existing models does not seem to fare well
|
||||
|
|
|
@ -799,6 +799,7 @@ class Trainer:
|
|||
gc_mode: str | None = None # deprecated, but marks when to do GC
|
||||
|
||||
wandb: bool = False # use wandb, if available
|
||||
wandb_params: dict = field(default_factory=lambda: dict)
|
||||
|
||||
weight_dtype: str = "float16" # dtype to have the model under
|
||||
|
||||
|
|
|
@ -307,21 +307,26 @@ def load_engines(training=True, **model_kwargs):
|
|||
# setup wandb
|
||||
if engine._training and cfg.trainer.wandb and wandb is not None:
|
||||
key_name = name
|
||||
kwargs = {}
|
||||
if cfg.lora is not None:
|
||||
if cfg.lora is not None:
|
||||
key_name = cfg.lora.full_name
|
||||
|
||||
salt = "-run-2"
|
||||
kwargs['id'] = f'{key_name}{salt}'
|
||||
kwargs['resume'] = 'allow'
|
||||
salt = cfg.trainer.wandb_params.pop("salt", "-run")
|
||||
kwargs = {
|
||||
'id': f'{key_name}{salt}',
|
||||
'resume': 'allow',
|
||||
"config": dict(
|
||||
config = engine.hyper_config.__dict__,
|
||||
hyperparameters = cfg.hyperparameters.__dict__,
|
||||
),
|
||||
}
|
||||
|
||||
if world_size() > 1:
|
||||
kwargs["group"] = f"DDP{salt}"
|
||||
kwargs['id'] = f'{key_name}{salt}-{global_rank()}'
|
||||
kwargs |= {
|
||||
"id": f'{key_name}{salt}-{global_rank()}',
|
||||
"group": f"DDP{salt}",
|
||||
}
|
||||
|
||||
kwargs['config'] = dict(
|
||||
config = engine.hyper_config.__dict__,
|
||||
hyperparameters = cfg.hyperparameters.__dict__,
|
||||
)
|
||||
kwargs.update( cfg.trainer.wandb_params )
|
||||
|
||||
try:
|
||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||
|
|
|
@ -76,9 +76,6 @@ class AR_NAR_V2(Base_V2):
|
|||
# RVQ levels to apply masking training on
|
||||
masking_train_rvq_levels = [0,self.n_resp_levels] # self.config.experimental.masking_train_rvq_levels
|
||||
|
||||
# cringe
|
||||
self.audio_frames_per_second = cfg.dataset.frames_per_second
|
||||
|
||||
# CFG
|
||||
cfg_text_dropout_p = self.config.experimental.cfg_text_dropout_p if self.config is not None else 0.0
|
||||
cfg_cond_dropout_p = self.config.experimental.cfg_cond_dropout_p if self.config is not None else 0.0
|
||||
|
@ -604,12 +601,15 @@ class AR_NAR_V2(Base_V2):
|
|||
batch_size = len(resps_list)
|
||||
|
||||
# implicitly set for training
|
||||
if training is None and phns_list is not None and resps_list is not None:
|
||||
if training is None and (phns_list is not None or text_list is not None) and resps_list is not None:
|
||||
n_levels_set = {r.shape[-1] for r in resps_list}
|
||||
n_levels = next(iter(n_levels_set))
|
||||
|
||||
training = n_levels == self.n_resp_levels
|
||||
|
||||
# cringe
|
||||
self.audio_frames_per_second = cfg.dataset.frames_per_second
|
||||
|
||||
# is training
|
||||
if training:
|
||||
return self.forward_train(
|
||||
|
|
|
@ -155,8 +155,6 @@ class Attention(nn.Module):
|
|||
self.attn_mode = torch.nn.attention.SDPBackend.FLASH_ATTENTION
|
||||
elif self.attn_mode == "cudnn":
|
||||
self.attn_mode = torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
||||
elif self.attn_mode == "sdpa":
|
||||
self.attn_mode = torch.nn.attention.SDPBackend.MATH
|
||||
|
||||
self.q_proj = nn.Linear( config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias )
|
||||
self.k_proj = nn.Linear( config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias )
|
||||
|
@ -365,6 +363,8 @@ class Attention(nn.Module):
|
|||
if attention_mask is not None:
|
||||
x_mask = x_mask[:, :, :, : key_states.shape[-2]]
|
||||
|
||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||
|
||||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
||||
if query_states.device.type == "cuda" and x_mask is not None:
|
||||
|
@ -407,7 +407,14 @@ class Attention(nn.Module):
|
|||
elif mode in ["default"]:
|
||||
attn_scores = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
# cringe logic
|
||||
attn_weights = (attn_scores + x_mask) if attention_mask is not None else (attn_scores)
|
||||
if x_mask is not None:
|
||||
if x_mask.dtype == torch.bool:
|
||||
attn_weights = attn_scores.masked_fill_(x_mask.logical_not(), float("-inf"))
|
||||
else:
|
||||
attn_weights = attn_scores + x_mask
|
||||
else:
|
||||
attn_weights = attn_scores
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||
|
@ -433,24 +440,21 @@ class Attention(nn.Module):
|
|||
elif mode == "sdpa":
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=x_mask,
|
||||
attn_mask=None if is_causal else x_mask,
|
||||
dropout_p=dropout_rate,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
with torch.nn.attention.sdpa_kernel(self.attn_mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=x_mask,
|
||||
attn_mask=None if is_causal else x_mask,
|
||||
dropout_p=dropout_rate,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
@ -588,33 +592,31 @@ class Model(LlamaPreTrainedModel):
|
|||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
|
||||
bsz, seq_len, _ = inputs_embeds.size()
|
||||
dtype = torch.bool
|
||||
device = inputs_embeds.device
|
||||
min_dtype = False # torch.iinfo(dtype).min # torch.finfo(dtype).min
|
||||
|
||||
# generate default mask based on input
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones( (bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device )
|
||||
attention_mask = torch.ones( (bsz, seq_len), dtype=dtype, device=device )
|
||||
|
||||
# make square
|
||||
expanded_mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to( dtype=inputs_embeds.dtype )
|
||||
mask = attention_mask[:, None, None, :].expand( bsz, 1, seq_len, seq_len ).to(dtype)
|
||||
# mask = AttentionMaskConverter._unmask_unattended(mask, min_dtype)
|
||||
return mask
|
||||
|
||||
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
return inverted_mask.masked_fill( inverted_mask.to(dtype=torch.bool), torch.finfo(inputs_embeds.dtype).min )
|
||||
# generate a sliding window pattern
|
||||
def _sliding_window( self, seq_len, window_size ):
|
||||
if not window_size:
|
||||
return True
|
||||
|
||||
half_window = int(window_size // 2)
|
||||
mask = torch.zeros( seq_len, seq_len, dtype=torch.bool )
|
||||
|
||||
def _apply_sliding_window(self, mask, start_idx, end_idx, window_size):
|
||||
window_size = int(window_size // 2) # ick
|
||||
|
||||
seq_len = mask.size(-1)
|
||||
for i in range(start_idx, min(end_idx, seq_len)):
|
||||
if not window_size:
|
||||
break
|
||||
|
||||
window_start = max(start_idx, i - window_size)
|
||||
window_end = min(end_idx, i + window_size + 1)
|
||||
|
||||
if window_start > start_idx:
|
||||
mask[..., i, start_idx:window_start] = 0
|
||||
if window_end < end_idx:
|
||||
mask[..., i, window_end:end_idx] = 0
|
||||
for i in range( seq_len ):
|
||||
window_start = max( 0, i - half_window )
|
||||
window_end = min( seq_len, i + half_window + 1 )
|
||||
mask[i, window_start:window_end] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
@ -629,15 +631,15 @@ class Model(LlamaPreTrainedModel):
|
|||
):
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
bsz, seq_len, _ = inputs_embeds.size()
|
||||
|
||||
dtype = torch.bool
|
||||
device = inputs_embeds.device
|
||||
min_dtype = False # torch.iinfo(dtype).min # torch.finfo(dtype).min
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((bsz, seq_len), dtype=torch.bool, device=inputs_embeds.device)
|
||||
attention_mask = torch.ones((bsz, seq_len), dtype=dtype, device=device)
|
||||
|
||||
expanded_mask = torch.zeros(
|
||||
(bsz, 1, seq_len, seq_len),
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device
|
||||
)
|
||||
expanded_mask = torch.zeros( (bsz, 1, seq_len, seq_len), dtype=dtype, device=device )
|
||||
|
||||
for batch_index, aux_len in enumerate( aux_lens ):
|
||||
window_size = window_sizes[batch_index] if window_sizes is not None else None
|
||||
|
@ -653,25 +655,30 @@ class Model(LlamaPreTrainedModel):
|
|||
prom_start, prom_end = text_end, text_end + prom_len
|
||||
output_start, output_end = prom_end, prom_end + output_len
|
||||
|
||||
"""
|
||||
output_start, output_end = prom_end, seq_len # prom_end + output_len
|
||||
if text_len:
|
||||
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
|
||||
expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], text_start, text_end, text_window )
|
||||
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = True
|
||||
if prom_len:
|
||||
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
|
||||
expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], prom_start, prom_end, prom_window )
|
||||
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = True
|
||||
if output_len:
|
||||
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0
|
||||
expanded_mask[batch_index, 0] = self._apply_sliding_window( expanded_mask[batch_index, 0], output_start, output_end, output_window )
|
||||
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = True
|
||||
"""
|
||||
|
||||
if text_len:
|
||||
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = True if not text_window else self._sliding_window( text_len, text_window )
|
||||
if prom_len:
|
||||
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:text_end] = True
|
||||
expanded_mask[batch_index, 0, prom_start:prom_end, prom_start:prom_end] = True if not prom_window else self._sliding_window( prom_len, prom_window )
|
||||
if output_len:
|
||||
expanded_mask[batch_index, 0, output_start:output_end, text_start:text_end] = True
|
||||
expanded_mask[batch_index, 0, output_start:output_end, prom_start:prom_end] = True
|
||||
expanded_mask[batch_index, 0, output_start:output_end, output_start:output_end] = True if not output_window else self._sliding_window( output_len, output_window )
|
||||
|
||||
# apply the original attention mask
|
||||
expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||
|
||||
# invert from 1.0 = attend, 0.0 = masked to 0.0 = valid, -inf = masked
|
||||
inverted_mask = 1.0 - expanded_mask
|
||||
return inverted_mask.masked_fill(
|
||||
inverted_mask.to(dtype=torch.bool),
|
||||
torch.finfo(inputs_embeds.dtype).min
|
||||
)
|
||||
mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len).to(dtype)
|
||||
#mask = AttentionMaskConverter._unmask_unattended(mask, min_dtype)
|
||||
return mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
|
@ -684,6 +691,14 @@ class Model(LlamaPreTrainedModel):
|
|||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
# [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
|
||||
bsz, seq_len, _ = inputs_embeds.size()
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((bsz, seq_len), dtype=torch.bool, device=device)
|
||||
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
|
@ -705,9 +720,10 @@ class Model(LlamaPreTrainedModel):
|
|||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
# gut out the things that just shoves responsibility on SDPA's is_causal generating a mask because this causes problems
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
|
|
|
@ -62,7 +62,7 @@ def md5_hash( x ):
|
|||
return hashlib.md5(str(x).encode("utf-8")).hexdigest()
|
||||
|
||||
# removes entries from a dict if that key is missing from the source
|
||||
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True, ignore=["optimizer_params"] ):
|
||||
def prune_missing( source, dest, recurse=True, path=[], parent_is_obj=None, return_missing=True, ignore=["optimizer_params", "wandb_params"] ):
|
||||
is_obj = hasattr( source, "__dict__" )
|
||||
if parent_is_obj is None:
|
||||
parent_is_obj = is_obj
|
||||
|
|
Loading…
Reference in New Issue
Block a user