when you uhh when you for once use your main rig to test and forgot to and when you port things back over
This commit is contained in:
parent
d9e18037cc
commit
f8e1d110dc
|
@ -457,9 +457,9 @@ class AR_NAR(Base):
|
|||
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
|
||||
if score_masked_only:
|
||||
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
|
||||
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, unfiltered_sampled.scores ) ]
|
||||
else:
|
||||
scores = [ scores for scores in sampled.scores ]
|
||||
scores = [ scores for scores in unfiltered_sampled.scores ]
|
||||
|
||||
return resps_list
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user