forked from mrq/DL-Art-School
conditioning masking is random
This commit is contained in:
parent
e06ee1b6f3
commit
286918c581
|
@ -207,9 +207,10 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
||||||
if custom_conditioning_fetcher is not None:
|
if custom_conditioning_fetcher is not None:
|
||||||
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
|
||||||
else:
|
else:
|
||||||
if self.conditioning_masking > 0:
|
if self.training and self.conditioning_masking > 0:
|
||||||
cond_op_len = x.shape[-1]
|
cond_op_len = x.shape[-1]
|
||||||
mask_len = int(cond_op_len * self.conditioning_masking)
|
mask_prop = random.random() * self.conditioning_masking
|
||||||
|
mask_len = int(cond_op_len * mask_prop)
|
||||||
if mask_len > 0:
|
if mask_len > 0:
|
||||||
start = random.randint(0, (cond_op_len-mask_len)) + cond_start
|
start = random.randint(0, (cond_op_len-mask_len)) + cond_start
|
||||||
conditioning_input[:,:,start:(start+mask_len)] = 0
|
conditioning_input[:,:,start:(start+mask_len)] = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user