From 7ca532c7cc26ce48d65310362cc14f628b41640c Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 17 Jun 2022 09:37:07 -0600 Subject: [PATCH] handle unused encoder parameters --- codes/models/audio/music/transformer_diffusion12.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index b52639fe..77e716ec 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -525,10 +525,16 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): self.encoder = self.encoder.eval() def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): + unused_parameters = [] encoder_grad_enabled = self.internal_step > self.freeze_encoder_until + if not encoder_grad_enabled: + unused_parameters.extend(list(self.encoder.parameters())) with torch.set_grad_enabled(encoder_grad_enabled): proj = self.encoder(truth_mel).permute(0,2,1) + for p in unused_parameters: + proj = proj + p.mean() * 0 + diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free) return diff