simplify parameterization a bit

This commit is contained in:
James Betker 2022-07-20 23:41:54 -06:00
parent ee8ceed6da
commit b9d0f7e6de

View File

@ -29,20 +29,21 @@ class SubBlock(nn.Module):
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
self.pos_bias = RelativeQKBias(l=64)
self.attn_glu = cGLU(contraction_dim)
self.attnorm = nn.GroupNorm(8, contraction_dim)
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
self.ff_glu = cGLU(contraction_dim)
self.ffnorm = nn.GroupNorm(8, contraction_dim)
ff_contract = contraction_dim//2
self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1),
nn.GroupNorm(8, ff_contract),
cGLU(ff_contract))
self.ff2 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim*3//2, ff_contract, kernel_size=3, padding=1),
nn.GroupNorm(8, ff_contract),
cGLU(ff_contract))
def forward(self, x):
ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
ah = self.attn_glu(self.attnorm(ah))
h = torch.cat([ah, x], dim=1)
hf = self.dropout(checkpoint(self.ff, h))
hf = self.ff_glu(self.ffnorm(hf))
hf = self.dropout(checkpoint(self.ff1, h))
h = torch.cat([h, hf], dim=1)
return h
hf = self.dropout(checkpoint(self.ff2, h))
return torch.cat([h, hf], dim=1)
class ConcatAttentionBlock(TimestepBlock):
@ -157,8 +158,10 @@ class TransformerDiffusion(nn.Module):
def get_grad_norm_parameter_groups(self):
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers]))
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
[lyr.block1.ff2.parameters() for lyr in self.layers]))
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
[lyr.block2.ff2.parameters() for lyr in self.layers]))
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
groups = {
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
@ -295,6 +298,7 @@ def test_tfd():
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
unconditioned_percentage=.6)
model.get_grad_norm_parameter_groups()
for k in range(100):
x = model.input_to_random_resolution_and_window(clip, ts, diffuser)
model(x, ts, conditioning_input=cond)