diff --git a/codes/models/vqvae/vqvae_no_conv_transpose.py b/codes/models/vqvae/vqvae_no_conv_transpose.py index d1418d1d..1d28a82d 100644 --- a/codes/models/vqvae/vqvae_no_conv_transpose.py +++ b/codes/models/vqvae/vqvae_no_conv_transpose.py @@ -14,11 +14,12 @@ # ============================================================================ -# Borrowed from https://github.com/rosinality/vq-vae-2-pytorch -# Which was itself orrowed from https://github.com/deepmind/sonnet +# This is an alternative implementation of VQVAE that uses convolutions with kernels of size 5 and +# a "standard" upsampler rather than ConvTranspose. import torch +from kornia import filter2D from torch import nn from torch.nn import functional as F @@ -28,6 +29,17 @@ from trainer.networks import register_model from utils.util import checkpoint, opt_get +# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper. +class UpsampleConv(nn.Module): + def __init__(self, in_filters, out_filters, kernel_size, padding): + super().__init__() + self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding) + + def forward(self, x): + up = torch.nn.functional.interpolate(x, scale_factor=2) + return self.conv(up) + + class Quantize(nn.Module): def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): super().__init__() @@ -106,16 +118,16 @@ class Encoder(nn.Module): if stride == 4: blocks = [ - nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), + nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), nn.ReLU(inplace=True), - nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1), + nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2), nn.ReLU(inplace=True), nn.Conv2d(channel, channel, 3, padding=1), ] elif stride == 2: blocks = [ - nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), + nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2), nn.ReLU(inplace=True), nn.Conv2d(channel // 2, channel, 3, padding=1), ] @@ -147,17 +159,17 @@ class Decoder(nn.Module): if stride == 4: blocks.extend( [ - nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1), + UpsampleConv(channel, channel // 2, 5, padding=2), nn.ReLU(inplace=True), - nn.ConvTranspose2d( - channel // 2, out_channel, 4, stride=2, padding=1 + UpsampleConv( + channel // 2, out_channel, 5, padding=2 ), ] ) elif stride == 2: blocks.append( - nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1) + UpsampleConv(channel, out_channel, 5, padding=2) ) self.blocks = nn.Sequential(*blocks) @@ -188,8 +200,8 @@ class VQVAE(nn.Module): ) self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) self.quantize_b = Quantize(codebook_dim, codebook_size) - self.upsample_t = nn.ConvTranspose2d( - codebook_dim, codebook_dim, 4, stride=2, padding=1 + self.upsample_t = UpsampleConv( + codebook_dim, codebook_dim, 5, padding=2 ) self.dec = Decoder( codebook_dim + codebook_dim, @@ -244,6 +256,11 @@ class VQVAE(nn.Module): @register_model -def register_vqvae(opt_net, opt): +def register_vqvae_normalized(opt_net, opt): kw = opt_get(opt_net, ['kwargs'], {}) return VQVAE(**kw) + + +if __name__ == '__main__': + v = VQVAE() + print(v(torch.randn(1,3,128,128))[0].shape)