From 3ef889429077358244c2e7cfe022940e0515dad2 Mon Sep 17 00:00:00 2001 From: mrq Date: Sun, 8 Dec 2024 15:24:21 -0600 Subject: [PATCH] oops --- vall_e/models/ar_nar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 965a19a..ec2d611 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -255,7 +255,7 @@ class AR_NAR(Base): cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) - remasking = sampling_kwargs.get("remasking", False) + remasking = sampling_kwargs.get("remasking", True) max_steps = math.floor(max_steps * (end_noise - start_noise)) len_list = [ clamp(l, min_length, max_length) for l in len_list ] @@ -297,7 +297,7 @@ class AR_NAR(Base): # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) # proportion of tokens to remask - remask_p = 1.0 / max_steps if remasking else 0 + remask_p = 1.0 / (max_steps * 2) if remasking else 0 # pick the worst scoring tokens to mask off masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] # mask off inputs