Try to squeeze a bit more performance out of this arch
This commit is contained in:
parent
b9d0f7e6de
commit
767f963392
|
@ -29,6 +29,8 @@ class SubBlock(nn.Module):
|
|||
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
|
||||
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
|
||||
self.pos_bias = RelativeQKBias(l=64)
|
||||
# Feedforward is split into two groups: one with kernel_size=1, the other with kernel_size=3. The idea is that
|
||||
# localized processing is valuable (and more efficient!)
|
||||
ff_contract = contraction_dim//2
|
||||
self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1),
|
||||
nn.GroupNorm(8, ff_contract),
|
||||
|
@ -47,18 +49,20 @@ class SubBlock(nn.Module):
|
|||
|
||||
|
||||
class ConcatAttentionBlock(TimestepBlock):
|
||||
def __init__(self, trunk_dim, contraction_dim, blk_dim, heads, dropout):
|
||||
def __init__(self, trunk_dim, contraction_dim, blk_in_dim, blk_proj_dim, heads, dropout):
|
||||
super().__init__()
|
||||
self.contraction_dim = contraction_dim
|
||||
self.prenorm = nn.GroupNorm(8, trunk_dim)
|
||||
self.block1 = SubBlock(trunk_dim+blk_dim, contraction_dim, heads, dropout)
|
||||
self.block2 = SubBlock(trunk_dim+blk_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||
self.block_proj = nn.Linear(blk_in_dim, blk_proj_dim)
|
||||
self.block1 = SubBlock(trunk_dim+blk_proj_dim, contraction_dim, heads, dropout)
|
||||
self.block2 = SubBlock(trunk_dim+blk_proj_dim+contraction_dim*2, contraction_dim, heads, dropout)
|
||||
self.out = nn.Conv1d(contraction_dim*4, trunk_dim, kernel_size=1, bias=False)
|
||||
self.out.weight.data.zero_()
|
||||
|
||||
def forward(self, x, blk_emb):
|
||||
h = self.prenorm(x)
|
||||
h = torch.cat([h, blk_emb.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
|
||||
blk_enc = self.block_proj(blk_emb)
|
||||
h = torch.cat([h, blk_enc.unsqueeze(-1).repeat(1,1,x.shape[-1])], dim=1)
|
||||
h = self.block1(h)
|
||||
h = self.block2(h)
|
||||
h = self.out(h[:,-self.contraction_dim*4:])
|
||||
|
@ -110,6 +114,7 @@ class TransformerDiffusion(nn.Module):
|
|||
time_embed_dim=256,
|
||||
time_proj_dim=64,
|
||||
cond_proj_dim=256,
|
||||
blk_op_dim=128,
|
||||
num_heads=4,
|
||||
dropout=0,
|
||||
use_fp16=False,
|
||||
|
@ -144,7 +149,7 @@ class TransformerDiffusion(nn.Module):
|
|||
self.unconditioned_embedding = nn.Parameter(torch.randn(1,cond_proj_dim))
|
||||
|
||||
self.inp_block = conv_nd(1, in_channels+input_vec_dim, model_channels, 3, 1, 1)
|
||||
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim,
|
||||
self.layers = TimestepEmbedSequential(*[ConcatAttentionBlock(model_channels, contraction_dim, time_proj_dim*3 + cond_proj_dim, blk_op_dim,
|
||||
num_heads, dropout) for _ in range(num_layers)])
|
||||
|
||||
self.out = nn.Sequential(
|
||||
|
|
Loading…
Reference in New Issue
Block a user