have loss for the NAR not-ignore the text prompt, I imagine this should help the NAR and explain why it's always had a bit of an issue with training

This commit is contained in:
mrq 2023-09-15 19:08:44 -05:00
parent 4aef798135
commit 22ffaf3a33
2 changed files with 6 additions and 12 deletions

View File

@ -415,16 +415,14 @@ class Base(nn.Module):
# process each batch
for i in range(len(text_prom_list)):
# for the AR, shift the text/input prompt and target prompt into the future by 1, and ignore the rolled back text token
# for the AR and NAR, shift the text/input prompt into the future by 1, and ignore the rolled back token
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
text_prom_list[i][-1] = self.ignore_index
# for the AR, shift the target response into the future by 1, and ignore the rolled back text token
if quant_levels is None or quant_levels[i] == 0:
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps
text_prom_list[i][-1] = self.ignore_index
targ_list[i][-1] = self.stop_token
# for the NAR, ignore completely computing the loss against the text prompt
else:
text_prom_list[i][:] = self.ignore_index
# create the new target sequence to compute the loss against
target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) )

View File

@ -102,11 +102,7 @@ def run_eval(engines, disabled_engines, eval_name, dl):
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
ref_audio = ref_audio[..., 0:min_length]
hyp_audio = hyp_audio[..., 0:min_length]
try:
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
except Exception as e:
stats['loss'].append(0)
print(traceback.format_exc())
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
processed = 0
while processed < cfg.evaluation.size: