forked from mrq/DL-Art-School
better downsampler
This commit is contained in:
parent
28d95e3141
commit
41abf776d9
|
@ -18,17 +18,17 @@ class UpperEncoder(nn.Module):
|
|||
super().__init__()
|
||||
attn = []
|
||||
def edim(m):
|
||||
dd = max(hidden_dim // m, 128, spec_dim)
|
||||
dd = min(spec_dim + m * 128, hidden_dim)
|
||||
return ceil_multiple(dd, 8)
|
||||
self.downsampler = nn.Sequential(
|
||||
ResBlock(spec_dim, out_channels=edim(6), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(6), out_channels=edim(5), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(5), out_channels=edim(4), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(4), out_channels=edim(3), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1),
|
||||
ResBlock(edim(3), out_channels=edim(2), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(2), out_channels=edim(2), use_conv=True, dims=1),
|
||||
ResBlock(edim(2), out_channels=hidden_dim, use_conv=True, dims=1, down=True))
|
||||
ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1),
|
||||
ResBlock(edim(4), out_channels=edim(5), use_conv=True, dims=1, down=True),
|
||||
ResBlock(edim(5), out_channels=edim(5), use_conv=True, dims=1),
|
||||
ResBlock(edim(5), out_channels=hidden_dim, use_conv=True, dims=1, down=True))
|
||||
self.encoder = nn.Sequential(
|
||||
AttentionBlock(hidden_dim, 4, do_activation=True),
|
||||
ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1),
|
||||
|
|
Loading…
Reference in New Issue
Block a user