forked from mrq/DL-Art-School
another debugging fix
This commit is contained in:
parent
6fc86bbbe7
commit
157d5d56c3
|
@ -177,9 +177,25 @@ class TransformerDiffusion(nn.Module):
|
|||
self.debug_codes = {}
|
||||
|
||||
def get_grad_norm_parameter_groups(self):
|
||||
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
|
||||
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
|
||||
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers]))
|
||||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers]))
|
||||
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
|
||||
groups = {
|
||||
'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()),
|
||||
'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()),
|
||||
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
|
||||
'blk1_attention_layers': attn1,
|
||||
'blk2_attention_layers': attn2,
|
||||
'attention_layers': attn1 + attn2,
|
||||
'blk1_ff_layers': ff1,
|
||||
'blk2_ff_layers': ff2,
|
||||
'ff_layers': ff1 + ff2,
|
||||
'block_out_layers': blkout_layers,
|
||||
'rotary_embeddings': list(self.rotary_embeddings.parameters()),
|
||||
'out': list(self.out.parameters()),
|
||||
'x_proj': list(self.inp_block.parameters()),
|
||||
'layers': list(self.layers.parameters()),
|
||||
#'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()),
|
||||
'time_embed': list(self.time_embed.parameters()),
|
||||
}
|
||||
return groups
|
||||
|
@ -605,6 +621,7 @@ def test_multi_vqvae_model():
|
|||
print_network(model)
|
||||
o = model(clip, ts, cond)
|
||||
pg = model.get_grad_norm_parameter_groups()
|
||||
model.diff.get_grad_norm_parameter_groups()
|
||||
|
||||
|
||||
def test_ar_model():
|
||||
|
|
Loading…
Reference in New Issue
Block a user