diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 3faf3f8b..d9f0fc64 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -99,6 +99,7 @@ class TransformerDiffusion(nn.Module): dropout=0, use_fp16=False, ar_prior=False, + code_expansion_mode='nearest', # Parameters for regularization. unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training. # Parameters for re-training head @@ -114,6 +115,7 @@ class TransformerDiffusion(nn.Module): self.dropout = dropout self.unconditioned_percentage = unconditioned_percentage self.enable_fp16 = use_fp16 + self.code_expansion_mode = code_expansion_mode self.inp_block = conv_nd(1, in_channels, prenet_channels, 3, 1, 1) @@ -217,7 +219,7 @@ class TransformerDiffusion(nn.Module): code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(prior.shape[0], 1, 1), code_emb) - expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode='nearest').permute(0,2,1) + expanded_code_emb = F.interpolate(code_emb.permute(0,2,1), size=expected_seq_len, mode=self.code_expansion_mode).permute(0,2,1) return expanded_code_emb def forward(self, x, timesteps, codes=None, conditioning_input=None, precomputed_code_embeddings=None, conditioning_free=False):