This commit is contained in:
mrq 2025-05-05 13:03:44 -05:00
parent 5fe01ffc6c
commit b2b243e7e7
2 changed files with 4 additions and 4 deletions

View File

@ -298,7 +298,7 @@ class AR_NAR(Base):
end_noise = sampling_kwargs.get("denoise_end", 1.0)
max_steps = math.floor(max_steps * (end_noise - start_noise))
largest_score = 1.0
largest_score = 1.0 # to-do: validate that the scores I do return are normalized via softmax
smallest_score = 0.0 # -float("inf")
score_masked_only = sampling_kwargs.pop("sampling_scores_masked_only", False)
@ -457,7 +457,7 @@ class AR_NAR(Base):
resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ]
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
if score_masked_only:
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, unfiltered_sampled.scores ) ]
scores = [ torch.where( masked, scores.t(), largest_score ) for masked, scores in zip( is_masked, unfiltered_sampled.scores ) ]
else:
scores = [ scores for scores in unfiltered_sampled.scores ]

View File

@ -257,7 +257,7 @@ class AR_NAR_V2(Base_V2):
end_noise = sampling_kwargs.get("denoise_end", 1.0)
max_steps = math.floor(max_steps * (end_noise - start_noise))
largest_score = 1.0
largest_score = 1.0 # to-do: validate that the scores I do return are normalized via softmax
smallest_score = 0.0 # -float("inf")
score_masked_only = sampling_kwargs.pop("sampling_scores_masked_only", False)
@ -361,7 +361,7 @@ class AR_NAR_V2(Base_V2):
resps_list = [ torch.where( masked, ids.t(), resps ).to(torch.int16) for masked, ids, resps in zip( is_masked, sampled.ids, resps_list ) ]
# update scores, only updating tokens that were masked off, and force keeping unmasked tokens
if score_masked_only:
scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
scores = [ torch.where( masked, scores.t(), largest_score ) for masked, scores in zip( is_masked, sampled.scores ) ]
else:
scores = [ scores.t() for scores in sampled.scores ]