simplify parameterization a bit
This commit is contained in:
parent
ee8ceed6da
commit
b9d0f7e6de
|
@ -29,20 +29,21 @@ class SubBlock(nn.Module):
|
||||||
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
|
self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads)
|
||||||
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
|
self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False)
|
||||||
self.pos_bias = RelativeQKBias(l=64)
|
self.pos_bias = RelativeQKBias(l=64)
|
||||||
self.attn_glu = cGLU(contraction_dim)
|
ff_contract = contraction_dim//2
|
||||||
self.attnorm = nn.GroupNorm(8, contraction_dim)
|
self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1),
|
||||||
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
nn.GroupNorm(8, ff_contract),
|
||||||
self.ff_glu = cGLU(contraction_dim)
|
cGLU(ff_contract))
|
||||||
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
self.ff2 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim*3//2, ff_contract, kernel_size=3, padding=1),
|
||||||
|
nn.GroupNorm(8, ff_contract),
|
||||||
|
cGLU(ff_contract))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
|
ah = self.dropout(self.attn(x, mask=self.mask, qk_bias=self.pos_bias(x.shape[-1])))
|
||||||
ah = self.attn_glu(self.attnorm(ah))
|
|
||||||
h = torch.cat([ah, x], dim=1)
|
h = torch.cat([ah, x], dim=1)
|
||||||
hf = self.dropout(checkpoint(self.ff, h))
|
hf = self.dropout(checkpoint(self.ff1, h))
|
||||||
hf = self.ff_glu(self.ffnorm(hf))
|
|
||||||
h = torch.cat([h, hf], dim=1)
|
h = torch.cat([h, hf], dim=1)
|
||||||
return h
|
hf = self.dropout(checkpoint(self.ff2, h))
|
||||||
|
return torch.cat([h, hf], dim=1)
|
||||||
|
|
||||||
|
|
||||||
class ConcatAttentionBlock(TimestepBlock):
|
class ConcatAttentionBlock(TimestepBlock):
|
||||||
|
@ -157,8 +158,10 @@ class TransformerDiffusion(nn.Module):
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
|
attn1 = list(itertools.chain.from_iterable([lyr.block1.attn.parameters() for lyr in self.layers]))
|
||||||
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
|
attn2 = list(itertools.chain.from_iterable([lyr.block2.attn.parameters() for lyr in self.layers]))
|
||||||
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff.parameters() for lyr in self.layers]))
|
ff1 = list(itertools.chain.from_iterable([lyr.block1.ff1.parameters() for lyr in self.layers] +
|
||||||
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff.parameters() for lyr in self.layers]))
|
[lyr.block1.ff2.parameters() for lyr in self.layers]))
|
||||||
|
ff2 = list(itertools.chain.from_iterable([lyr.block2.ff1.parameters() for lyr in self.layers] +
|
||||||
|
[lyr.block2.ff2.parameters() for lyr in self.layers]))
|
||||||
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
|
blkout_layers = list(itertools.chain.from_iterable([lyr.out.parameters() for lyr in self.layers]))
|
||||||
groups = {
|
groups = {
|
||||||
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
|
'prenorms': list(itertools.chain.from_iterable([lyr.prenorm.parameters() for lyr in self.layers])),
|
||||||
|
@ -295,6 +298,7 @@ def test_tfd():
|
||||||
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
model = TransformerDiffusion(in_channels=256, model_channels=1024, contraction_dim=512,
|
||||||
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
|
num_heads=512//64, input_vec_dim=256, num_layers=12, dropout=.1,
|
||||||
unconditioned_percentage=.6)
|
unconditioned_percentage=.6)
|
||||||
|
model.get_grad_norm_parameter_groups()
|
||||||
for k in range(100):
|
for k in range(100):
|
||||||
x = model.input_to_random_resolution_and_window(clip, ts, diffuser)
|
x = model.input_to_random_resolution_and_window(clip, ts, diffuser)
|
||||||
model(x, ts, conditioning_input=cond)
|
model(x, ts, conditioning_input=cond)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user