Fix another edge case

This commit is contained in:
James Betker 2022-07-04 16:47:57 -06:00
parent 808a1a4a31
commit 802998674e

View File

@ -203,7 +203,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
else:
if self.training and self.conditioning_masking > 0:
mask_prop = random.random() * self.conditioning_masking
mask_len = max(int(N * mask_prop), 4)
mask_len = max(int(N * mask_prop), 16)
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
seg_start = random.randint(8, (N-mask_len)) + cond_start
# Readjust mask_len to ensure at least 8 sequence elements on the end as well