Revert mods to lrdvae

They didn't really change anything
This commit is contained in:
James Betker 2021-08-17 09:09:29 -06:00
parent 8332923f5c
commit 17453ccbe8

View File

@ -31,9 +31,9 @@ class ResBlock(nn.Module):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
conv(chan, chan, 3, padding = 1), conv(chan, chan, 3, padding = 1),
nn.SiLU(), nn.ReLU(),
conv(chan, chan, 3, padding = 1), conv(chan, chan, 3, padding = 1),
nn.SiLU(), nn.ReLU(),
conv(chan, chan, 1) conv(chan, chan, 1)
) )
@ -88,14 +88,12 @@ class DiscreteVAE(nn.Module):
dec_layers = [] dec_layers = []
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
for _ in range(num_resnet_blocks): enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
dec_layers.append(ResBlock(dec_in, conv)) dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.SiLU())) for _ in range(num_resnet_blocks):
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.SiLU())) dec_layers.insert(0, ResBlock(dec_chans[1], conv))
enc_layers.append(ResBlock(enc_chans[-1], conv))
for _ in range(num_resnet_blocks):
enc_layers.append(ResBlock(enc_out, conv))
if num_resnet_blocks > 0: if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1)) dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
@ -204,7 +202,7 @@ if __name__ == '__main__':
#v = DiscreteVAE() #v = DiscreteVAE()
#o=v(torch.randn(1,3,256,256)) #o=v(torch.randn(1,3,256,256))
#print(o.shape) #print(o.shape)
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256, num_resnet_blocks=2) v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
v.eval() v.eval()
o=v(torch.randn(1,1,256)) o=v(torch.randn(1,1,256))
print(o[-1].shape) print(o[-1].shape)