and the other ones..

really need to unify this file better.
This commit is contained in:
James Betker 2022-06-17 09:30:25 -06:00
parent 3081c893d4
commit e025183bfb

View File

@ -511,6 +511,7 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
@ -566,6 +567,7 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2