From 455943779ba0db99a97cdc366392d7ac2ca48b97 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 4 Jul 2022 08:16:14 -0600 Subject: [PATCH] Fix bug in conditioning segment fetching --- codes/models/audio/music/tfdpc_v5.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 2d20504d..a64f8368 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -203,9 +203,12 @@ class TransformerDiffusionWithPointConditioning(nn.Module): else: if self.training and self.conditioning_masking > 0: mask_prop = random.random() * self.conditioning_masking - mask_len = min(int(N * mask_prop), 4) + mask_len = max(int(N * mask_prop), 4) assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}" seg_start = random.randint(8, (N-mask_len)) + cond_start + # Readjust mask_len to ensure at least 8 sequence elements on the end as well + mask_len = min(mask_len, conditioning_input.shape[-1]-seg_start-8) + assert mask_len > 0, f"Error with mask_len: {conditioning_input.shape[-1], seg_start, mask_len, mask_prop, N}" seg_end = seg_start+mask_len conditioning_input[:,:,seg_start:seg_end] = 0 else: @@ -276,14 +279,14 @@ def register_tfdpc5(opt_net, opt): def test_cheater_model(): - clip = torch.randn(2, 256, 200) - cl = torch.randn(2, 256, 500) + clip = torch.randn(2, 256, 350) + cl = torch.randn(2, 256, 646) ts = torch.LongTensor([600, 600]) # For music: model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, contraction_dim=512, num_heads=8, num_layers=15, dropout=0, - unconditioned_percentage=.4, conditioning_masking=.5, + unconditioned_percentage=.4, conditioning_masking=.4, segregrate_conditioning_segments=True) print_network(model) for k in range(100):