From 712e0e82f75d60d337142ac94efd5e796c7f31d2 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 1 Jun 2022 14:21:44 -0600 Subject: [PATCH] fix bug --- codes/models/audio/music/transformer_diffusion7.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/music/transformer_diffusion7.py b/codes/models/audio/music/transformer_diffusion7.py index bc29014b..f8102690 100644 --- a/codes/models/audio/music/transformer_diffusion7.py +++ b/codes/models/audio/music/transformer_diffusion7.py @@ -220,7 +220,8 @@ class TransformerDiffusionWithQuantizer(nn.Module): def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): quant_grad_enabled = self.internal_step > self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): - proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1) + proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True) + proj = proj.permute(0,2,1) # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. if not quant_grad_enabled: