This commit is contained in:
mrq 2024-12-08 15:24:21 -06:00
parent 1d460b9fe3
commit 3ef8894290

View File

@ -255,7 +255,7 @@ class AR_NAR(Base):
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75) cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
start_noise = sampling_kwargs.get("denoise_start", 0.0) start_noise = sampling_kwargs.get("denoise_start", 0.0)
end_noise = sampling_kwargs.get("denoise_end", 1.0) end_noise = sampling_kwargs.get("denoise_end", 1.0)
remasking = sampling_kwargs.get("remasking", False) remasking = sampling_kwargs.get("remasking", True)
max_steps = math.floor(max_steps * (end_noise - start_noise)) max_steps = math.floor(max_steps * (end_noise - start_noise))
len_list = [ clamp(l, min_length, max_length) for l in len_list ] len_list = [ clamp(l, min_length, max_length) for l in len_list ]
@ -297,7 +297,7 @@ class AR_NAR(Base):
# get noise level, per cosine scheduling # get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 ) noise_p = math.cos( timestep * math.pi * 0.5 )
# proportion of tokens to remask # proportion of tokens to remask
remask_p = 1.0 / max_steps if remasking else 0 remask_p = 1.0 / (max_steps * 2) if remasking else 0
# pick the worst scoring tokens to mask off # pick the worst scoring tokens to mask off
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# mask off inputs # mask off inputs