Fix another edge case
This commit is contained in:
parent
808a1a4a31
commit
802998674e
|
@ -203,7 +203,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
else:
|
||||
if self.training and self.conditioning_masking > 0:
|
||||
mask_prop = random.random() * self.conditioning_masking
|
||||
mask_len = max(int(N * mask_prop), 4)
|
||||
mask_len = max(int(N * mask_prop), 16)
|
||||
assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}"
|
||||
seg_start = random.randint(8, (N-mask_len)) + cond_start
|
||||
# Readjust mask_len to ensure at least 8 sequence elements on the end as well
|
||||
|
|
Loading…
Reference in New Issue
Block a user