Try to squeeze a bit more performance out of this arch

This commit is contained in:
James Betker 2022-07-20 23:51:11 -06:00
parent b9d0f7e6de
commit 767f963392

View File

@ -29,6 +29,8 @@ class SubBlock(nn.Module):
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads) self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False) self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
self.pos_bias = RelativeQKBias(l=64) self.pos_bias = RelativeQKBias(l=64)
# Feedforward is split into two groups: one with kernel_size=1, the other with kernel_size=3. The idea is that
# localized processing is valuable (and more efficient!)
ff_contract = contraction_dim//2 ff_contract = contraction_dim//2
self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1), self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1),
nn.GroupNorm(8, ff_contract), nn.GroupNorm(8, ff_contract),
@ -47,18 +49,20 @@ class SubBlock(nn.Module):
class ConcatAttentionBlock(TimestepBlock): class ConcatAttentionBlock(TimestepBlock):
def __init__(self, trunk_dim, contraction_dim, blk_dim, heads, dropout): def __init__(self, trunk_dim, contraction_dim, blk_in_dim, blk_proj_dim, heads, dropout):
super().__init__() super().__init__()
self.contraction_dim = contraction_dim self.contraction_dim = contraction_dim
self.prenorm = nn.GroupNorm(8, trunk_dim) self.prenorm = nn.GroupNorm(8, trunk_dim)
self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout) self.block_proj = nn.Linear(blk_in_dim, blk_proj_dim)
self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout) self.block1 = SubBlock(trunk_dim+blk_proj_dim, contraction_dim, heads, dropout)
self.block2 = SubBlock(trunk_dim+blk_proj_dim+contraction_dim*2, contraction_dim, heads, dropout)
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False) self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
self.out.weight.data.zero_() self.out.weight.data.zero_()
def forward(self, x, blk_emb): def forward(self, x, blk_emb):
h = self.prenorm(x) h = self.prenorm(x)
h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1) blk_enc = self.block_proj(blk_emb)
h = torch.cat([h, blk_enc.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
h = self.block1(h) h = self.block1(h)
h = self.block2(h) h = self.block2(h)
h = self.out(h[:,-self.contraction_dim*4:]) h = self.out(h[:,-self.contraction_dim*4:])
@ -110,6 +114,7 @@ class TransformerDiffusion(nn.Module):
time_embed_dim=256, time_embed_dim=256,
time_proj_dim=64, time_proj_dim=64,
cond_proj_dim=256, cond_proj_dim=256,
blk_op_dim=128,
num_heads=4, num_heads=4,
dropout=0, dropout=0,
use_fp16=False, use_fp16=False,
@ -144,7 +149,7 @@ class TransformerDiffusion(nn.Module):
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim)) self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1) self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim, self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim, blk_op_dim,
num_heads, dropout) for _ in range(num_layers)]) num_heads, dropout) for _ in range(num_layers)])
self.out = nn.Sequential( self.out = nn.Sequential(