From a1184586efa011cf23f419b0837b234bde325ab5 Mon Sep 17 00:00:00 2001 From: mrq Date: Mon, 31 Mar 2025 20:59:13 -0500 Subject: [PATCH] should never have trusted mse_loss, it never works --- vall_e/config.py | 1 + vall_e/models/ar_nar_v2.py | 4 +++- vall_e/models/base_v2.py | 35 ++++++++++++++++++++++++++--------- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index 9b8c251..3422caa 100644 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -297,6 +297,7 @@ class ModelExperimentalSettings: # * NAR-demask would semi-doubly train for AR # * the model wouldn't also need to learn when to predict the token in place len_parallel_training: bool = True # used for version >= 7, computes len loss alongside normal training through using the input sequence (surely nothing can go wrong) + len_use_logits: bool = False # whether to treat duration prediction as a nll/logits task or use a raw, continuous float len_loss_factor: float = 0.00001 # loss factor for len calculation, very small because it mucks up loss scaling under float16 parallel_attention_mask_dropout: float = 0.0 # randomly sets to a causal attention mask when training NAR-len demasking layer_dropout_p: float = 0.0 # performs layer dropout, which I readded because it might actually help since the reference model had this at 0.1 diff --git a/vall_e/models/ar_nar_v2.py b/vall_e/models/ar_nar_v2.py index 7653687..ecb744b 100644 --- a/vall_e/models/ar_nar_v2.py +++ b/vall_e/models/ar_nar_v2.py @@ -959,8 +959,10 @@ def example_usage(): phns_list, proms_list, resp_list, task_list = sample_data( task ) if task == "tts-nar": - # len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 ) + len_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["len"], max_steps=5, temperature=0.0 ) + print( len_list ) len_list = [ r.shape[0] for r in resp_list ] + print( len_list ) resps_list = engine( phns_list=phns_list, proms_list=proms_list, len_list=len_list ) else: resps_list = engine( phns_list=phns_list, proms_list=proms_list, task_list=["tts"], max_duration=steps, temperature=1.0 ) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index 3c1dc52..9e4a06d 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -293,6 +293,7 @@ class Base_V2(nn.Module): resp_parallel_training = config.experimental.resp_parallel_training if config is not None else True len_parallel_training = config.experimental.len_parallel_training if config is not None else False + len_use_logits = config.experimental.len_use_logits if config is not None else True predict_causally = config.experimental.predict_causally if config is not None else False monolithic_audio_encoder = config.experimental.monolithic_audio_encoder if config is not None else False audio_level_loss_factors = config.experimental.audio_level_loss_factors if config is not None else "auto" @@ -384,6 +385,7 @@ class Base_V2(nn.Module): self.predict_causally = predict_causally self.resp_parallel_training = resp_parallel_training self.len_parallel_training = len_parallel_training + self.len_use_logits = len_use_logits self.unified_position_ids = unified_position_ids self.inject_timestep_embedding = False # results in bad output self.masking_ratio = masking_ratio @@ -433,7 +435,7 @@ class Base_V2(nn.Module): self.n_resp_levels, use_ln=per_level_normalization, ) - self.len_decoder = AuxDecoder( d_model, 1 ) + self.len_decoder = AuxDecoder( d_model, 1 if not len_use_logits else (10 * 5) ) self.phn_decoder = AuxDecoder( d_model, n_phn_tokens ) self.text_decoder = AuxDecoder( d_model, n_text_tokens ) @@ -870,6 +872,10 @@ class Base_V2(nn.Module): return input + # handles "tokenizing" an integer + def tokenize_duration( seq_lens, device, dtype=torch.int64 ): + return torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in seq_lens], device=device, dtype=dtype) + k_lo, k_hi = 1, 20 level_loss_factors = self.audio_level_loss_factors @@ -1069,11 +1075,13 @@ class Base_V2(nn.Module): if logits_aux is not None: len_factor = self.len_loss_factor # 0.001 # to-do: user adjustable (it's really small because mse_loss causes wildly bigly losses) aux_loss_logit = torch.cat( logits_aux ) - #aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=torch.int64 ) - #loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor - - aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second - loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor + + if self.len_use_logits: + aux_loss_target = torch.tensor( [ [ int(i) for i in str( l ).zfill(5) ] for l in resp_durations ], device=aux_loss_logit.device, dtype=torch.int64).squeeze(0) + loss['len'] = F.cross_entropy( aux_loss_logit, aux_loss_target ) * len_factor + else: + aux_loss_target = torch.tensor( resp_durations, device=aux_loss_logit.device, dtype=aux_loss_logit.dtype ) / self.audio_frames_per_second + loss['len'] = F.mse_loss( aux_loss_logit, aux_loss_target ) * len_factor return LossStats(loss, stats) @@ -1225,17 +1233,26 @@ class Base_V2(nn.Module): # do duration prediction logits_aux = self.len_decoder( output.logits ) # it's more accurate this way - logits_aux = [ logit[..., -1, :1] for logit, aux_len in zip(logits_aux, aux_lens) ] + logits_aux = [ logit[..., -1, :] for logit, aux_len in zip(logits_aux, aux_lens) ] + # reshape logits + if self.len_use_logits: + # get tokens + logits_aux = [ logit.view(5, 10).argmax(dim=-1) for logit in logits_aux ] + # stitch + logits_aux = [ int("".join([ str(t.item()) for t in logit ])) / self.audio_frames_per_second for logit in logits_aux ] logits = logits_aux - # compute loss if the target is given else: # do duration prediction if self.len_parallel_training: logits_aux = self.len_decoder( output.logits ) # only keep the input - logits_aux = [ logit[..., aux_len[0] + aux_len[1], :1] for logit, aux_len in zip(logits_aux, aux_lens) ] + logits_aux = [ logit[..., aux_len[0] + aux_len[1], :] for logit, aux_len in zip(logits_aux, aux_lens) ] + + # reshape logits + if self.len_use_logits: + logits_aux = [ logit.view(5, 10) for logit in logits_aux ] else: logits_aux = None