From 22ffaf3a33d77b67d815d8541f3c8b0f2402d205 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 15 Sep 2023 19:08:44 -0500 Subject: [PATCH] have loss for the NAR not-ignore the text prompt, I imagine this should help the NAR and explain why it's always had a bit of an issue with training --- vall_e/models/base.py | 12 +++++------- vall_e/train.py | 6 +----- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index f8aacf6..69377dd 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -415,16 +415,14 @@ class Base(nn.Module): # process each batch for i in range(len(text_prom_list)): - # for the AR, shift the text/input prompt and target prompt into the future by 1, and ignore the rolled back text token + # for the AR and NAR, shift the text/input prompt into the future by 1, and ignore the rolled back token + text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) + text_prom_list[i][-1] = self.ignore_index + + # for the AR, shift the target response into the future by 1, and ignore the rolled back text token if quant_levels is None or quant_levels[i] == 0: - text_prom_list[i] = text_prom_list[i].roll(-1, dims=0) targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps - - text_prom_list[i][-1] = self.ignore_index targ_list[i][-1] = self.stop_token - # for the NAR, ignore completely computing the loss against the text prompt - else: - text_prom_list[i][:] = self.ignore_index # create the new target sequence to compute the loss against target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) ) diff --git a/vall_e/train.py b/vall_e/train.py index 289b925..8617ea5 100755 --- a/vall_e/train.py +++ b/vall_e/train.py @@ -102,11 +102,7 @@ def run_eval(engines, disabled_engines, eval_name, dl): min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] ) ref_audio = ref_audio[..., 0:min_length] hyp_audio = hyp_audio[..., 0:min_length] - try: - stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) - except Exception as e: - stats['loss'].append(0) - print(traceback.format_exc()) + stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item()) processed = 0 while processed < cfg.evaluation.size: