diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 965a19a..ec2d611 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -255,7 +255,7 @@ class AR_NAR(Base): cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) start_noise = sampling_kwargs.get("denoise_start", 0.0) end_noise = sampling_kwargs.get("denoise_end", 1.0) - remasking = sampling_kwargs.get("remasking", False) + remasking = sampling_kwargs.get("remasking", True) max_steps = math.floor(max_steps * (end_noise - start_noise)) len_list = [ clamp(l, min_length, max_length) for l in len_list ] @@ -297,7 +297,7 @@ class AR_NAR(Base): # get noise level, per cosine scheduling noise_p = math.cos( timestep * math.pi * 0.5 ) # proportion of tokens to remask - remask_p = 1.0 / max_steps if remasking else 0 + remask_p = 1.0 / (max_steps * 2) if remasking else 0 # pick the worst scoring tokens to mask off masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] # mask off inputs