and the other ones..

really need to unify this file better.
This commit is contained in:
James Betker 2022-06-17 09:30:25 -06:00
parent 3081c893d4
commit e025183bfb

View File

@ -511,7 +511,8 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module):
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients. # directly fiddling with the gradients.
for p in scaled_grad_parameters: for p in scaled_grad_parameters:
p.grad *= .2 if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
class TransformerDiffusionWithCheaterLatent(nn.Module): class TransformerDiffusionWithCheaterLatent(nn.Module):
@ -566,7 +567,8 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
# higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than # higher gradients. Ideally we would use parameter groups, but ZeroRedundancyOptimizer makes this trickier than
# directly fiddling with the gradients. # directly fiddling with the gradients.
for p in scaled_grad_parameters: for p in scaled_grad_parameters:
p.grad *= .2 if hasattr(p, 'grad') and p.grad is not None:
p.grad *= .2
@register_model @register_model