From 286918c581a707ac6ff4011e35036949b4e16e22 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 1 Jul 2022 21:43:30 -0600 Subject: [PATCH] conditioning masking is random --- codes/models/audio/music/tfdpc_v5.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 5115e773..85734782 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -207,9 +207,10 @@ class TransformerDiffusionWithPointConditioning(nn.Module): if custom_conditioning_fetcher is not None: cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) else: - if self.conditioning_masking > 0: + if self.training and self.conditioning_masking > 0: cond_op_len = x.shape[-1] - mask_len = int(cond_op_len * self.conditioning_masking) + mask_prop = random.random() * self.conditioning_masking + mask_len = int(cond_op_len * mask_prop) if mask_len > 0: start = random.randint(0, (cond_op_len-mask_len)) + cond_start conditioning_input[:,:,start:(start+mask_len)] = 0