From 157d5d56c3722790c83de66f9c4643e7703b9dbd Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 15 Jun 2022 09:19:34 -0600 Subject: [PATCH] another debugging fix --- .../audio/music/transformer_diffusion12.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 372b7fc3..ec39c921 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -177,9 +177,25 @@ class TransformerDiffusion(nn.Module): self.debug_codes = {} def get_grad_norm_parameter_groups(self): + attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers])) + attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers])) + ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers])) + ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers])) + blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) groups = { - 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), - 'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), + 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])), + 'blk1_attention_layers': attn1, + 'blk2_attention_layers': attn2, + 'attention_layers': attn1 + attn2, + 'blk1_ff_layers': ff1, + 'blk2_ff_layers': ff2, + 'ff_layers': ff1 + ff2, + 'block_out_layers': blkout_layers, + 'rotary_embeddings': list(self.rotary_embeddings.parameters()), + 'out': list(self.out.parameters()), + 'x_proj': list(self.inp_block.parameters()), + 'layers': list(self.layers.parameters()), + #'code_converters': list(self.input_converter.parameters()) + list(self.code_converter.parameters()), 'time_embed': list(self.time_embed.parameters()), } return groups @@ -605,6 +621,7 @@ def test_multi_vqvae_model(): print_network(model) o = model(clip, ts, cond) pg = model.get_grad_norm_parameter_groups() + model.diff.get_grad_norm_parameter_groups() def test_ar_model():