diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index ab45fb0f..8521437b 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -234,6 +234,8 @@ class TransformerDiffusionWithQuantizer(nn.Module): diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) + if disable_diversity: + return diff if mse is None: return diff, diversity_loss return diff, diversity_loss, mse