Add VQVAE with no Conv2dTranspose
This commit is contained in:
parent
587a4f4050
commit
d919ae7148
|
@ -14,11 +14,12 @@
|
|||
# ============================================================================
|
||||
|
||||
|
||||
# Borrowed from https://github.com/rosinality/vq-vae-2-pytorch
|
||||
# Which was itself orrowed from https://github.com/deepmind/sonnet
|
||||
# This is an alternative implementation of VQVAE that uses convolutions with kernels of size 5 and
|
||||
# a "standard" upsampler rather than ConvTranspose.
|
||||
|
||||
|
||||
import torch
|
||||
from kornia import filter2D
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
@ -28,6 +29,17 @@ from trainer.networks import register_model
|
|||
from utils.util import checkpoint, opt_get
|
||||
|
||||
|
||||
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
|
||||
class UpsampleConv(nn.Module):
|
||||
def __init__(self, in_filters, out_filters, kernel_size, padding):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
class Quantize(nn.Module):
|
||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
|
||||
super().__init__()
|
||||
|
@ -106,16 +118,16 @@ class Encoder(nn.Module):
|
|||
|
||||
if stride == 4:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
|
||||
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel, channel, 3, padding=1),
|
||||
]
|
||||
|
||||
elif stride == 2:
|
||||
blocks = [
|
||||
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
||||
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
||||
]
|
||||
|
@ -147,17 +159,17 @@ class Decoder(nn.Module):
|
|||
if stride == 4:
|
||||
blocks.extend(
|
||||
[
|
||||
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
|
||||
UpsampleConv(channel, channel // 2, 5, padding=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(
|
||||
channel // 2, out_channel, 4, stride=2, padding=1
|
||||
UpsampleConv(
|
||||
channel // 2, out_channel, 5, padding=2
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
elif stride == 2:
|
||||
blocks.append(
|
||||
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
|
||||
UpsampleConv(channel, out_channel, 5, padding=2)
|
||||
)
|
||||
|
||||
self.blocks = nn.Sequential(*blocks)
|
||||
|
@ -188,8 +200,8 @@ class VQVAE(nn.Module):
|
|||
)
|
||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
||||
self.upsample_t = nn.ConvTranspose2d(
|
||||
codebook_dim, codebook_dim, 4, stride=2, padding=1
|
||||
self.upsample_t = UpsampleConv(
|
||||
codebook_dim, codebook_dim, 5, padding=2
|
||||
)
|
||||
self.dec = Decoder(
|
||||
codebook_dim + codebook_dim,
|
||||
|
@ -244,6 +256,11 @@ class VQVAE(nn.Module):
|
|||
|
||||
|
||||
@register_model
|
||||
def register_vqvae(opt_net, opt):
|
||||
def register_vqvae_normalized(opt_net, opt):
|
||||
kw = opt_get(opt_net, ['kwargs'], {})
|
||||
return VQVAE(**kw)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v = VQVAE()
|
||||
print(v(torch.randn(1,3,128,128))[0].shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user