diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index d9f0fc64..0ca0c029 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -99,7 +99,7 @@ class TransformerDiffusion(nn.Module): dropout=0, use_fp16=False, ar_prior=False, - code_expansion_mode='nearest', + new_code_expansion=False, # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. # Parameters for re-training head @@ -115,7 +115,7 @@ class TransformerDiffusion(nn.Module): self.dropout = dropout self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 - self.code_expansion_mode = code_expansion_mode + self.new_code_expansion = new_code_expansion self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) @@ -209,7 +209,9 @@ class TransformerDiffusion(nn.Module): return groups def timestep_independent(self, prior, expected_seq_len): - code_emb = self.ar_input(prior) if self.ar_prior else self.input_converter(prior) + if self.new_code_expansion: + code_emb = F.interpolate(prior.permute(0,2,1), size=expected_seq_len, mode='linear').permute(0,2,1) + code_emb = self.ar_input(code_emb) if self.ar_prior else self.input_converter(code_emb) code_emb = self.ar_prior_intg(code_emb) if self.ar_prior else self.code_converter(code_emb) # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. @@ -219,8 +221,9 @@ class TransformerDiffusion(nn.Module): code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), code_emb) - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode=self.code_expansion_mode).permute(0,2,1) - return expanded_code_emb + if not self.new_code_expansion: + code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) + return code_emb def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False): if precomputed_code_embeddings is not None: @@ -722,7 +725,7 @@ def test_cheater_model(): model_channels=1024, contraction_dim=512, prenet_channels=1024, num_heads=8, input_vec_dim=256, num_layers=12, prenet_layers=6, - dropout=.1, + dropout=.1, new_code_expansion=True, ) diff_weights = torch.load('extracted_diff.pth') model.diff.load_state_dict(diff_weights, strict=False)