hopefully the final tweaks needed for this bastard of a model
This commit is contained in:
parent
00d1fed217
commit
5670fcb23f
|
@ -277,6 +277,7 @@ class ModelExperimentalSettings:
|
||||||
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
|
predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as
|
||||||
# * NAR-demask would semi-doubly train for AR
|
# * NAR-demask would semi-doubly train for AR
|
||||||
# * the model wouldn't also need to learn when to predict the token in place
|
# * the model wouldn't also need to learn when to predict the token in place
|
||||||
|
len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong)
|
||||||
|
|
||||||
#
|
#
|
||||||
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298
|
||||||
|
|
|
@ -548,7 +548,12 @@ class TTS():
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception("!")
|
raise Exception("!")
|
||||||
|
"""
|
||||||
|
len_list = [ 3 * cfg.dataset.frames_per_second ]
|
||||||
|
resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"],
|
||||||
|
**(sampling_kwargs),
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
# to-do: care about batching later
|
# to-do: care about batching later
|
||||||
resps = resps_list[0]
|
resps = resps_list[0]
|
||||||
|
|
|
@ -386,6 +386,21 @@ class Attention(nn.Module):
|
||||||
tensor_layout="HND",
|
tensor_layout="HND",
|
||||||
is_causal=is_causal
|
is_causal=is_causal
|
||||||
)
|
)
|
||||||
|
elif mode in ["flex"]:
|
||||||
|
def causal_mod(score, b, h, q_idx, kv_idx):
|
||||||
|
if x_mask is not None:
|
||||||
|
score = score + x_mask[b][0][q_idx][kv_idx]
|
||||||
|
return score
|
||||||
|
|
||||||
|
attn_output, attn_weights = flex_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
score_mod=causal_mod,
|
||||||
|
enable_gqa=True,
|
||||||
|
scale=self.head_dim**-0.5,
|
||||||
|
return_lse=True,
|
||||||
|
)
|
||||||
elif mode in ["fused_attn"]:
|
elif mode in ["fused_attn"]:
|
||||||
attn_output = fused_attn_func(
|
attn_output = fused_attn_func(
|
||||||
query_states,
|
query_states,
|
||||||
|
@ -411,6 +426,8 @@ class Attention(nn.Module):
|
||||||
)
|
)
|
||||||
elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
|
elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]:
|
||||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
|
||||||
|
if isinstance( is_causal, list ):
|
||||||
|
is_causal = is_causal[0]
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
|
@ -419,10 +436,20 @@ class Attention(nn.Module):
|
||||||
dropout_p=dropout_rate,
|
dropout_p=dropout_rate,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
else:
|
elif mode == "sdpa":
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
# is_causal = True if x_mask is None and q_len > 1 else False
|
# is_causal = True if x_mask is None and q_len > 1 else False
|
||||||
|
is_causal = True if x_mask is None and q_len > 1 else False
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=x_mask,
|
||||||
|
dropout_p=dropout_rate,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
is_causal = True if x_mask is None and q_len > 1 else False
|
is_causal = True if x_mask is None and q_len > 1 else False
|
||||||
with torch.nn.attention.sdpa_kernel(self.attn_mode):
|
with torch.nn.attention.sdpa_kernel(self.attn_mode):
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
@ -599,11 +626,11 @@ class Model(LlamaPreTrainedModel):
|
||||||
text_start, text_end = 0, aux_len[0]
|
text_start, text_end = 0, aux_len[0]
|
||||||
|
|
||||||
prom_start, prom_end = text_end, text_end + aux_len[1]
|
prom_start, prom_end = text_end, text_end + aux_len[1]
|
||||||
output_start = prom_end
|
output_start, output_end = prom_end, prom_end + aux_len[2]
|
||||||
|
|
||||||
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
|
expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0
|
||||||
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
|
expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0
|
||||||
expanded_mask[batch_index, 0, output_start:, :] = 1.0
|
expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0
|
||||||
|
|
||||||
# apply the original attention mask
|
# apply the original attention mask
|
||||||
expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len)
|
||||||
|
|
|
@ -329,7 +329,7 @@ class Base_V2(nn.Module):
|
||||||
ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False
|
ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False
|
||||||
|
|
||||||
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True
|
||||||
len_parallel_training = False # config.experimental.len_parallel_training if config is not None else True
|
len_parallel_training = config.experimental.len_parallel_training if config is not None else False
|
||||||
predict_causally = config.experimental.predict_causally if config is not None else False
|
predict_causally = config.experimental.predict_causally if config is not None else False
|
||||||
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False
|
||||||
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto"
|
||||||
|
@ -434,8 +434,7 @@ class Base_V2(nn.Module):
|
||||||
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
|
self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None
|
||||||
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None
|
||||||
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
|
self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None
|
||||||
self.len_emb = ml.Embedding(11, d_model)
|
self.len_emb = ml.Embedding(11, d_model) # unused
|
||||||
# to-do: un-autoregressivefy len inferencing, and have it trained parallel to normal training through a separate head or something
|
|
||||||
|
|
||||||
self.audio_emb = None
|
self.audio_emb = None
|
||||||
self.proms_emb = None
|
self.proms_emb = None
|
||||||
|
@ -477,7 +476,7 @@ class Base_V2(nn.Module):
|
||||||
training=training,
|
training=training,
|
||||||
use_ln=per_level_normalization,
|
use_ln=per_level_normalization,
|
||||||
)
|
)
|
||||||
self.len_decoder = AuxDecoder( d_model, 11 )
|
self.len_decoder = AuxDecoder( d_model, 11 ) # to-do: adjust this
|
||||||
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
self.phn_decoder = AuxDecoder( d_model, n_phn_tokens )
|
||||||
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
self.text_decoder = AuxDecoder( d_model, n_text_tokens )
|
||||||
|
|
||||||
|
@ -1015,9 +1014,8 @@ class Base_V2(nn.Module):
|
||||||
loss_factors.append( loss_factor )
|
loss_factors.append( loss_factor )
|
||||||
loss_names.append( name )
|
loss_names.append( name )
|
||||||
else:
|
else:
|
||||||
if name == "resp" and self.len_parallel_training:
|
if name == "resp":
|
||||||
resp_durations.append( token.shape[0] )
|
resp_durations.append( token.shape[0] )
|
||||||
|
|
||||||
for level in range( self.n_resp_levels ):
|
for level in range( self.n_resp_levels ):
|
||||||
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
|
if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'):
|
||||||
continue
|
continue
|
||||||
|
@ -1108,6 +1106,16 @@ class Base_V2(nn.Module):
|
||||||
stats[f'acc[k={k_hi}]'] = []
|
stats[f'acc[k={k_hi}]'] = []
|
||||||
stats[f"acc[k={k_hi}]"] = acc_k_hi
|
stats[f"acc[k={k_hi}]"] = acc_k_hi
|
||||||
|
|
||||||
|
# check if len logits are provided
|
||||||
|
if logits_aux is not None:
|
||||||
|
len_factor = 0.01
|
||||||
|
aux_loss_logit = torch.cat( logits_aux )
|
||||||
|
#aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 )
|
||||||
|
#loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor
|
||||||
|
|
||||||
|
aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype )
|
||||||
|
loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor
|
||||||
|
|
||||||
return LossStats(loss, stats)
|
return LossStats(loss, stats)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -1163,20 +1171,22 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
# create special masks
|
# create special masks
|
||||||
# to-do, create it if mixed (although I expect this model to be purely non-causal)
|
# to-do, create it if mixed (although I expect this model to be purely non-causal)
|
||||||
if self.use_segmented_attention_mask and not any(is_causal):
|
aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32)
|
||||||
aux_lens = torch.ones((batch_size, 2), device=x.device, dtype=torch.int32) * 2
|
# fill aux lens
|
||||||
# fill aux lens
|
for batch_index, batch_input in enumerate( inputs ):
|
||||||
for batch_index, batch_input in enumerate( inputs ):
|
for name, input in batch_input:
|
||||||
for name, input in batch_input:
|
if name in ["phn", "text"]:
|
||||||
if name in ["phn", "text"]:
|
aux_lens[batch_index][0] = input.shape[0]
|
||||||
aux_lens[batch_index][0] = input.shape[0]
|
elif name == "lang":
|
||||||
elif name == "lang":
|
aux_lens[batch_index][0] += 2
|
||||||
aux_lens[batch_index][0] += 2
|
elif name == "prom":
|
||||||
elif name == "prom":
|
aux_lens[batch_index][1] = input.shape[0]
|
||||||
aux_lens[batch_index][1] = input.shape[0]
|
elif name == "tone":
|
||||||
elif name == "tone":
|
aux_lens[batch_index][1] += 2
|
||||||
aux_lens[batch_index][1] += 2
|
elif name == "resp":
|
||||||
|
aux_lens[batch_index][2] = input.shape[0]
|
||||||
|
|
||||||
|
if self.use_segmented_attention_mask and not any(is_causal):
|
||||||
mask = self.model._update_segmented_mask( mask, x, aux_lens )
|
mask = self.model._update_segmented_mask( mask, x, aux_lens )
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
|
@ -1190,43 +1200,44 @@ class Base_V2(nn.Module):
|
||||||
|
|
||||||
hidden_states = output.hidden_states
|
hidden_states = output.hidden_states
|
||||||
|
|
||||||
if self.use_streamlined_calc_loss:
|
logits = self.audio_decoder( output.logits )
|
||||||
logits = self.audio_decoder( output.logits )
|
"""
|
||||||
# to-do: get len logits
|
logits = [ logit for logit in output.logits ]
|
||||||
else:
|
logits_aux = None
|
||||||
logits = [ logit for logit in output.logits ]
|
|
||||||
grouped_logits = {}
|
grouped_logits = {}
|
||||||
|
|
||||||
|
for batch_index in range( batch_size ):
|
||||||
|
classifier_level = classifier_levels[batch_index]
|
||||||
|
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
|
||||||
|
classifier_level = "audio"
|
||||||
|
|
||||||
|
if classifier_level not in ["audio", "phn", "text", "len"]:
|
||||||
|
continue
|
||||||
|
|
||||||
for batch_index in range( batch_size ):
|
if classifier_level not in grouped_logits:
|
||||||
classifier_level = classifier_levels[batch_index]
|
grouped_logits[classifier_level] = []
|
||||||
if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"):
|
|
||||||
classifier_level = "audio"
|
grouped_logits[classifier_level].append(batch_index)
|
||||||
|
|
||||||
if classifier_level not in ["audio", "phn", "text", "len"]:
|
for classifier_level, decoders_indices in grouped_logits.items():
|
||||||
continue
|
if classifier_level == "audio":
|
||||||
|
head = self.audio_decoder
|
||||||
if classifier_level not in grouped_logits:
|
elif classifier_level == "phn":
|
||||||
grouped_logits[classifier_level] = []
|
head = self.phn_decoder
|
||||||
|
elif classifier_level == "text":
|
||||||
grouped_logits[classifier_level].append(batch_index)
|
head = self.text_decoder
|
||||||
|
elif classifier_level == "len":
|
||||||
|
head = self.len_decoder
|
||||||
|
|
||||||
for classifier_level, decoders_indices in grouped_logits.items():
|
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
||||||
if classifier_level == "audio":
|
decoders_logits = head( decoders_logits )
|
||||||
head = self.audio_decoder
|
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
||||||
elif classifier_level == "phn":
|
logits[batch_index] = logit
|
||||||
head = self.phn_decoder
|
"""
|
||||||
elif classifier_level == "text":
|
|
||||||
head = self.text_decoder
|
|
||||||
elif classifier_level == "len":
|
|
||||||
head = self.len_decoder
|
|
||||||
|
|
||||||
decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ])
|
|
||||||
decoders_logits = head( decoders_logits )
|
|
||||||
for batch_index, logit in zip( decoders_indices, decoders_logits ):
|
|
||||||
logits[batch_index] = logit
|
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ]
|
logits = [ logit[..., :l, :] for logit, l in zip(logits, map(len, x_list)) ]
|
||||||
|
|
||||||
if not training:
|
if not training:
|
||||||
loss = None
|
loss = None
|
||||||
|
@ -1235,9 +1246,26 @@ class Base_V2(nn.Module):
|
||||||
self.loss = None
|
self.loss = None
|
||||||
self.stats = None
|
self.stats = None
|
||||||
|
|
||||||
|
# grab duration if no resp is provided
|
||||||
|
if aux_lens[0][2] == 0:
|
||||||
|
# do duration prediction
|
||||||
|
logits_aux = self.len_decoder( output.logits )
|
||||||
|
# only keep the input
|
||||||
|
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||||
|
|
||||||
|
logits = logits_aux
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
else:
|
else:
|
||||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
# do duration prediction
|
||||||
|
if self.len_parallel_training:
|
||||||
|
logits_aux = self.len_decoder( output.logits )
|
||||||
|
# only keep the input
|
||||||
|
logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ]
|
||||||
|
else:
|
||||||
|
logits_aux = None
|
||||||
|
|
||||||
|
loss, stats = self.calc_loss( inputs=inputs, logits=logits, logits_aux=logits_aux, quant_levels=quant_levels )
|
||||||
|
|
||||||
# include any additional losses (for example: MoE router)
|
# include any additional losses (for example: MoE router)
|
||||||
if output.loss is not None:
|
if output.loss is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user