dvae mods

Trying to squeeze as much performance out of this net as possible
This commit is contained in:
James Betker 2021-08-25 08:55:13 -06:00
parent d05cc1f46c
commit 67bf7f5219

View File

@ -1,3 +1,4 @@
import functools
import math
from math import sqrt
@ -41,6 +42,19 @@ class ResBlock(nn.Module):
return self.net(x) + x
class UpsampledConv(nn.Module):
def __init__(self, conv, *args, **kwargs):
super().__init__()
assert 'stride' in kwargs.keys()
self.stride = kwargs['stride']
del kwargs['stride']
self.conv = conv(*args, **kwargs)
def forward(self, x):
up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
return self.conv(up)
class DiscreteVAE(nn.Module):
def __init__(
self,
@ -51,6 +65,10 @@ class DiscreteVAE(nn.Module):
num_resnet_blocks = 0,
hidden_dim = 64,
channels = 3,
stride = 2,
kernel_size = 4,
use_transposed_convs = True,
encoder_norm = False,
smooth_l1_loss = False,
straight_through = False,
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
@ -73,6 +91,8 @@ class DiscreteVAE(nn.Module):
else:
conv = nn.Conv1d
conv_transpose = nn.ConvTranspose1d
if not use_transposed_convs:
conv_transpose = functools.partial(UpsampledConv, conv)
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
dec_chans = list(reversed(enc_chans))
@ -87,9 +107,12 @@ class DiscreteVAE(nn.Module):
enc_layers = []
dec_layers = []
pad = (kernel_size - 1) // 2
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
if encoder_norm:
enc_layers.append(nn.GroupNorm(8, enc_out))
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
for _ in range(num_resnet_blocks):
dec_layers.insert(0, ResBlock(dec_chans[1], conv))
@ -171,7 +194,9 @@ class DiscreteVAE(nn.Module):
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
if self.training:
out = self.decoder(sampled)
out = sampled
for d in self.decoder:
out = d(out)
else:
# This is non-differentiable, but gives a better idea of how the network is actually performing.
out = self.decode(codes)
@ -202,7 +227,7 @@ if __name__ == '__main__':
#v = DiscreteVAE()
#o=v(torch.randn(1,3,256,256))
#print(o.shape)
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
v.eval()
o=v(torch.randn(1,1,256))
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096, hidden_dim=256, stride=4, num_resnet_blocks=1, kernel_size=5, num_layers=5, use_transposed_convs=False)
#v.eval()
o=v(torch.randn(1,1,4096))
print(o[-1].shape)