conditioning masking is random

This commit is contained in:
James Betker 2022-07-01 21:43:30 -06:00
parent e06ee1b6f3
commit 286918c581

View File

@ -207,9 +207,10 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
if custom_conditioning_fetcher is not None:
cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb)
else:
if self.conditioning_masking > 0:
if self.training and self.conditioning_masking > 0:
cond_op_len = x.shape[-1]
mask_len = int(cond_op_len * self.conditioning_masking)
mask_prop = random.random() * self.conditioning_masking
mask_len = int(cond_op_len * mask_prop)
if mask_len > 0:
start = random.randint(0, (cond_op_len-mask_len)) + cond_start
conditioning_input[:,:,start:(start+mask_len)] = 0