From f9ebcf11d8e121603f724aed9b9795bb01bc37e3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 5 Jun 2022 01:31:37 -0600 Subject: [PATCH] fix2 --- codes/models/audio/music/transformer_diffusion8.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion8.py b/codes/models/audio/music/transformer_diffusion8.py index a3c22551..9a9ae82b 100644 --- a/codes/models/audio/music/transformer_diffusion8.py +++ b/codes/models/audio/music/transformer_diffusion8.py @@ -1,3 +1,5 @@ +import itertools + import torch import torch.nn as nn import torch.nn.functional as F @@ -224,8 +226,8 @@ class TransformerDiffusionWithQuantizer(nn.Module): def get_grad_norm_parameter_groups(self): groups = { - 'attention_layers': [lyr.attn.parameters() for lyr in self.diff.layers], - 'ff_layers': [lyr.ff.parameters() for lyr in self.diff.layers], + 'attention_layers': list(itertools.chain.from_iterable([lyr.attn.parameters() for lyr in self.diff.layers])), + 'ff_layers': list(itertools.chain.from_iterable([lyr.ff.parameters() for lyr in self.diff.layers])), 'quantizer_encoder': list(self.quantizer.encoder.parameters()), 'quant_codebook': [self.quantizer.quantizer.codevectors], 'rotary_embeddings': list(self.diff.rotary_embeddings.parameters()),