diff --git a/codes/models/audio/music/tfdpc_v5.py b/codes/models/audio/music/tfdpc_v5.py index 85734782..2d20504d 100644 --- a/codes/models/audio/music/tfdpc_v5.py +++ b/codes/models/audio/music/tfdpc_v5.py @@ -126,6 +126,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): num_heads=8, dropout=0, use_fp16=False, + segregrate_conditioning_segments=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. conditioning_masking=0, @@ -136,6 +137,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): self.model_channels = model_channels self.time_embed_dim = time_embed_dim self.out_channels = out_channels + self.segregrate_conditioning_segments = segregrate_conditioning_segments self.dropout = dropout self.unconditioned_percentage = unconditioned_percentage self.conditioning_masking = conditioning_masking @@ -195,6 +197,33 @@ class TransformerDiffusionWithPointConditioning(nn.Module): } return groups + def process_conditioning(self, conditioning_input, time_emb, N, cond_start, custom_conditioning_fetcher): + if custom_conditioning_fetcher is not None: + cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) + else: + if self.training and self.conditioning_masking > 0: + mask_prop = random.random() * self.conditioning_masking + mask_len = min(int(N * mask_prop), 4) + assert N-mask_len > 8, f"Use longer inputs or shorter conditioning_masking proportion. {N-mask_len}" + seg_start = random.randint(8, (N-mask_len)) + cond_start + seg_end = seg_start+mask_len + conditioning_input[:,:,seg_start:seg_end] = 0 + else: + seg_start = cond_start + N // 2 + seg_end = seg_start + if self.segregrate_conditioning_segments: + cond_enc1 = self.conditioning_encoder(conditioning_input[:,:,:seg_start], time_emb) + cs = cond_enc1[:,:,cond_start] + cond_enc2 = self.conditioning_encoder(conditioning_input[:,:,seg_end:], time_emb) + ce = cond_enc2[:,:,(N+cond_start)-seg_end] + else: + cond_enc = self.conditioning_encoder(conditioning_input, time_emb) + cs = cond_enc[:,:,cond_start] + ce = cond_enc[:,:,N+cond_start] + cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) + cond = F.interpolate(cond_enc, size=(N,), mode='linear').permute(0,2,1) + return cond + def forward(self, x, timesteps, conditioning_input=None, conditioning_free=False, cond_start=0, custom_conditioning_fetcher=None): unused_params = [] @@ -204,21 +233,7 @@ class TransformerDiffusionWithPointConditioning(nn.Module): cond = self.unconditioned_embedding cond = cond.repeat(1,x.shape[-1],1) else: - if custom_conditioning_fetcher is not None: - cs, ce = custom_conditioning_fetcher(self.conditioning_encoder, time_emb) - else: - if self.training and self.conditioning_masking > 0: - cond_op_len = x.shape[-1] - mask_prop = random.random() * self.conditioning_masking - mask_len = int(cond_op_len * mask_prop) - if mask_len > 0: - start = random.randint(0, (cond_op_len-mask_len)) + cond_start - conditioning_input[:,:,start:(start+mask_len)] = 0 - cond_enc = self.conditioning_encoder(conditioning_input, time_emb) - cs = cond_enc[:,:,cond_start] - ce = cond_enc[:,:,x.shape[-1]+cond_start] - cond_enc = torch.cat([cs.unsqueeze(-1), ce.unsqueeze(-1)], dim=-1) - cond = F.interpolate(cond_enc, size=(x.shape[-1],), mode='linear').permute(0,2,1) + cond = self.process_conditioning(conditioning_input, time_emb, x.shape[-1], cond_start, custom_conditioning_fetcher) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. if self.training and self.unconditioned_percentage > 0: unconditioned_batches = torch.rand((cond.shape[0], 1, 1), @@ -261,16 +276,18 @@ def register_tfdpc5(opt_net, opt): def test_cheater_model(): - clip = torch.randn(2, 256, 400) - cl = torch.randn(2, 256, 400) + clip = torch.randn(2, 256, 200) + cl = torch.randn(2, 256, 500) ts = torch.LongTensor([600, 600]) # For music: model = TransformerDiffusionWithPointConditioning(in_channels=256, out_channels=512, model_channels=1024, contraction_dim=512, num_heads=8, num_layers=15, dropout=0, - unconditioned_percentage=.4, conditioning_masking=.5) + unconditioned_percentage=.4, conditioning_masking=.5, + segregrate_conditioning_segments=True) print_network(model) - o = model(clip, ts, cl) + for k in range(100): + o = model(clip, ts, cl) pg = model.get_grad_norm_parameter_groups() def prmsz(lp): sz = 0