Add VQVAE with no Conv2dTranspose
This commit is contained in:
parent
587a4f4050
commit
d919ae7148
|
@ -14,11 +14,12 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
# Borrowed from https://github.com/rosinality/vq-vae-2-pytorch
|
# This is an alternative implementation of VQVAE that uses convolutions with kernels of size 5 and
|
||||||
# Which was itself orrowed from https://github.com/deepmind/sonnet
|
# a "standard" upsampler rather than ConvTranspose.
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from kornia import filter2D
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
@ -28,6 +29,17 @@ from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, opt_get
|
from utils.util import checkpoint, opt_get
|
||||||
|
|
||||||
|
|
||||||
|
# Upsamples and blurs (similar to StyleGAN). Replaces ConvTranspose2D from the original paper.
|
||||||
|
class UpsampleConv(nn.Module):
|
||||||
|
def __init__(self, in_filters, out_filters, kernel_size, padding):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(in_filters, out_filters, kernel_size, padding=padding)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
up = torch.nn.functional.interpolate(x, scale_factor=2)
|
||||||
|
return self.conv(up)
|
||||||
|
|
||||||
|
|
||||||
class Quantize(nn.Module):
|
class Quantize(nn.Module):
|
||||||
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
|
def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -106,16 +118,16 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
if stride == 4:
|
if stride == 4:
|
||||||
blocks = [
|
blocks = [
|
||||||
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
|
nn.Conv2d(channel // 2, channel, 5, stride=2, padding=2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(channel, channel, 3, padding=1),
|
nn.Conv2d(channel, channel, 3, padding=1),
|
||||||
]
|
]
|
||||||
|
|
||||||
elif stride == 2:
|
elif stride == 2:
|
||||||
blocks = [
|
blocks = [
|
||||||
nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
|
nn.Conv2d(in_channel, channel // 2, 5, stride=2, padding=2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
nn.Conv2d(channel // 2, channel, 3, padding=1),
|
||||||
]
|
]
|
||||||
|
@ -147,17 +159,17 @@ class Decoder(nn.Module):
|
||||||
if stride == 4:
|
if stride == 4:
|
||||||
blocks.extend(
|
blocks.extend(
|
||||||
[
|
[
|
||||||
nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
|
UpsampleConv(channel, channel // 2, 5, padding=2),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.ConvTranspose2d(
|
UpsampleConv(
|
||||||
channel // 2, out_channel, 4, stride=2, padding=1
|
channel // 2, out_channel, 5, padding=2
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
elif stride == 2:
|
elif stride == 2:
|
||||||
blocks.append(
|
blocks.append(
|
||||||
nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
|
UpsampleConv(channel, out_channel, 5, padding=2)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.blocks = nn.Sequential(*blocks)
|
self.blocks = nn.Sequential(*blocks)
|
||||||
|
@ -188,8 +200,8 @@ class VQVAE(nn.Module):
|
||||||
)
|
)
|
||||||
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
self.quantize_conv_b = nn.Conv2d(codebook_dim + channel, codebook_dim, 1)
|
||||||
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
self.quantize_b = Quantize(codebook_dim, codebook_size)
|
||||||
self.upsample_t = nn.ConvTranspose2d(
|
self.upsample_t = UpsampleConv(
|
||||||
codebook_dim, codebook_dim, 4, stride=2, padding=1
|
codebook_dim, codebook_dim, 5, padding=2
|
||||||
)
|
)
|
||||||
self.dec = Decoder(
|
self.dec = Decoder(
|
||||||
codebook_dim + codebook_dim,
|
codebook_dim + codebook_dim,
|
||||||
|
@ -244,6 +256,11 @@ class VQVAE(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_vqvae(opt_net, opt):
|
def register_vqvae_normalized(opt_net, opt):
|
||||||
kw = opt_get(opt_net, ['kwargs'], {})
|
kw = opt_get(opt_net, ['kwargs'], {})
|
||||||
return VQVAE(**kw)
|
return VQVAE(**kw)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
v = VQVAE()
|
||||||
|
print(v(torch.randn(1,3,128,128))[0].shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user