fix bug
This commit is contained in:
parent
de54be5570
commit
712e0e82f7
|
@ -220,7 +220,8 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
def forward(self, x, timesteps, truth_mel, conditioning_input, conditioning_free=False):
|
||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||
with torch.set_grad_enabled(quant_grad_enabled):
|
||||
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True).permute(0,2,1)
|
||||
proj, diversity_loss = self.m2v(truth_mel, return_decoder_latent=True)
|
||||
proj = proj.permute(0,2,1)
|
||||
|
||||
# Make sure this does not cause issues in DDP by explicitly using the parameters for nothing.
|
||||
if not quant_grad_enabled:
|
||||
|
|
Loading…
Reference in New Issue
Block a user