This commit is contained in:
James Betker 2022-06-01 14:21:44 -06:00
parent de54be5570
commit 712e0e82f7

View File

@ -220,7 +220,8 @@ class TransformerDiffusionWithQuantizer(nn.Module):
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False): def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
with torch.set_grad_enabled(quant_grad_enabled): with torch.set_grad_enabled(quant_grad_enabled):
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1) proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
proj = proj.permute(0,2,1)
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing. # Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
if not quant_grad_enabled: if not quant_grad_enabled: