|
|
|
@ -540,6 +540,7 @@ class TransformerDiffusionWithCheaterLatent(nn.Module):
|
|
|
|
|
|
|
|
|
|
def get_debug_values(self, step, __):
|
|
|
|
|
self.internal_step = step
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
def get_grad_norm_parameter_groups(self):
|
|
|
|
|
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.diff.layers]))
|
|
|
|
|