diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 77abe92b..69d96346 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -205,15 +205,15 @@ class TransformerDiffusionWithPointConditioning(nn.Module): # Break up conditioning input into two random segments aligned with the input. MIN_MARGIN = 8 assert N > (MIN_MARGIN*2+4), f"Input size too small. Was {N} but requires at least {MIN_MARGIN*2+4}" - break_pt = random.randint(2, N-MIN_MARGIN*2) + MIN_MARGIN + break_pt = random.randint(2, N-MIN_MARGIN*2-2) + MIN_MARGIN cond_left = cond_aligned[:,:,:break_pt] cond_right = cond_aligned[:,:,break_pt:] # Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this. - to_remove = random.randint(0, cond_left.shape[-1]-MIN_MARGIN) - cond_left = cond_left[:,:,:-to_remove] - to_remove = random.randint(0, cond_right.shape[-1]-MIN_MARGIN) - cond_right = cond_right[:,:,to_remove:] + to_remove_left = random.randint(1, cond_left.shape[-1]-MIN_MARGIN) + cond_left = cond_left[:,:,:-to_remove_left] + to_remove_right = random.randint(1, cond_right.shape[-1]-MIN_MARGIN) + cond_right = cond_right[:,:,to_remove_right:] # Concatenate the _pre and _post back on. cond_left_full = torch.cat([cond_pre, cond_left], dim=-1)