diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 83c14f59..3500d627 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -29,6 +29,8 @@ class SubBlock(nn.Module): self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads) self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False) self.pos_bias = RelativeQKBias(l=64) + # Feedforward is split into two groups: one with kernel_size=1, the other with kernel_size=3. The idea is that + # localized processing is valuable (and more efficient!) ff_contract = contraction_dim//2 self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1), nn.GroupNorm(8, ff_contract), @@ -47,18 +49,20 @@ class SubBlock(nn.Module): class ConcatAttentionBlock(TimestepBlock): - def __init__(self, trunk_dim, contraction_dim, blk_dim, heads, dropout): + def __init__(self, trunk_dim, contraction_dim, blk_in_dim, blk_proj_dim, heads, dropout): super().__init__() self.contraction_dim = contraction_dim self.prenorm = nn.GroupNorm(8, trunk_dim) - self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout) - self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout) + self.block_proj = nn.Linear(blk_in_dim, blk_proj_dim) + self.block1 = SubBlock(trunk_dim+blk_proj_dim, contraction_dim, heads, dropout) + self.block2 = SubBlock(trunk_dim+blk_proj_dim+contraction_dim*2, contraction_dim, heads, dropout) self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False) self.out.weight.data.zero_() def forward(self, x, blk_emb): h = self.prenorm(x) - h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1) + blk_enc = self.block_proj(blk_emb) + h = torch.cat([h, blk_enc.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1) h = self.block1(h) h = self.block2(h) h = self.out(h[:,-self.contraction_dim*4:]) @@ -110,6 +114,7 @@ class TransformerDiffusion(nn.Module): time_embed_dim=256, time_proj_dim=64, cond_proj_dim=256, + blk_op_dim=128, num_heads=4, dropout=0, use_fp16=False, @@ -144,7 +149,7 @@ class TransformerDiffusion(nn.Module): self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim)) self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1) - self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim, + self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim, blk_op_dim, num_heads, dropout) for _ in range(num_layers)]) self.out = nn.Sequential(