when you uhh when you for once use your main rig to test and forgot to and when you port things back over
This commit is contained in:
parent
d9e18037cc
commit
f8e1d110dc
|
@ -457,9 +457,9 @@ class AR_NAR(Base):
|
||||||
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
|
||||||
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
|
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
|
||||||
if score_masked_only:
|
if score_masked_only:
|
||||||
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
|
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, unfiltered_sampled.scores ) ]
|
||||||
else:
|
else:
|
||||||
scores = [ scores for scores in sampled.scores ]
|
scores = [ scores for scores in unfiltered_sampled.scores ]
|
||||||
|
|
||||||
return resps_list
|
return resps_list
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user