From 61de653ad990bfcc11be04176fabc00860c026fa Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 19 Mar 2025 14:20:19 -0500 Subject: [PATCH] now causal training should work again --- vall_e/models/base_v2.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index e376209..155a9c3 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -941,24 +941,26 @@ class Base_V2(nn.Module): if name != task_outputs.get(task_type, name): continue + sequence = token if token.dim() == 1: loss_factor = self.loss_factor(name) if loss_factor == 0.0: continue logit = logits[batch_index][start:end] - if causal or self.predict_causally: - l = self.causal_size - logit = logit[..., :-l, :] # shift the target so that token n... - token = token[..., l:] # ...predicts token n + 1 - + """ if self.logit_normalization: logit = logit_normalization( logit, self.logit_normalization ) """ - loss_targets.append( token.long() ) - loss_logits.append( logit ) + if causal or self.predict_causally: + l = self.causal_size + loss_targets.append( token[l:].long() ) # shift the target so that token n... + loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1 + else: + loss_targets.append( token.long() ) + loss_logits.append( logit ) loss_factors.append( loss_factor ) loss_names.append( name ) else: @@ -969,18 +971,21 @@ class Base_V2(nn.Module): continue logit = logits[batch_index][level][start:end] - if causal or self.predict_causally: - l = self.causal_size - logit = logit[..., :-l, :] # shift the target so that token n... - token = token[..., l:] # ...predicts token n + 1 """ if self.logit_normalization: logit = logit_normalization( logit, self.logit_normalization ) """ - loss_targets.append( token[:, level].long() ) - loss_logits.append( logit ) + if causal or self.predict_causally: + l = self.causal_size + loss_targets.append( token[l:, level].long() ) # shift the target so that token n... + loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1 + else: + loss_targets.append( token[:, level].long() ) + loss_logits.append( logit ) + + loss_factors.append( level_loss_factors[level] ) loss_names.append( name )