forked from mrq/DL-Art-School
whoops
This commit is contained in:
parent
a659cd865c
commit
56c4a00e71
|
@ -233,7 +233,7 @@ class TransformerDiffusionWithConditioningEncoder(nn.Module):
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
groups = self.diff.get_grad_norm_parameter_groups()
|
groups = self.diff.get_grad_norm_parameter_groups()
|
||||||
groups['conditioning_encoder'] = list(self.conditioning_encoder.parameters())
|
groups['conditioning_encoder'] = list(self.conditioning_encoder.parameters())
|
||||||
return
|
return groups
|
||||||
|
|
||||||
def before_step(self, step):
|
def before_step(self, step):
|
||||||
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
|
scaled_grad_parameters = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.diff.layers])) + \
|
||||||
|
|
Loading…
Reference in New Issue
Block a user