harharhack

This commit is contained in:
James Betker 2022-06-10 15:13:24 -06:00
parent 7198bd8bd0
commit 33178e89c4

View File

@ -228,7 +228,7 @@ class TransformerDiffusionWithQuantizer(nn.Module):
for p in self.diff.parameters():
unused = unused + p.mean() * 0
mse = mse + unused
return x, diversity_loss, mse
return x.repeat(1,2,1), diversity_loss, mse
quant_grad_enabled = self.internal_step >= self.freeze_quantizer_until
with torch.set_grad_enabled(quant_grad_enabled):