diff --git a/vall_e/models/base_v2.py b/vall_e/models/base_v2.py
index e376209..155a9c3 100644
--- a/vall_e/models/base_v2.py
+++ b/vall_e/models/base_v2.py
@@ -941,24 +941,26 @@ class Base_V2(nn.Module):
 				if name != task_outputs.get(task_type, name):
 					continue
 
+				sequence = token
 				if token.dim() == 1:
 					loss_factor = self.loss_factor(name)
 					if loss_factor == 0.0:
 						continue
 
 					logit = logits[batch_index][start:end]
-					if causal or self.predict_causally:
-						l = self.causal_size
-						logit = logit[..., :-l, :] # shift the target so that token n...
-						token = token[..., l:] # ...predicts token n + 1
-
+					
 					"""
 					if self.logit_normalization:
 						logit = logit_normalization( logit, self.logit_normalization )
 					"""
 
-					loss_targets.append( token.long() )
-					loss_logits.append( logit )
+					if causal or self.predict_causally:
+						l = self.causal_size
+						loss_targets.append( token[l:].long() ) # shift the target so that token n...
+						loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
+					else:
+						loss_targets.append( token.long() )
+						loss_logits.append( logit )
 					loss_factors.append( loss_factor )
 					loss_names.append( name )
 				else:
@@ -969,18 +971,21 @@ class Base_V2(nn.Module):
 							continue
 
 						logit = logits[batch_index][level][start:end]
-						if causal or self.predict_causally:
-							l = self.causal_size
-							logit = logit[..., :-l, :] # shift the target so that token n...
-							token = token[..., l:] # ...predicts token n + 1
 
 						"""
 						if self.logit_normalization:
 							logit = logit_normalization( logit, self.logit_normalization )
 						"""
 
-						loss_targets.append( token[:, level].long() )
-						loss_logits.append( logit )
+						if causal or self.predict_causally:
+							l = self.causal_size
+							loss_targets.append( token[l:, level].long() ) # shift the target so that token n...
+							loss_logits.append( logit[..., :-l, :] ) # ...predicts token n + 1
+						else:
+							loss_targets.append( token[:, level].long() )
+							loss_logits.append( logit )
+
+
 						loss_factors.append( level_loss_factors[level] )
 						loss_names.append( name )