This commit is contained in:
James Betker 2022-07-14 21:52:23 -06:00
parent 51291ab070
commit e13b1adfdb

View File

@ -219,9 +219,9 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
if self.training: if self.training:
# Arbitrarily restrict the context given. We should support short contexts and without this they are never encountered. # Arbitrarily restrict the context given. We should support short contexts and without this they are never encountered.
arb_context_cap = random.randint(50, 100) arb_context_cap = random.randint(50, 100)
if cond_left.shape[-1] > arb_context_cap and random() > .5: if cond_left.shape[-1] > arb_context_cap and random.random() > .5:
cond_left = cond_left[:,:,-arb_context_cap:] cond_left = cond_left[:,:,-arb_context_cap:]
if cond_right.shape[-1] > arb_context_cap and random() > .5: if cond_right.shape[-1] > arb_context_cap and random.random() > .5:
cond_right = cond_right[:,:,:arb_context_cap] cond_right = cond_right[:,:,:arb_context_cap]
elif cond_left is None: elif cond_left is None: