forked from mrq/DL-Art-School
dvae mods
Trying to squeeze as much performance out of this net as possible
This commit is contained in:
parent
d05cc1f46c
commit
67bf7f5219
|
@ -1,3 +1,4 @@
|
||||||
|
import functools
|
||||||
import math
|
import math
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
|
||||||
|
@ -41,6 +42,19 @@ class ResBlock(nn.Module):
|
||||||
return self.net(x) + x
|
return self.net(x) + x
|
||||||
|
|
||||||
|
|
||||||
|
class UpsampledConv(nn.Module):
|
||||||
|
def __init__(self, conv, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
assert 'stride' in kwargs.keys()
|
||||||
|
self.stride = kwargs['stride']
|
||||||
|
del kwargs['stride']
|
||||||
|
self.conv = conv(*args, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest')
|
||||||
|
return self.conv(up)
|
||||||
|
|
||||||
|
|
||||||
class DiscreteVAE(nn.Module):
|
class DiscreteVAE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -51,6 +65,10 @@ class DiscreteVAE(nn.Module):
|
||||||
num_resnet_blocks = 0,
|
num_resnet_blocks = 0,
|
||||||
hidden_dim = 64,
|
hidden_dim = 64,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
|
stride = 2,
|
||||||
|
kernel_size = 4,
|
||||||
|
use_transposed_convs = True,
|
||||||
|
encoder_norm = False,
|
||||||
smooth_l1_loss = False,
|
smooth_l1_loss = False,
|
||||||
straight_through = False,
|
straight_through = False,
|
||||||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||||
|
@ -73,6 +91,8 @@ class DiscreteVAE(nn.Module):
|
||||||
else:
|
else:
|
||||||
conv = nn.Conv1d
|
conv = nn.Conv1d
|
||||||
conv_transpose = nn.ConvTranspose1d
|
conv_transpose = nn.ConvTranspose1d
|
||||||
|
if not use_transposed_convs:
|
||||||
|
conv_transpose = functools.partial(UpsampledConv, conv)
|
||||||
|
|
||||||
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)]
|
||||||
dec_chans = list(reversed(enc_chans))
|
dec_chans = list(reversed(enc_chans))
|
||||||
|
@ -87,9 +107,12 @@ class DiscreteVAE(nn.Module):
|
||||||
enc_layers = []
|
enc_layers = []
|
||||||
dec_layers = []
|
dec_layers = []
|
||||||
|
|
||||||
|
pad = (kernel_size - 1) // 2
|
||||||
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io):
|
||||||
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU()))
|
enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
|
||||||
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU()))
|
if encoder_norm:
|
||||||
|
enc_layers.append(nn.GroupNorm(8, enc_out))
|
||||||
|
dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), nn.ReLU()))
|
||||||
|
|
||||||
for _ in range(num_resnet_blocks):
|
for _ in range(num_resnet_blocks):
|
||||||
dec_layers.insert(0, ResBlock(dec_chans[1], conv))
|
dec_layers.insert(0, ResBlock(dec_chans[1], conv))
|
||||||
|
@ -171,7 +194,9 @@ class DiscreteVAE(nn.Module):
|
||||||
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1))
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
out = self.decoder(sampled)
|
out = sampled
|
||||||
|
for d in self.decoder:
|
||||||
|
out = d(out)
|
||||||
else:
|
else:
|
||||||
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
# This is non-differentiable, but gives a better idea of how the network is actually performing.
|
||||||
out = self.decode(codes)
|
out = self.decode(codes)
|
||||||
|
@ -202,7 +227,7 @@ if __name__ == '__main__':
|
||||||
#v = DiscreteVAE()
|
#v = DiscreteVAE()
|
||||||
#o=v(torch.randn(1,3,256,256))
|
#o=v(torch.randn(1,3,256,256))
|
||||||
#print(o.shape)
|
#print(o.shape)
|
||||||
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256)
|
v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096, hidden_dim=256, stride=4, num_resnet_blocks=1, kernel_size=5, num_layers=5, use_transposed_convs=False)
|
||||||
v.eval()
|
#v.eval()
|
||||||
o=v(torch.randn(1,1,256))
|
o=v(torch.randn(1,1,4096))
|
||||||
print(o[-1].shape)
|
print(o[-1].shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user