forked from mrq/DL-Art-School
conditioning masking is random
This commit is contained in:
parent
e06ee1b6f3
commit
286918c581
|
@ -207,9 +207,10 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
if custom_conditioning_fetcher is not None:
|
||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
||||
else:
|
||||
if self.conditioning_masking > 0:
|
||||
if self.training and self.conditioning_masking > 0:
|
||||
cond_op_len = x.shape[-1]
|
||||
mask_len = int(cond_op_len * self.conditioning_masking)
|
||||
mask_prop = random.random() * self.conditioning_masking
|
||||
mask_len = int(cond_op_len * mask_prop)
|
||||
if mask_len > 0:
|
||||
start = random.randint(0, (cond_op_len-mask_len)) + cond_start
|
||||
conditioning_input[:,:,start:(start+mask_len)] = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user