From 5670fcb23fad897f4826d862b1422dc7cd20aa46 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 10 Mar 2025 20:59:11 -0500 Subject: [PATCH] hopefully the final tweaks needed for this bastard of a model --- vall_e/config.py | 1 + vall_e/inference.py | 7 +- vall_e/models/arch/llama.py | 33 ++++++++- vall_e/models/base_v2.py | 132 ++++++++++++++++++++++-------------- 4 files changed, 117 insertions(+), 56 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 4f403e5..363e003 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -277,6 +277,7 @@ class ModelExperimentalSettings: predict_causally: bool = False # predicts the next token even for the non-causal/NAR tasks, in theory this should also bolster the model, as # * NAR-demask would semi-doubly train for AR # * the model wouldn't also need to learn when to predict the token in place + len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong) # logit_normalization: float = 0 # performs logit normalization against the norms per the paper (https://arxiv.org/abs/2205.09310) per https://arxiv.org/abs/2406.05298 diff --git a/vall_e/inference.py b/vall_e/inference.py index 14b127d..7f7edb5 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -548,7 +548,12 @@ class TTS(): ) else: raise Exception("!") - + """ + len_list = [ 3 * cfg.dataset.frames_per_second ] + resps_list = model_nar( **input_kwargs, len_list=len_list, task_list=["tts"], + **(sampling_kwargs), + ) + """ # to-do: care about batching later resps = resps_list[0] diff --git a/vall_e/models/arch/llama.py b/vall_e/models/arch/llama.py index c1fb2f0..0a89c0c 100644 --- a/vall_e/models/arch/llama.py +++ b/vall_e/models/arch/llama.py @@ -386,6 +386,21 @@ class Attention(nn.Module): tensor_layout="HND", is_causal=is_causal ) + elif mode in ["flex"]: + def causal_mod(score, b, h, q_idx, kv_idx): + if x_mask is not None: + score = score + x_mask[b][0][q_idx][kv_idx] + return score + + attn_output, attn_weights = flex_attention( + query_states, + key_states, + value_states, + score_mod=causal_mod, + enable_gqa=True, + scale=self.head_dim**-0.5, + return_lse=True, + ) elif mode in ["fused_attn"]: attn_output = fused_attn_func( query_states, @@ -411,6 +426,8 @@ class Attention(nn.Module): ) elif mode in [torch.nn.attention.SDPBackend.FLASH_ATTENTION]: with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + if isinstance( is_causal, list ): + is_causal = is_causal[0] attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -419,10 +436,20 @@ class Attention(nn.Module): dropout_p=dropout_rate, is_causal=is_causal, ) - else: + elif mode == "sdpa": # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # is_causal = True if x_mask is None and q_len > 1 else False + is_causal = True if x_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=x_mask, + dropout_p=dropout_rate, + is_causal=is_causal, + ) + else: is_causal = True if x_mask is None and q_len > 1 else False with torch.nn.attention.sdpa_kernel(self.attn_mode): attn_output = torch.nn.functional.scaled_dot_product_attention( @@ -599,11 +626,11 @@ class Model(LlamaPreTrainedModel): text_start, text_end = 0, aux_len[0] prom_start, prom_end = text_end, text_end + aux_len[1] - output_start = prom_end + output_start, output_end = prom_end, prom_end + aux_len[2] expanded_mask[batch_index, 0, text_start:text_end, text_start:text_end] = 1.0 expanded_mask[batch_index, 0, prom_start:prom_end, text_start:prom_end] = 1.0 - expanded_mask[batch_index, 0, output_start:, :] = 1.0 + expanded_mask[batch_index, 0, output_start:output_end, text_start:output_end] = 1.0 # apply the original attention mask expanded_mask = expanded_mask * attention_mask[:, None, None, :].expand(bsz, 1, seq_len, seq_len) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index e608aef..7323424 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -329,7 +329,7 @@ class Base_V2(nn.Module): ignore_inputs_for_loss = config.experimental.ignore_inputs_for_loss if config is not None else False resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True - len_parallel_training = False # config.experimental.len_parallel_training if config is not None else True + len_parallel_training = config.experimental.len_parallel_training if config is not None else False predict_causally = config.experimental.predict_causally if config is not None else False monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto" @@ -434,8 +434,7 @@ class Base_V2(nn.Module): self.langs_emb = ml.Embedding(n_langs, d_model) if n_langs > 0 else None self.tasks_emb = ml.Embedding(n_tasks, d_model) if n_tasks > 0 else None self.tones_emb = ml.Embedding(n_tones, d_model) if n_tones > 0 else None - self.len_emb = ml.Embedding(11, d_model) - # to-do: un-autoregressivefy len inferencing, and have it trained parallel to normal training through a separate head or something + self.len_emb = ml.Embedding(11, d_model) # unused self.audio_emb = None self.proms_emb = None @@ -477,7 +476,7 @@ class Base_V2(nn.Module): training=training, use_ln=per_level_normalization, ) - self.len_decoder = AuxDecoder( d_model, 11 ) + self.len_decoder = AuxDecoder( d_model, 11 ) # to-do: adjust this self.phn_decoder = AuxDecoder( d_model, n_phn_tokens ) self.text_decoder = AuxDecoder( d_model, n_text_tokens ) @@ -1015,9 +1014,8 @@ class Base_V2(nn.Module): loss_factors.append( loss_factor ) loss_names.append( name ) else: - if name == "resp" and self.len_parallel_training: + if name == "resp": resp_durations.append( token.shape[0] ) - for level in range( self.n_resp_levels ): if not self.resp_parallel_training and not classifier_level.endswith(f':{level}:{level}'): continue @@ -1108,6 +1106,16 @@ class Base_V2(nn.Module): stats[f'acc[k={k_hi}]'] = [] stats[f"acc[k={k_hi}]"] = acc_k_hi + # check if len logits are provided + if logits_aux is not None: + len_factor = 0.01 + aux_loss_logit = torch.cat( logits_aux ) + #aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 ) + #loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor + + aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) + loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor + return LossStats(loss, stats) def forward( @@ -1163,20 +1171,22 @@ class Base_V2(nn.Module): # create special masks # to-do, create it if mixed (although I expect this model to be purely non-causal) - if self.use_segmented_attention_mask and not any(is_causal): - aux_lens = torch.ones((batch_size, 2), device=x.device, dtype=torch.int32) * 2 - # fill aux lens - for batch_index, batch_input in enumerate( inputs ): - for name, input in batch_input: - if name in ["phn", "text"]: - aux_lens[batch_index][0] = input.shape[0] - elif name == "lang": - aux_lens[batch_index][0] += 2 - elif name == "prom": - aux_lens[batch_index][1] = input.shape[0] - elif name == "tone": - aux_lens[batch_index][1] += 2 + aux_lens = torch.tensor([[2, 2, 0]] * batch_size, device=x.device, dtype=torch.int32) + # fill aux lens + for batch_index, batch_input in enumerate( inputs ): + for name, input in batch_input: + if name in ["phn", "text"]: + aux_lens[batch_index][0] = input.shape[0] + elif name == "lang": + aux_lens[batch_index][0] += 2 + elif name == "prom": + aux_lens[batch_index][1] = input.shape[0] + elif name == "tone": + aux_lens[batch_index][1] += 2 + elif name == "resp": + aux_lens[batch_index][2] = input.shape[0] + if self.use_segmented_attention_mask and not any(is_causal): mask = self.model._update_segmented_mask( mask, x, aux_lens ) output = self._forward( @@ -1190,43 +1200,44 @@ class Base_V2(nn.Module): hidden_states = output.hidden_states - if self.use_streamlined_calc_loss: - logits = self.audio_decoder( output.logits ) - # to-do: get len logits - else: - logits = [ logit for logit in output.logits ] - grouped_logits = {} + logits = self.audio_decoder( output.logits ) + """ + logits = [ logit for logit in output.logits ] + logits_aux = None + + grouped_logits = {} + + for batch_index in range( batch_size ): + classifier_level = classifier_levels[batch_index] + if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"): + classifier_level = "audio" + + if classifier_level not in ["audio", "phn", "text", "len"]: + continue - for batch_index in range( batch_size ): - classifier_level = classifier_levels[batch_index] - if classifier_level.startswith("AR:") or classifier_level.startswith("NAR:"): - classifier_level = "audio" + if classifier_level not in grouped_logits: + grouped_logits[classifier_level] = [] + + grouped_logits[classifier_level].append(batch_index) - if classifier_level not in ["audio", "phn", "text", "len"]: - continue - - if classifier_level not in grouped_logits: - grouped_logits[classifier_level] = [] - - grouped_logits[classifier_level].append(batch_index) + for classifier_level, decoders_indices in grouped_logits.items(): + if classifier_level == "audio": + head = self.audio_decoder + elif classifier_level == "phn": + head = self.phn_decoder + elif classifier_level == "text": + head = self.text_decoder + elif classifier_level == "len": + head = self.len_decoder - for classifier_level, decoders_indices in grouped_logits.items(): - if classifier_level == "audio": - head = self.audio_decoder - elif classifier_level == "phn": - head = self.phn_decoder - elif classifier_level == "text": - head = self.text_decoder - elif classifier_level == "len": - head = self.len_decoder - - decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ]) - decoders_logits = head( decoders_logits ) - for batch_index, logit in zip( decoders_indices, decoders_logits ): - logits[batch_index] = logit + decoders_logits = torch.stack([ logits[batch_index] for batch_index in decoders_indices ]) + decoders_logits = head( decoders_logits ) + for batch_index, logit in zip( decoders_indices, decoders_logits ): + logits[batch_index] = logit + """ # Remove padding - logits = [ hi[..., :li, :] for hi, li in zip(logits, map(len, x_list)) ] + logits = [ logit[..., :l, :] for logit, l in zip(logits, map(len, x_list)) ] if not training: loss = None @@ -1235,9 +1246,26 @@ class Base_V2(nn.Module): self.loss = None self.stats = None + # grab duration if no resp is provided + if aux_lens[0][2] == 0: + # do duration prediction + logits_aux = self.len_decoder( output.logits ) + # only keep the input + logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ] + + logits = logits_aux + # compute loss if the target is given else: - loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) + # do duration prediction + if self.len_parallel_training: + logits_aux = self.len_decoder( output.logits ) + # only keep the input + logits_aux = [ logit[..., aux_len[2], :1] for logit, aux_len in zip(logits_aux, aux_lens) ] + else: + logits_aux = None + + loss, stats = self.calc_loss( inputs=inputs, logits=logits, logits_aux=logits_aux, quant_levels=quant_levels ) # include any additional losses (for example: MoE router) if output.loss is not None: