dvae mods
Trying to squeeze as much performance out of this net as possible
This commit is contained in:
parent
d05cc1f46c
commit
67bf7f5219
|
@ -1,3 +1,4 @@
|
|||
import functools
|
||||
import math
|
||||
from math import sqrt
|
||||
|
||||
|
@ -41,6 +42,19 @@ class ResBlock(nn.Module):
|
|||
return self.net(x) + x
|
||||
|
||||
|
||||
class UpsampledConv(nn.Module):
|
||||
def __init__(self, conv, *args, **kwargs):
|
||||
super().__init__()
|
||||
assert 'stride' in kwargs.keys()
|
||||
self.stride = kwargs['stride']
|
||||
del kwargs['stride']
|
||||
self.conv = conv(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
|
||||
return self.conv(up)
|
||||
|
||||
|
||||
class DiscreteVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -51,6 +65,10 @@ class DiscreteVAE(nn.Module):
|
|||
num_resnet_blocks = 0,
|
||||
hidden_dim = 64,
|
||||
channels = 3,
|
||||
stride = 2,
|
||||
kernel_size = 4,
|
||||
use_transposed_convs = True,
|
||||
encoder_norm = False,
|
||||
smooth_l1_loss = False,
|
||||
straight_through = False,
|
||||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||
|
@ -73,6 +91,8 @@ class DiscreteVAE(nn.Module):
|
|||
else:
|
||||
conv = nn.Conv1d
|
||||
conv_transpose = nn.ConvTranspose1d
|
||||
if not use_transposed_convs:
|
||||
conv_transpose = functools.partial(UpsampledConv, conv)
|
||||
|
||||
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
||||
dec_chans = list(reversed(enc_chans))
|
||||
|
@ -87,9 +107,12 @@ class DiscreteVAE(nn.Module):
|
|||
enc_layers = []
|
||||
dec_layers = []
|
||||
|
||||
pad = (kernel_size - 1) // 2
|
||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
|
||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
|
||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
|
||||
if encoder_norm:
|
||||
enc_layers.append(nn.GroupNorm(8, enc_out))
|
||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
|
||||
|
||||
for _ in range(num_resnet_blocks):
|
||||
dec_layers.insert(0, ResBlock(dec_chans[1], conv))
|
||||
|
@ -171,7 +194,9 @@ class DiscreteVAE(nn.Module):
|
|||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||
|
||||
if self.training:
|
||||
out = self.decoder(sampled)
|
||||
out = sampled
|
||||
for d in self.decoder:
|
||||
out = d(out)
|
||||
else:
|
||||
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||
out = self.decode(codes)
|
||||
|
@ -202,7 +227,7 @@ if __name__ == '__main__':
|
|||
#v = DiscreteVAE()
|
||||
#o=v(torch.randn(1,3,256,256))
|
||||
#print(o.shape)
|
||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
|
||||
v.eval()
|
||||
o=v(torch.randn(1,1,256))
|
||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096, hidden_dim=256, stride=4, num_resnet_blocks=1, kernel_size=5, num_layers=5, use_transposed_convs=False)
|
||||
#v.eval()
|
||||
o=v(torch.randn(1,1,4096))
|
||||
print(o[-1].shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user