forked from mrq/DL-Art-School
Mods to dvae
- Add resblock to each layer - Increase filter size for each layer - Use SiLU
This commit is contained in:
parent
b8bec22f1a
commit
a826d5f658
|
@ -31,9 +31,9 @@ class ResBlock(nn.Module):
|
|||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
conv(chan, chan, 3, padding = 1),
|
||||
nn.ReLU(),
|
||||
nn.SiLU(),
|
||||
conv(chan, chan, 3, padding = 1),
|
||||
nn.ReLU(),
|
||||
nn.SiLU(),
|
||||
conv(chan, chan, 1)
|
||||
)
|
||||
|
||||
|
@ -74,7 +74,7 @@ class DiscreteVAE(nn.Module):
|
|||
conv = nn.Conv1d
|
||||
conv_transpose = nn.ConvTranspose1d
|
||||
|
||||
enc_chans = [hidden_dim] * num_layers
|
||||
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
||||
dec_chans = list(reversed(enc_chans))
|
||||
|
||||
enc_chans = [channels, *enc_chans]
|
||||
|
@ -88,12 +88,14 @@ class DiscreteVAE(nn.Module):
|
|||
dec_layers = []
|
||||
|
||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
|
||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
|
||||
for _ in range(num_resnet_blocks):
|
||||
dec_layers.append(ResBlock(dec_in, conv))
|
||||
|
||||
for _ in range(num_resnet_blocks):
|
||||
dec_layers.insert(0, ResBlock(dec_chans[1], conv))
|
||||
enc_layers.append(ResBlock(enc_chans[-1], conv))
|
||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.SiLU()))
|
||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.SiLU()))
|
||||
|
||||
for _ in range(num_resnet_blocks):
|
||||
enc_layers.append(ResBlock(enc_out, conv))
|
||||
|
||||
if num_resnet_blocks > 0:
|
||||
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
|
||||
|
@ -202,7 +204,7 @@ if __name__ == '__main__':
|
|||
#v = DiscreteVAE()
|
||||
#o=v(torch.randn(1,3,256,256))
|
||||
#print(o.shape)
|
||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
|
||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256, num_resnet_blocks=2)
|
||||
v.eval()
|
||||
o=v(torch.randn(1,1,256))
|
||||
print(o[-1].shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user