Add VQVAE with no Conv2dTranspose

This commit is contained in:
James Betker 2021-01-18 08:49:59 -07:00
parent 587a4f4050
commit d919ae7148

View File

@ -14,11 +14,12 @@
# ============================================================================
# Borrowed from https://github.com/rosinality/vq-vae-2-pytorch
# Which was itself orrowed from https://github.com/deepmind/sonnet
# This is an alternative implementation of VQVAE that uses convolutions with kernels of size 5 and
# a "standard" upsampler rather than ConvTranspose.
import torch
from kornia import filter2D
from torch import nn
from torch.nn import functional as F
@ -28,6 +29,17 @@ from trainer.networks import register_model
from utils.util import checkpoint, opt_get
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
class UpsampleConv(nn.Module):
def __init__(self, in_filters, out_filters, kernel_size, padding):
super().__init__()
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
def forward(self, x):
up = torch.nn.functional.interpolate(x, scale_factor=2)
return self.conv(up)
class Quantize(nn.Module):
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
super().__init__()
@ -106,16 +118,16 @@ class Encoder(nn.Module):
if stride == 4:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(channel, channel, 3, padding=1),
]
elif stride == 2:
blocks = [
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
nn.ReLU(inplace=True),
nn.Conv2d(channel // 2, channel, 3, padding=1),
]
@ -147,17 +159,17 @@ class Decoder(nn.Module):
if stride == 4:
blocks.extend(
[
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
UpsampleConv(channel, channel // 2, 5, padding=2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
channel // 2, out_channel, 4, stride=2, padding=1
UpsampleConv(
channel // 2, out_channel, 5, padding=2
),
]
)
elif stride == 2:
blocks.append(
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
UpsampleConv(channel, out_channel, 5, padding=2)
)
self.blocks = nn.Sequential(*blocks)
@ -188,8 +200,8 @@ class VQVAE(nn.Module):
)
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
self.quantize_b = Quantize(codebook_dim, codebook_size)
self.upsample_t = nn.ConvTranspose2d(
codebook_dim, codebook_dim, 4, stride=2, padding=1
self.upsample_t = UpsampleConv(
codebook_dim, codebook_dim, 5, padding=2
)
self.dec = Decoder(
codebook_dim + codebook_dim,
@ -244,6 +256,11 @@ class VQVAE(nn.Module):
@register_model
def register_vqvae(opt_net, opt):
def register_vqvae_normalized(opt_net, opt):
kw = opt_get(opt_net, ['kwargs'], {})
return VQVAE(**kw)
if __name__ == '__main__':
v = VQVAE()
print(v(torch.randn(1,3,128,128))[0].shape)