when you uhh when you for once use your main rig to test and forgot to and when you port things back over

This commit is contained in:
mrq 2025-04-18 20:49:00 -05:00
parent d9e18037cc
commit f8e1d110dc

View File

@ -457,9 +457,9 @@ class AR_NAR(Base):
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
if score_masked_only:
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, unfiltered_sampled.scores ) ]
else:
scores = [ scores for scores in sampled.scores ]
scores = [ scores for scores in unfiltered_sampled.scores ]
return resps_list