diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 9e7d2e29..13df3a16 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -219,9 +219,9 @@ class TransformerDiffusionWithPointConditioning(nn.Module): if self.training: # Arbitrarily restrict the context given. We should support short contexts and without this they are never encountered. arb_context_cap = random.randint(50, 100) - if cond_left.shape[-1] > arb_context_cap and random() > .5: + if cond_left.shape[-1] > arb_context_cap and random.random() > .5: cond_left = cond_left[:,:,-arb_context_cap:] - if cond_right.shape[-1] > arb_context_cap and random() > .5: + if cond_right.shape[-1] > arb_context_cap and random.random() > .5: cond_right = cond_right[:,:,:arb_context_cap] elif cond_left is None: