From b210e5025c6ec52a4bf0fef0d624c88a39eabb5d Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 23 Jun 2022 11:34:46 -0600 Subject: [PATCH] Le encoder shalt always be frozen. --- codes/models/audio/music/transformer_diffusion12.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 58b45c86..5e0904a7 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -520,7 +520,7 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module): class TransformerDiffusionWithCheaterLatent(nn.Module): - def __init__(self, freeze_encoder_until=50000, **kwargs): + def __init__(self, freeze_encoder_until=None, **kwargs): super().__init__() self.internal_step = 0 self.freeze_encoder_until = freeze_encoder_until @@ -530,7 +530,7 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): unused_parameters = [] - encoder_grad_enabled = self.internal_step > self.freeze_encoder_until + encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until if not encoder_grad_enabled: unused_parameters.extend(list(self.encoder.parameters())) with torch.set_grad_enabled(encoder_grad_enabled):