diff --git a/codes/models/audio/music/transformer_diffusion7.py b/codes/models/audio/music/transformer_diffusion7.py index bc29014b..f8102690 100644 --- a/codes/models/audio/music/transformer_diffusion7.py +++ b/codes/models/audio/music/transformer_diffusion7.py @@ -220,7 +220,8 @@ class TransformerDiffusionWithQuantizer(nn.Module): def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): quant_grad_enabled = self.internal_step > self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1) + proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True) + proj = proj.permute(0,2,1) # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. if not quant_grad_enabled: