eval bug fix

This commit is contained in:
James Betker 2022-06-10 13:51:06 -06:00
parent 84469f3538
commit 89bd40d39f

View File

@ -234,6 +234,8 @@ class TransformerDiffusionWithQuantizer(nn.Module):
diff = self.diff(x, timesteps, codes=proj, conditioning_input=conditioning_input,
conditioning_free=conditioning_free)
if disable_diversity:
return diff
if mse is None:
return diff, diversity_loss
return diff, diversity_loss, mse