oops
This commit is contained in:
parent
1d460b9fe3
commit
3ef8894290
|
@ -255,7 +255,7 @@ class AR_NAR(Base):
|
||||||
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
cfg_rescale = sampling_kwargs.pop("cfg_rescale", 0.75)
|
||||||
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
start_noise = sampling_kwargs.get("denoise_start", 0.0)
|
||||||
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
end_noise = sampling_kwargs.get("denoise_end", 1.0)
|
||||||
remasking = sampling_kwargs.get("remasking", False)
|
remasking = sampling_kwargs.get("remasking", True)
|
||||||
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
max_steps = math.floor(max_steps * (end_noise - start_noise))
|
||||||
|
|
||||||
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
len_list = [ clamp(l, min_length, max_length) for l in len_list ]
|
||||||
|
@ -297,7 +297,7 @@ class AR_NAR(Base):
|
||||||
# get noise level, per cosine scheduling
|
# get noise level, per cosine scheduling
|
||||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||||
# proportion of tokens to remask
|
# proportion of tokens to remask
|
||||||
remask_p = 1.0 / max_steps if remasking else 0
|
remask_p = 1.0 / (max_steps * 2) if remasking else 0
|
||||||
# pick the worst scoring tokens to mask off
|
# pick the worst scoring tokens to mask off
|
||||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||||
# mask off inputs
|
# mask off inputs
|
||||||
|
|
Loading…
Reference in New Issue
Block a user