another debugging fix

pull/9/head
James Betker 2022-06-15 09:19:34 +07:00
parent 6fc86bbbe7
commit 157d5d56c3
1 changed files with 19 additions and 2 deletions

@ -177,9 +177,25 @@ class TransformerDiffusion(nn.Module):
self.debug_codes = {}
def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
groups = {
'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()),
'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()),
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
'blk1_attention_layers': attn1,
'blk2_attention_layers': attn2,
'attention_layers': attn1 + attn2,
'blk1_ff_layers': ff1,
'blk2_ff_layers': ff2,
'ff_layers': ff1 + ff2,
'block_out_layers': blkout_layers,
'rotary_embeddings': list(self.rotary_embeddings.parameters()),
'out': list(self.out.parameters()),
'x_proj': list(self.inp_block.parameters()),
'layers': list(self.layers.parameters()),
#'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()),
'time_embed': list(self.time_embed.parameters()),
}
return groups
@ -605,6 +621,7 @@ def test_multi_vqvae_model():
print_network(model)
o = model(clip, ts, cond)
pg = model.get_grad_norm_parameter_groups()
model.diff.get_grad_norm_parameter_groups()
def test_ar_model():