diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index 145a7260..0037ec5c 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -228,7 +228,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): for p in self.diff.parameters(): unused = unused + p.mean() * 0 mse = mse + unused - return x, diversity_loss, mse + return x.repeat(1,2,1), diversity_loss, mse quant_grad_enabled = self.internal_step >= self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled):