diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py index e376209..155a9c3 100644 --- a/vall_e/models/base_v2.py +++ b/vall_e/models/base_v2.py @@ -941,24 +941,26 @@ class Base_V2(nn.Module): if name != task_outputs.get(task_type, name): continue + sequence = token if token.dim() == 1: loss_factor = self.loss_factor(name) if loss_factor == 0.0: continue logit = logits[batch_index][start:end] - if causal or self.predict_causally: - l = self.causal_size - logit = logit[..., :-l, :] # shift the target so that token n... - token = token[..., l:] # ...predicts token n + 1 - + """ if self.logit_normalization: logit = logit_normalization( logit, self.logit_normalization ) """ - loss_targets.append( token.long() ) - loss_logits.append( logit ) + if causal or self.predict_causally: + l = self.causal_size + loss_targets.append( token[l:].long() ) # shift the target so that token n... + loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1 + else: + loss_targets.append( token.long() ) + loss_logits.append( logit ) loss_factors.append( loss_factor ) loss_names.append( name ) else: @@ -969,18 +971,21 @@ class Base_V2(nn.Module): continue logit = logits[batch_index][level][start:end] - if causal or self.predict_causally: - l = self.causal_size - logit = logit[..., :-l, :] # shift the target so that token n... - token = token[..., l:] # ...predicts token n + 1 """ if self.logit_normalization: logit = logit_normalization( logit, self.logit_normalization ) """ - loss_targets.append( token[:, level].long() ) - loss_logits.append( logit ) + if causal or self.predict_causally: + l = self.causal_size + loss_targets.append( token[l:, level].long() ) # shift the target so that token n... + loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1 + else: + loss_targets.append( token[:, level].long() ) + loss_logits.append( logit ) + + loss_factors.append( level_loss_factors[level] ) loss_names.append( name )