diff --git a/codes/models/audio/music/transformer_diffusion9.py b/codes/models/audio/music/transformer_diffusion9.py index e9c30ef6..6670b7f3 100644 --- a/codes/models/audio/music/transformer_diffusion9.py +++ b/codes/models/audio/music/transformer_diffusion9.py @@ -253,7 +253,7 @@ class TransformerDiffusionWithQuantizer(nn.Module): self.quantizer.min_gumbel_temperature, ) - def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False): + def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): quant_grad_enabled = self.internal_step > self.freeze_quantizer_until with torch.set_grad_enabled(quant_grad_enabled): proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True)