diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 5115e773..85734782 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -207,9 +207,10 @@ class TransformerDiffusionWithPointConditioning(nn.Module): if custom_conditioning_fetcher is not None: cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) else: - if self.conditioning_masking > 0: + if self.training and self.conditioning_masking > 0: cond_op_len = x.shape[-1] - mask_len = int(cond_op_len * self.conditioning_masking) + mask_prop = random.random() * self.conditioning_masking + mask_len = int(cond_op_len * mask_prop) if mask_len > 0: start = random.randint(0, (cond_op_len-mask_len)) + cond_start conditioning_input[:,:,start:(start+mask_len)] = 0