diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 58b45c86..5e0904a7 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -520,7 +520,7 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module): class TransformerDiffusionWithCheaterLatent(nn.Module): - def __init__(self, freeze_encoder_until=50000, **kwargs): + def __init__(self, freeze_encoder_until=None, **kwargs): super().__init__() self.internal_step = 0 self.freeze_encoder_until = freeze_encoder_until @@ -530,7 +530,7 @@ class TransformerDiffusionWithCheaterLatent(nn.Module): def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): unused_parameters = [] - encoder_grad_enabled = self.internal_step > self.freeze_encoder_until + encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until if not encoder_grad_enabled: unused_parameters.extend(list(self.encoder.parameters())) with torch.set_grad_enabled(encoder_grad_enabled):