oops
This commit is contained in:
parent
1d460b9fe3
commit
3ef8894290
|
@ -255,7 +255,7 @@ class AR_NAR(Base):
|
|||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||
remasking = sampling_kwargs.get("remasking", False)
|
||||
remasking = sampling_kwargs.get("remasking", True)
|
||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||
|
||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||
|
@ -297,7 +297,7 @@ class AR_NAR(Base):
|
|||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# proportion of tokens to remask
|
||||
remask_p = 1.0 / max_steps if remasking else 0
|
||||
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||
# pick the worst scoring tokens to mask off
|
||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
# mask off inputs
|
||||
|
|
Loading…
Reference in New Issue
Block a user