and the other ones..

really need to unify this file better.
pull/9/head
James Betker 2022-06-17 09:30:25 +07:00
parent 3081c893d4
commit e025183bfb
1 changed files with 4 additions and 2 deletions

@ -511,7 +511,8 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
p.grad *= .2
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
class TransformerDiffusionWithCheaterLatent(nn.Module):
@ -566,7 +567,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients.
for p in scaled_grad_parameters:
p.grad *= .2
if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
@register_model