have loss for the NAR not-ignore the text prompt, I imagine this should help the NAR and explain why it's always had a bit of an issue with training
This commit is contained in:
parent
4aef798135
commit
22ffaf3a33
|
@ -415,16 +415,14 @@ class Base(nn.Module):
|
|||
|
||||
# process each batch
|
||||
for i in range(len(text_prom_list)):
|
||||
# for the AR, shift the text/input prompt and target prompt into the future by 1, and ignore the rolled back text token
|
||||
# for the AR and NAR, shift the text/input prompt into the future by 1, and ignore the rolled back token
|
||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||
text_prom_list[i][-1] = self.ignore_index
|
||||
|
||||
# for the AR, shift the target response into the future by 1, and ignore the rolled back text token
|
||||
if quant_levels is None or quant_levels[i] == 0:
|
||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||
targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps
|
||||
|
||||
text_prom_list[i][-1] = self.ignore_index
|
||||
targ_list[i][-1] = self.stop_token
|
||||
# for the NAR, ignore completely computing the loss against the text prompt
|
||||
else:
|
||||
text_prom_list[i][:] = self.ignore_index
|
||||
|
||||
# create the new target sequence to compute the loss against
|
||||
target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) )
|
||||
|
|
|
@ -102,11 +102,7 @@ def run_eval(engines, disabled_engines, eval_name, dl):
|
|||
min_length = min( ref_audio.shape[-1], hyp_audio.shape[-1] )
|
||||
ref_audio = ref_audio[..., 0:min_length]
|
||||
hyp_audio = hyp_audio[..., 0:min_length]
|
||||
try:
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
|
||||
except Exception as e:
|
||||
stats['loss'].append(0)
|
||||
print(traceback.format_exc())
|
||||
stats['loss'].append(mel_stft_loss(hyp_audio[None, :, :], ref_audio[None, :, :]).item())
|
||||
|
||||
processed = 0
|
||||
while processed < cfg.evaluation.size:
|
||||
|
|
Loading…
Reference in New Issue
Block a user