forked from mrq/DL-Art-School
Le encoder shalt always be frozen.
This commit is contained in:
parent
aeff1a4cc7
commit
b210e5025c
|
@ -520,7 +520,7 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class TransformerDiffusionWithCheaterLatent(nn.Module):
|
class TransformerDiffusionWithCheaterLatent(nn.Module):
|
||||||
def __init__(self, freeze_encoder_until=50000, **kwargs):
|
def __init__(self, freeze_encoder_until=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.internal_step = 0
|
self.internal_step = 0
|
||||||
self.freeze_encoder_until = freeze_encoder_until
|
self.freeze_encoder_until = freeze_encoder_until
|
||||||
|
@ -530,7 +530,7 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
||||||
unused_parameters = []
|
unused_parameters = []
|
||||||
encoder_grad_enabled = self.internal_step > self.freeze_encoder_until
|
encoder_grad_enabled = self.freeze_encoder_until is not None and self.internal_step > self.freeze_encoder_until
|
||||||
if not encoder_grad_enabled:
|
if not encoder_grad_enabled:
|
||||||
unused_parameters.extend(list(self.encoder.parameters()))
|
unused_parameters.extend(list(self.encoder.parameters()))
|
||||||
with torch.set_grad_enabled(encoder_grad_enabled):
|
with torch.set_grad_enabled(encoder_grad_enabled):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user