simplify parameterization a bit
This commit is contained in:
parent
ee8ceed6da
commit
b9d0f7e6de
|
@ -29,20 +29,21 @@ class SubBlock(nn.Module):
|
|||
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
|
||||
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
|
||||
self.pos_bias = RelativeQKBias(l=64)
|
||||
self.attn_glu = cGLU(contraction_dim)
|
||||
self.attnorm = nn.GroupNorm(8, contraction_dim)
|
||||
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
||||
self.ff_glu = cGLU(contraction_dim)
|
||||
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
||||
ff_contract = contraction_dim//2
|
||||
self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1),
|
||||
nn.GroupNorm(8, ff_contract),
|
||||
cGLU(ff_contract))
|
||||
self.ff2 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim*3//2, ff_contract, kernel_size=3, padding=1),
|
||||
nn.GroupNorm(8, ff_contract),
|
||||
cGLU(ff_contract))
|
||||
|
||||
def forward(self, x):
|
||||
ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
|
||||
ah = self.attn_glu(self.attnorm(ah))
|
||||
h = torch.cat([ah, x], dim=1)
|
||||
hf = self.dropout(checkpoint(self.ff, h))
|
||||
hf = self.ff_glu(self.ffnorm(hf))
|
||||
hf = self.dropout(checkpoint(self.ff1, h))
|
||||
h = torch.cat([h, hf], dim=1)
|
||||
return h
|
||||
hf = self.dropout(checkpoint(self.ff2, h))
|
||||
return torch.cat([h, hf], dim=1)
|
||||
|
||||
|
||||
class ConcatAttentionBlock(TimestepBlock):
|
||||
|
@ -157,8 +158,10 @@ class TransformerDiffusion(nn.Module):
|
|||
def get_grad_norm_parameter_groups(self):
|
||||
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
|
||||
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
|
||||
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers]))
|
||||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers]))
|
||||
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
|
||||
[lyr.block1.ff2.parameters() for lyr in self.layers]))
|
||||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
|
||||
[lyr.block2.ff2.parameters() for lyr in self.layers]))
|
||||
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
|
||||
groups = {
|
||||
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
|
||||
|
@ -295,6 +298,7 @@ def test_tfd():
|
|||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
|
||||
unconditioned_percentage=.6)
|
||||
model.get_grad_norm_parameter_groups()
|
||||
for k in range(100):
|
||||
x = model.input_to_random_resolution_and_window(clip, ts, diffuser)
|
||||
model(x, ts, conditioning_input=cond)
|
||||
|
|
Loading…
Reference in New Issue
Block a user