diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py
index 813fdda..bf89fdc 100644
--- a/vall_e/models/ar_nar.py
+++ b/vall_e/models/ar_nar.py
@@ -457,9 +457,9 @@ class AR_NAR(Base):
 			resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
 			# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
 			if score_masked_only:
-				scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
+				scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, unfiltered_sampled.scores ) ]
 			else:
-				scores = [ scores for scores in sampled.scores ]
+				scores = [ scores for scores in unfiltered_sampled.scores ]
 
 		return resps_list