From b9d0f7e6de51b96ae34a9d55b6078f30f8076ce7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 20 Jul 2022 23:41:54 -0600 Subject: [PATCH] simplify parameterization a bit --- .../audio/music/transformer_diffusion13.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index a78e3cac..83c14f59 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -29,20 +29,21 @@ class SubBlock(nn.Module): self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads) self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False) self.pos_bias = RelativeQKBias(l=64) - self.attn_glu = cGLU(contraction_dim) - self.attnorm = nn.GroupNorm(8, contraction_dim) - self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1) - self.ff_glu = cGLU(contraction_dim) - self.ffnorm = nn.GroupNorm(8, contraction_dim) + ff_contract = contraction_dim//2 + self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1), + nn.GroupNorm(8, ff_contract), + cGLU(ff_contract)) + self.ff2 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim*3//2, ff_contract, kernel_size=3, padding=1), + nn.GroupNorm(8, ff_contract), + cGLU(ff_contract)) def forward(self, x): ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1]))) - ah = self.attn_glu(self.attnorm(ah)) h = torch.cat([ah, x], dim=1) - hf = self.dropout(checkpoint(self.ff, h)) - hf = self.ff_glu(self.ffnorm(hf)) + hf = self.dropout(checkpoint(self.ff1, h)) h = torch.cat([h, hf], dim=1) - return h + hf = self.dropout(checkpoint(self.ff2, h)) + return torch.cat([h, hf], dim=1) class ConcatAttentionBlock(TimestepBlock): @@ -157,8 +158,10 @@ class TransformerDiffusion(nn.Module): def get_grad_norm_parameter_groups(self): attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers])) attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers])) - ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers])) - ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers])) + ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] + + [lyr.block1.ff2.parameters() for lyr in self.layers])) + ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] + + [lyr.block2.ff2.parameters() for lyr in self.layers])) blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers])) groups = { 'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])), @@ -295,6 +298,7 @@ def test_tfd(): model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512, num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1, unconditioned_percentage=.6) + model.get_grad_norm_parameter_groups() for k in range(100): x = model.input_to_random_resolution_and_window(clip, ts, diffuser) model(x, ts, conditioning_input=cond)