forked from mrq/DL-Art-School
ci not required
This commit is contained in:
parent
08597bfaf5
commit
5028703b3d
|
@ -253,7 +253,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
|
|||
self.quantizer.min_gumbel_temperature,
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input, disable_diversity=False, conditioning_free=False):
|
||||
def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False):
|
||||
quant_grad_enabled = self.internal_step > self.freeze_quantizer_until
|
||||
with torch.set_grad_enabled(quant_grad_enabled):
|
||||
proj, diversity_loss = self.quantizer(truth_mel, return_decoder_latent=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user