From 1dbe0b6b2e8108a1a984f86a88a1eec7a12ed3f6 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 26 May 2022 10:13:27 -0600 Subject: [PATCH] a --- codes/models/audio/music/transformer_diffusion3.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion3.py b/codes/models/audio/music/transformer_diffusion3.py index 6183e1ba..ea936984 100644 --- a/codes/models/audio/music/transformer_diffusion3.py +++ b/codes/models/audio/music/transformer_diffusion3.py @@ -151,9 +151,7 @@ class TransformerDiffusion(nn.Module): def get_grad_norm_parameter_groups(self): groups = { 'contextual_embedder': list(self.conditioning_embedder.parameters()), - 'top_layers': list(self.top_layers.parameters()) + list(self.inp_block.parameters()), - 'mid_layers': list(self.mid_layers.parameters()), - 'final_layers': list(self.final_layers.parameters()), + 'layers': list(self.layers.parameters()) + list(self.inp_block.parameters()), 'code_converters': list(self.embeddings.parameters()) + list(self.code_converter.parameters()) + list(self.latent_conditioner.parameters()), 'time_embed': list(self.time_embed.parameters()), }