Fix some bugs

This commit is contained in:
James Betker 2022-07-05 10:37:47 -06:00
parent 2b128730e7
commit 7440e43531

View File

@ -205,15 +205,15 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
# Break up conditioning input into two random segments aligned with the input.
MIN_MARGIN = 8
assert N > (MIN_MARGIN*2+4), f"Input size too small. Was {N} but requires at least {MIN_MARGIN*2+4}"
break_pt = random.randint(2, N-MIN_MARGIN*2) + MIN_MARGIN
break_pt = random.randint(2, N-MIN_MARGIN*2-2) + MIN_MARGIN
cond_left = cond_aligned[:,:,:break_pt]
cond_right = cond_aligned[:,:,break_pt:]
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
to_remove = random.randint(0, cond_left.shape[-1]-MIN_MARGIN)
cond_left = cond_left[:,:,:-to_remove]
to_remove = random.randint(0, cond_right.shape[-1]-MIN_MARGIN)
cond_right = cond_right[:,:,to_remove:]
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
cond_left = cond_left[:,:,:-to_remove_left]
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
cond_right = cond_right[:,:,to_remove_right:]
# Concatenate the _pre and _post back on.
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)