Add VQVAE with no Conv2dTranspose

This commit is contained in:
James Betker 2021-01-18 08:49:59 -07:00
parent 587a4f4050
commit d919ae7148

View File

@ -14,11 +14,12 @@
# ============================================================================ # ============================================================================
# Borrowed from https://github.com/rosinality/vq-vae-2-pytorch # This is an alternative implementation of VQVAE that uses convolutions with kernels of size 5 and
# Which was itself orrowed from https://github.com/deepmind/sonnet # a "standard" upsampler rather than ConvTranspose.
import torch import torch
from kornia import filter2D
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
@ -28,6 +29,17 @@ from trainer.networks import register_model
from utils.util import checkpoint, opt_get from utils.util import checkpoint, opt_get
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size, padding):
super().__init__()
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
return self.conv(up)
class Quantize(nn.Module): class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5): def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
super().__init__() super().__init__()
@ -106,16 +118,16 @@ class Encoder(nn.Module):
if stride == 4: if stride == 4:
blocks = [ blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1), nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1), nn.Conv2d(channel, channel, 3, padding=1),
] ]
elif stride == 2: elif stride == 2:
blocks = [ blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1), nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1), nn.Conv2d(channel // 2, channel, 3, padding=1),
] ]
@ -147,17 +159,17 @@ class Decoder(nn.Module):
if stride == 4: if stride == 4:
blocks.extend( blocks.extend(
[ [
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1), UpsampleConv(channel, channel // 2, 5, padding=2),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.ConvTranspose2d( UpsampleConv(
channel // 2, out_channel, 4, stride=2, padding=1 channel // 2, out_channel, 5, padding=2
), ),
] ]
) )
elif stride == 2: elif stride == 2:
blocks.append( blocks.append(
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1) UpsampleConv(channel, out_channel, 5, padding=2)
) )
self.blocks = nn.Sequential(*blocks) self.blocks = nn.Sequential(*blocks)
@ -188,8 +200,8 @@ class VQVAE(nn.Module):
) )
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1) self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size) self.quantize_b = Quantize(codebook_dim, codebook_size)
self.upsample_t = nn.ConvTranspose2d( self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, 4, stride=2, padding=1 codebook_dim, codebook_dim, 5, padding=2
) )
self.dec = Decoder( self.dec = Decoder(
codebook_dim + codebook_dim, codebook_dim + codebook_dim,
@ -244,6 +256,11 @@ class VQVAE(nn.Module):
@register_model @register_model
def register_vqvae(opt_net, opt): def register_vqvae_normalized(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {}) kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE(**kw) return VQVAE(**kw)
if __name__ == '__main__':
v = VQVAE()
print(v(torch.randn(1,3,128,128))[0].shape)