hopefully the final tweaks needed for this bastard of a model

This commit is contained in:
mrq 2025-03-10 20:59:11 -05:00
parent 00d1fed217
commit 5670fcb23f
4 changed files with 117 additions and 56 deletions

View File

@ -277,6 +277,7 @@ class ModelExperimentalSettings:
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
# * NAR-demask would semi-doubly train for AR
# * the model wouldn't also need to learn when to predict the token in place
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
#
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298

View File

@ -548,7 +548,12 @@ class TTS():
)
else:
raise Exception("!")
"""
len_list = [ 3 * cfg.dataset.frames_per_second ]
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
**(sampling_kwargs),
)
"""
# to-do: care about batching later
resps = resps_list[0]

View File

@ -386,6 +386,21 @@ class Attention(nn.Module):
tensor_layout="HND",
is_causal=is_causal
)
elif mode in ["flex"]:
def causal_mod(score, b, h, q_idx, kv_idx):
if x_mask is not None:
score = score + x_mask[b][0][q_idx][kv_idx]
return score
attn_output, attn_weights = flex_attention(
query_states,
key_states,
value_states,
score_mod=causal_mod,
enable_gqa=True,
scale=self.head_dim**-0.5,
return_lse=True,
)
elif mode in ["fused_attn"]:
attn_output = fused_attn_func(
query_states,
@ -411,6 +426,8 @@ class Attention(nn.Module):
)
elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
if isinstance( is_causal, list ):
is_causal = is_causal[0]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
@ -419,10 +436,20 @@ class Attention(nn.Module):
dropout_p=dropout_rate,
is_causal=is_causal,
)
else:
elif mode == "sdpa":
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# is_causal = True if x_mask is None and q_len > 1 else False
is_causal = True if x_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=x_mask,
dropout_p=dropout_rate,
is_causal=is_causal,
)
else:
is_causal = True if x_mask is None and q_len > 1 else False
with torch.nn.attention.sdpa_kernel(self.attn_mode):
attn_output = torch.nn.functional.scaled_dot_product_attention(
@ -599,11 +626,11 @@ class Model(LlamaPreTrainedModel):
text_start, text_end = 0, aux_len[0]
prom_start, prom_end = text_end, text_end + aux_len[1]
output_start = prom_end
output_start, output_end = prom_end, prom_end + aux_len[2]
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
expanded_mask[batch_index, 0, output_start:, :] = 1.0
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0
# apply the original attention mask
expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)

View File

@ -329,7 +329,7 @@ class Base_V2(nn.Module):
ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
len_parallel_training = False # config.experimental.len_parallel_training if config is not None else True
len_parallel_training = config.experimental.len_parallel_training if config is not None else False
predict_causally = config.experimental.predict_causally if config is not None else False
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
@ -434,8 +434,7 @@ class Base_V2(nn.Module):
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
self.len_emb = ml.Embedding(11, d_model)
# to-do: un-autoregressivefy len inferencing, and have it trained parallel to normal training through a separate head or something
self.len_emb = ml.Embedding(11, d_model) # unused
self.audio_emb = None
self.proms_emb = None
@ -477,7 +476,7 @@ class Base_V2(nn.Module):
training=training,
use_ln=per_level_normalization,
)
self.len_decoder = AuxDecoder( d_model, 11 )
self.len_decoder = AuxDecoder( d_model, 11 ) # to-do: adjust this
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
@ -1015,9 +1014,8 @@ class Base_V2(nn.Module):
loss_factors.append( loss_factor )
loss_names.append( name )
else:
if name == "resp" and self.len_parallel_training:
if name == "resp":
resp_durations.append( token.shape[0] )
for level in range( self.n_resp_levels ):
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
continue
@ -1108,6 +1106,16 @@ class Base_V2(nn.Module):
stats[f'acc[k={k_hi}]'] = []
stats[f"acc[k={k_hi}]"] = acc_k_hi
# check if len logits are provided
if logits_aux is not None:
len_factor = 0.01
aux_loss_logit = torch.cat( logits_aux )
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype )
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
return LossStats(loss, stats)
def forward(
@ -1163,20 +1171,22 @@ class Base_V2(nn.Module):
# create special masks
# to-do, create it if mixed (although I expect this model to be purely non-causal)
if self.use_segmented_attention_mask and not any(is_causal):
aux_lens = torch.ones((batch_size, 2), device=x.device, dtype=torch.int32) * 2
# fill aux lens
for batch_index, batch_input in enumerate( inputs ):
for name, input in batch_input:
if name in ["phn", "text"]:
aux_lens[batch_index][0] = input.shape[0]
elif name == "lang":
aux_lens[batch_index][0] += 2
elif name == "prom":
aux_lens[batch_index][1] = input.shape[0]
elif name == "tone":
aux_lens[batch_index][1] += 2
aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32)
# fill aux lens
for batch_index, batch_input in enumerate( inputs ):
for name, input in batch_input:
if name in ["phn", "text"]:
aux_lens[batch_index][0] = input.shape[0]
elif name == "lang":
aux_lens[batch_index][0] += 2
elif name == "prom":
aux_lens[batch_index][1] = input.shape[0]
elif name == "tone":
aux_lens[batch_index][1] += 2
elif name == "resp":
aux_lens[batch_index][2] = input.shape[0]
if self.use_segmented_attention_mask and not any(is_causal):
mask = self.model._update_segmented_mask( mask, x, aux_lens )
output = self._forward(
@ -1190,43 +1200,44 @@ class Base_V2(nn.Module):
hidden_states = output.hidden_states
if self.use_streamlined_calc_loss:
logits = self.audio_decoder( output.logits )
# to-do: get len logits
else:
logits = [ logit for logit in output.logits ]
grouped_logits = {}
logits = self.audio_decoder( output.logits )
"""
logits = [ logit for logit in output.logits ]
logits_aux = None
grouped_logits = {}
for batch_index in range( batch_size ):
classifier_level = classifier_levels[batch_index]
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
classifier_level = "audio"
if classifier_level not in ["audio", "phn", "text", "len"]:
continue
for batch_index in range( batch_size ):
classifier_level = classifier_levels[batch_index]
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
classifier_level = "audio"
if classifier_level not in grouped_logits:
grouped_logits[classifier_level] = []
grouped_logits[classifier_level].append(batch_index)
if classifier_level not in ["audio", "phn", "text", "len"]:
continue
if classifier_level not in grouped_logits:
grouped_logits[classifier_level] = []
grouped_logits[classifier_level].append(batch_index)
for classifier_level, decoders_indices in grouped_logits.items():
if classifier_level == "audio":
head = self.audio_decoder
elif classifier_level == "phn":
head = self.phn_decoder
elif classifier_level == "text":
head = self.text_decoder
elif classifier_level == "len":
head = self.len_decoder
for classifier_level, decoders_indices in grouped_logits.items():
if classifier_level == "audio":
head = self.audio_decoder
elif classifier_level == "phn":
head = self.phn_decoder
elif classifier_level == "text":
head = self.text_decoder
elif classifier_level == "len":
head = self.len_decoder
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
decoders_logits = head( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
decoders_logits = head( decoders_logits )
for batch_index, logit in zip( decoders_indices, decoders_logits ):
logits[batch_index] = logit
"""
# Remove padding
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
logits = [ logit[..., :l, :] for logit, l in zip(logits, map(len, x_list)) ]
if not training:
loss = None
@ -1235,9 +1246,26 @@ class Base_V2(nn.Module):
self.loss = None
self.stats = None
# grab duration if no resp is provided
if aux_lens[0][2] == 0:
# do duration prediction
logits_aux = self.len_decoder( output.logits )
# only keep the input
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
logits = logits_aux
# compute loss if the target is given
else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
# do duration prediction
if self.len_parallel_training:
logits_aux = self.len_decoder( output.logits )
# only keep the input
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
else:
logits_aux = None
loss, stats = self.calc_loss( inputs=inputs, logits=logits, logits_aux=logits_aux, quant_levels=quant_levels )
# include any additional losses (for example: MoE router)
if output.loss is not None: