|
|
|
@ -525,10 +525,16 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
|
|
|
|
self.encoder = self.encoder.eval()
|
|
|
|
|
|
|
|
|
|
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
|
|
|
|
unused_parameters = []
|
|
|
|
|
encoder_grad_enabled = self.internal_step > self.freeze_encoder_until
|
|
|
|
|
if not encoder_grad_enabled:
|
|
|
|
|
unused_parameters.extend(list(self.encoder.parameters()))
|
|
|
|
|
with torch.set_grad_enabled(encoder_grad_enabled):
|
|
|
|
|
proj = self.encoder(truth_mel).permute(0,2,1)
|
|
|
|
|
|
|
|
|
|
for p in unused_parameters:
|
|
|
|
|
proj = proj + p.mean() * 0
|
|
|
|
|
|
|
|
|
|
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input, conditioning_free=conditioning_free)
|
|
|
|
|
return diff
|
|
|
|
|
|
|
|
|
|