This commit is contained in:
James Betker 2022-06-17 09:40:11 -06:00
parent 7ca532c7cc
commit c000e489fa

View File

@ -540,6 +540,7 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
def get_debug_values(self, step, __): def get_debug_values(self, step, __):
self.internal_step = step self.internal_step = step
return {}
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers])) attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers]))