Try to squeeze a bit more performance out of this arch

This commit is contained in:
James Betker 2022-07-20 23:51:11 -06:00
parent b9d0f7e6de
commit 767f963392

View File

@ -29,6 +29,8 @@ class SubBlock(nn.Module):
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
self.pos_bias = RelativeQKBias(l=64)
# Feedforward is split into two groups: one with kernel_size=1, the other with kernel_size=3. The idea is that
# localized processing is valuable (and more efficient!)
ff_contract = contraction_dim//2
self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1),
nn.GroupNorm(8, ff_contract),
@ -47,18 +49,20 @@ class SubBlock(nn.Module):
class ConcatAttentionBlock(TimestepBlock):
def __init__(self, trunk_dim, contraction_dim, blk_dim, heads, dropout):
def __init__(self, trunk_dim, contraction_dim, blk_in_dim, blk_proj_dim, heads, dropout):
super().__init__()
self.contraction_dim = contraction_dim
self.prenorm = nn.GroupNorm(8, trunk_dim)
self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.block_proj = nn.Linear(blk_in_dim, blk_proj_dim)
self.block1 = SubBlock(trunk_dim+blk_proj_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+blk_proj_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
self.out.weight.data.zero_()
def forward(self, x, blk_emb):
h = self.prenorm(x)
h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
blk_enc = self.block_proj(blk_emb)
h = torch.cat([h, blk_enc.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
h = self.block1(h)
h = self.block2(h)
h = self.out(h[:,-self.contraction_dim*4:])
@ -110,6 +114,7 @@ class TransformerDiffusion(nn.Module):
time_embed_dim=256,
time_proj_dim=64,
cond_proj_dim=256,
blk_op_dim=128,
num_heads=4,
dropout=0,
use_fp16=False,
@ -144,7 +149,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim,
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim, blk_op_dim,
num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential(