From f8e1d110dc9708d0f37ae78b0bf382bdea098df4 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 18 Apr 2025 20:49:00 -0500 Subject: [PATCH] when you uhh when you for once use your main rig to test and forgot to and when you port things back over --- vall_e/models/ar_nar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 813fdda..bf89fdc 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -457,9 +457,9 @@ class AR_NAR(Base): resps_list = [ torch.where( masked, input_ids, resps ).to(torch.int16) for masked, input_ids, resps in zip( is_masked, sampled_ids, resps_list ) ] # update scores, only updating tokens that were masked off, and force keeping unmasked tokens if score_masked_only: - scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, sampled.scores ) ] + scores = [ torch.where( masked, scores.t(), smallest_score ) for masked, scores in zip( is_masked, unfiltered_sampled.scores ) ] else: - scores = [ scores for scores in sampled.scores ] + scores = [ scores for scores in unfiltered_sampled.scores ] return resps_list