Don't augment grad scale when the grad don't exist!

This commit is contained in:
James Betker 2022-06-17 09:27:04 -06:00
parent 3efd64ed7a
commit 3081c893d4

View File

@ -342,7 +342,8 @@ class TransformerDiffusionWithQuantizer(nn.Module):
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
p.grad *= .2
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
class TransformerDiffusionWithARPrior(nn.Module):