From 2b128730e72498156975ccf3b4f57d95838fa397 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 5 Jul 2022 10:30:28 -0600 Subject: [PATCH] Improve conditioning separation logic --- codes/models/audio/music/tfdpc_v5.py | 56 ++++++++++++++-------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index c2186133..77abe92b 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -126,10 +126,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module): num_heads=8, dropout=0, use_fp16=False, - segregrate_conditioning_segments=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. - conditioning_masking=0, ): super().__init__() @@ -137,10 +135,8 @@ class TransformerDiffusionWithPointConditioning(nn.Module): self.model_channels = model_channels self.time_embed_dim = time_embed_dim self.out_channels = out_channels - self.segregrate_conditioning_segments = segregrate_conditioning_segments self.dropout = dropout self.unconditioned_percentage = unconditioned_percentage - self.conditioning_masking = conditioning_masking self.enable_fp16 = use_fp16 self.inp_block = conv_nd(1, in_channels, model_channels, 3, 1, 1) @@ -201,28 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module): if custom_conditioning_fetcher is not None: cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) else: - if self.training and self.conditioning_masking > 0: - mask_prop = random.random() * self.conditioning_masking - mask_len = max(int(N * mask_prop), 16) - assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}" - seg_start = random.randint(8, (N-mask_len)) + cond_start - # Readjust mask_len to ensure at least 8 sequence elements on the end as well - mask_len = min(mask_len, conditioning_input.shape[-1]-seg_start-8) - assert mask_len > 0, f"Error with mask_len: {conditioning_input.shape[-1], seg_start, mask_len, mask_prop, N}" - seg_end = seg_start+mask_len - conditioning_input[:,:,seg_start:seg_end] = 0 - else: - seg_start = cond_start + N // 2 - seg_end = seg_start - if self.segregrate_conditioning_segments: - cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb) - cs = cond_enc1[:,:,cond_start] - cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb) - ce = cond_enc2[:,:,(N+cond_start)-seg_end] - else: - cond_enc = self.conditioning_encoder(conditioning_input, time_emb) - cs = cond_enc[:,:,cond_start] - ce = cond_enc[:,:,N+cond_start] + assert conditioning_input.shape[-1] - cond_start - N > 0, f'Some sort of conditioning misalignment, {conditioning_input.shape[-1], cond_start, N}' + cond_pre = conditioning_input[:,:,:cond_start] + cond_aligned = conditioning_input[:,:,cond_start:N+cond_start] + cond_post = conditioning_input[:,:,N+cond_start:] + + # Break up conditioning input into two random segments aligned with the input. + MIN_MARGIN = 8 + assert N > (MIN_MARGIN*2+4), f"Input size too small. Was {N} but requires at least {MIN_MARGIN*2+4}" + break_pt = random.randint(2, N-MIN_MARGIN*2) + MIN_MARGIN + cond_left = cond_aligned[:,:,:break_pt] + cond_right = cond_aligned[:,:,break_pt:] + + # Drop out a random amount of the aligned data. The network will need to figure out how to reconstruct this. + to_remove = random.randint(0, cond_left.shape[-1]-MIN_MARGIN) + cond_left = cond_left[:,:,:-to_remove] + to_remove = random.randint(0, cond_right.shape[-1]-MIN_MARGIN) + cond_right = cond_right[:,:,to_remove:] + + # Concatenate the _pre and _post back on. + cond_left_full = torch.cat([cond_pre, cond_left], dim=-1) + cond_right_full = torch.cat([cond_right, cond_post], dim=-1) + + # Propagate through the encoder. + cond_left_enc = self.conditioning_encoder(cond_left_full, time_emb) + cs = cond_left_enc[:,:,cond_start] + cond_right_enc = self.conditioning_encoder(cond_right_full, time_emb) + ce = cond_right_enc[:,:,cond_right.shape[-1]-1] cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1) return cond @@ -286,8 +287,7 @@ def test_cheater_model(): # For music: model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, contraction_dim=512, num_heads=8, num_layers=15, dropout=0, - unconditioned_percentage=.4, conditioning_masking=.4, - segregrate_conditioning_segments=True) + unconditioned_percentage=.4) print_network(model) for k in range(100): o = model(clip, ts, cl)