diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index b52639fe..77e716ec 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -525,10 +525,16 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): self.encoder = self.encoder.eval() def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): + unused_parameters = [] encoder_grad_enabled = self.internal_step > self.freeze_encoder_until + if not encoder_grad_enabled: + unused_parameters.extend(list(self.encoder.parameters())) with torch.set_grad_enabled(encoder_grad_enabled): proj = self.encoder(truth_mel).permute(0,2,1) + for p in unused_parameters: + proj = proj + p.mean() * 0 + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) return diff