pull/9/head
James Betker 2022-06-19 23:22:30 +07:00
parent a659cd865c
commit 56c4a00e71
1 changed files with 1 additions and 1 deletions

@ -233,7 +233,7 @@ class TransformerDiffusionWithConditioningEncoder(nn.Module):
def get_grad_norm_parameter_groups(self):
groups = self.diff.get_grad_norm_parameter_groups()
groups['conditioning_encoder'] = list(self.conditioning_encoder.parameters())
return
return groups
def before_step(self, step):
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \