Fix some bugs
This commit is contained in:
parent
2b128730e7
commit
7440e43531
|
@ -205,15 +205,15 @@ class TransformerDiffusionWithPointConditioning(nn.Module):
|
|||
# Break up conditioning input into two random segments aligned with the input.
|
||||
MIN_MARGIN = 8
|
||||
assert N > (MIN_MARGIN*2+4), f"Input size too small. Was {N} but requires at least {MIN_MARGIN*2+4}"
|
||||
break_pt = random.randint(2, N-MIN_MARGIN*2) + MIN_MARGIN
|
||||
break_pt = random.randint(2, N-MIN_MARGIN*2-2) + MIN_MARGIN
|
||||
cond_left = cond_aligned[:,:,:break_pt]
|
||||
cond_right = cond_aligned[:,:,break_pt:]
|
||||
|
||||
# Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this.
|
||||
to_remove = random.randint(0, cond_left.shape[-1]-MIN_MARGIN)
|
||||
cond_left = cond_left[:,:,:-to_remove]
|
||||
to_remove = random.randint(0, cond_right.shape[-1]-MIN_MARGIN)
|
||||
cond_right = cond_right[:,:,to_remove:]
|
||||
to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN)
|
||||
cond_left = cond_left[:,:,:-to_remove_left]
|
||||
to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN)
|
||||
cond_right = cond_right[:,:,to_remove_right:]
|
||||
|
||||
# Concatenate the _pre and _post back on.
|
||||
cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user