hopefully the final tweaks needed for this bastard of a model
This commit is contained in:
parent
00d1fed217
commit
5670fcb23f
|
@ -277,6 +277,7 @@ class ModelExperimentalSettings:
|
|||
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
|
||||
# * NAR-demask would semi-doubly train for AR
|
||||
# * the model wouldn't also need to learn when to predict the token in place
|
||||
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||
|
||||
#
|
||||
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
||||
|
|
|
@ -548,7 +548,12 @@ class TTS():
|
|||
)
|
||||
else:
|
||||
raise Exception("!")
|
||||
|
||||
"""
|
||||
len_list = [ 3 * cfg.dataset.frames_per_second ]
|
||||
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
|
||||
**(sampling_kwargs),
|
||||
)
|
||||
"""
|
||||
|
||||
# to-do: care about batching later
|
||||
resps = resps_list[0]
|
||||
|
|
|
@ -386,6 +386,21 @@ class Attention(nn.Module):
|
|||
tensor_layout="HND",
|
||||
is_causal=is_causal
|
||||
)
|
||||
elif mode in ["flex"]:
|
||||
def causal_mod(score, b, h, q_idx, kv_idx):
|
||||
if x_mask is not None:
|
||||
score = score + x_mask[b][0][q_idx][kv_idx]
|
||||
return score
|
||||
|
||||
attn_output, attn_weights = flex_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
score_mod=causal_mod,
|
||||
enable_gqa=True,
|
||||
scale=self.head_dim**-0.5,
|
||||
return_lse=True,
|
||||
)
|
||||
elif mode in ["fused_attn"]:
|
||||
attn_output = fused_attn_func(
|
||||
query_states,
|
||||
|
@ -411,6 +426,8 @@ class Attention(nn.Module):
|
|||
)
|
||||
elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
||||
if isinstance( is_causal, list ):
|
||||
is_causal = is_causal[0]
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
|
@ -419,10 +436,20 @@ class Attention(nn.Module):
|
|||
dropout_p=dropout_rate,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
elif mode == "sdpa":
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=x_mask,
|
||||
dropout_p=dropout_rate,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
else:
|
||||
is_causal = True if x_mask is None and q_len > 1 else False
|
||||
with torch.nn.attention.sdpa_kernel(self.attn_mode):
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
|
@ -599,11 +626,11 @@ class Model(LlamaPreTrainedModel):
|
|||
text_start, text_end = 0, aux_len[0]
|
||||
|
||||
prom_start, prom_end = text_end, text_end + aux_len[1]
|
||||
output_start = prom_end
|
||||
output_start, output_end = prom_end, prom_end + aux_len[2]
|
||||
|
||||
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
|
||||
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
|
||||
expanded_mask[batch_index, 0, output_start:, :] = 1.0
|
||||
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0
|
||||
|
||||
# apply the original attention mask
|
||||
expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||
|
|
|
@ -329,7 +329,7 @@ class Base_V2(nn.Module):
|
|||
ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False
|
||||
|
||||
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
||||
len_parallel_training = False # config.experimental.len_parallel_training if config is not None else True
|
||||
len_parallel_training = config.experimental.len_parallel_training if config is not None else False
|
||||
predict_causally = config.experimental.predict_causally if config is not None else False
|
||||
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
||||
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
||||
|
@ -434,8 +434,7 @@ class Base_V2(nn.Module):
|
|||
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
||||
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
|
||||
self.len_emb = ml.Embedding(11, d_model)
|
||||
# to-do: un-autoregressivefy len inferencing, and have it trained parallel to normal training through a separate head or something
|
||||
self.len_emb = ml.Embedding(11, d_model) # unused
|
||||
|
||||
self.audio_emb = None
|
||||
self.proms_emb = None
|
||||
|
@ -477,7 +476,7 @@ class Base_V2(nn.Module):
|
|||
training=training,
|
||||
use_ln=per_level_normalization,
|
||||
)
|
||||
self.len_decoder = AuxDecoder( d_model, 11 )
|
||||
self.len_decoder = AuxDecoder( d_model, 11 ) # to-do: adjust this
|
||||
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
||||
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
||||
|
||||
|
@ -1015,9 +1014,8 @@ class Base_V2(nn.Module):
|
|||
loss_factors.append( loss_factor )
|
||||
loss_names.append( name )
|
||||
else:
|
||||
if name == "resp" and self.len_parallel_training:
|
||||
if name == "resp":
|
||||
resp_durations.append( token.shape[0] )
|
||||
|
||||
for level in range( self.n_resp_levels ):
|
||||
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
|
||||
continue
|
||||
|
@ -1108,6 +1106,16 @@ class Base_V2(nn.Module):
|
|||
stats[f'acc[k={k_hi}]'] = []
|
||||
stats[f"acc[k={k_hi}]"] = acc_k_hi
|
||||
|
||||
# check if len logits are provided
|
||||
if logits_aux is not None:
|
||||
len_factor = 0.01
|
||||
aux_loss_logit = torch.cat( logits_aux )
|
||||
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
||||
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
||||
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype )
|
||||
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
||||
|
||||
return LossStats(loss, stats)
|
||||
|
||||
def forward(
|
||||
|
@ -1163,20 +1171,22 @@ class Base_V2(nn.Module):
|
|||
|
||||
# create special masks
|
||||
# to-do, create it if mixed (although I expect this model to be purely non-causal)
|
||||
if self.use_segmented_attention_mask and not any(is_causal):
|
||||
aux_lens = torch.ones((batch_size, 2), device=x.device, dtype=torch.int32) * 2
|
||||
# fill aux lens
|
||||
for batch_index, batch_input in enumerate( inputs ):
|
||||
for name, input in batch_input:
|
||||
if name in ["phn", "text"]:
|
||||
aux_lens[batch_index][0] = input.shape[0]
|
||||
elif name == "lang":
|
||||
aux_lens[batch_index][0] += 2
|
||||
elif name == "prom":
|
||||
aux_lens[batch_index][1] = input.shape[0]
|
||||
elif name == "tone":
|
||||
aux_lens[batch_index][1] += 2
|
||||
aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32)
|
||||
# fill aux lens
|
||||
for batch_index, batch_input in enumerate( inputs ):
|
||||
for name, input in batch_input:
|
||||
if name in ["phn", "text"]:
|
||||
aux_lens[batch_index][0] = input.shape[0]
|
||||
elif name == "lang":
|
||||
aux_lens[batch_index][0] += 2
|
||||
elif name == "prom":
|
||||
aux_lens[batch_index][1] = input.shape[0]
|
||||
elif name == "tone":
|
||||
aux_lens[batch_index][1] += 2
|
||||
elif name == "resp":
|
||||
aux_lens[batch_index][2] = input.shape[0]
|
||||
|
||||
if self.use_segmented_attention_mask and not any(is_causal):
|
||||
mask = self.model._update_segmented_mask( mask, x, aux_lens )
|
||||
|
||||
output = self._forward(
|
||||
|
@ -1190,43 +1200,44 @@ class Base_V2(nn.Module):
|
|||
|
||||
hidden_states = output.hidden_states
|
||||
|
||||
if self.use_streamlined_calc_loss:
|
||||
logits = self.audio_decoder( output.logits )
|
||||
# to-do: get len logits
|
||||
else:
|
||||
logits = [ logit for logit in output.logits ]
|
||||
grouped_logits = {}
|
||||
logits = self.audio_decoder( output.logits )
|
||||
"""
|
||||
logits = [ logit for logit in output.logits ]
|
||||
logits_aux = None
|
||||
|
||||
grouped_logits = {}
|
||||
|
||||
for batch_index in range( batch_size ):
|
||||
classifier_level = classifier_levels[batch_index]
|
||||
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
|
||||
classifier_level = "audio"
|
||||
|
||||
if classifier_level not in ["audio", "phn", "text", "len"]:
|
||||
continue
|
||||
|
||||
for batch_index in range( batch_size ):
|
||||
classifier_level = classifier_levels[batch_index]
|
||||
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
|
||||
classifier_level = "audio"
|
||||
if classifier_level not in grouped_logits:
|
||||
grouped_logits[classifier_level] = []
|
||||
|
||||
grouped_logits[classifier_level].append(batch_index)
|
||||
|
||||
if classifier_level not in ["audio", "phn", "text", "len"]:
|
||||
continue
|
||||
|
||||
if classifier_level not in grouped_logits:
|
||||
grouped_logits[classifier_level] = []
|
||||
|
||||
grouped_logits[classifier_level].append(batch_index)
|
||||
for classifier_level, decoders_indices in grouped_logits.items():
|
||||
if classifier_level == "audio":
|
||||
head = self.audio_decoder
|
||||
elif classifier_level == "phn":
|
||||
head = self.phn_decoder
|
||||
elif classifier_level == "text":
|
||||
head = self.text_decoder
|
||||
elif classifier_level == "len":
|
||||
head = self.len_decoder
|
||||
|
||||
for classifier_level, decoders_indices in grouped_logits.items():
|
||||
if classifier_level == "audio":
|
||||
head = self.audio_decoder
|
||||
elif classifier_level == "phn":
|
||||
head = self.phn_decoder
|
||||
elif classifier_level == "text":
|
||||
head = self.text_decoder
|
||||
elif classifier_level == "len":
|
||||
head = self.len_decoder
|
||||
|
||||
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
||||
decoders_logits = head( decoders_logits )
|
||||
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||
logits[batch_index] = logit
|
||||
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
||||
decoders_logits = head( decoders_logits )
|
||||
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||
logits[batch_index] = logit
|
||||
"""
|
||||
|
||||
# Remove padding
|
||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
||||
logits = [ logit[..., :l, :] for logit, l in zip(logits, map(len, x_list)) ]
|
||||
|
||||
if not training:
|
||||
loss = None
|
||||
|
@ -1235,9 +1246,26 @@ class Base_V2(nn.Module):
|
|||
self.loss = None
|
||||
self.stats = None
|
||||
|
||||
# grab duration if no resp is provided
|
||||
if aux_lens[0][2] == 0:
|
||||
# do duration prediction
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
# only keep the input
|
||||
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
|
||||
logits = logits_aux
|
||||
|
||||
# compute loss if the target is given
|
||||
else:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
# do duration prediction
|
||||
if self.len_parallel_training:
|
||||
logits_aux = self.len_decoder( output.logits )
|
||||
# only keep the input
|
||||
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||
else:
|
||||
logits_aux = None
|
||||
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, logits_aux=logits_aux, quant_levels=quant_levels )
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if output.loss is not None:
|
||||
|
|
Loading…
Reference in New Issue
Block a user