diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index cba5c728..9b171619 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -1,3 +1,4 @@ +import functools import math from math import sqrt @@ -41,6 +42,19 @@ class ResBlock(nn.Module): return self.net(x) + x +class UpsampledConv(nn.Module): + def __init__(self, conv, *args, **kwargs): + super().__init__() + assert 'stride' in kwargs.keys() + self.stride = kwargs['stride'] + del kwargs['stride'] + self.conv = conv(*args, **kwargs) + + def forward(self, x): + up = nn.functional.interpolate(x, scale_factor=self.stride, mode='nearest') + return self.conv(up) + + class DiscreteVAE(nn.Module): def __init__( self, @@ -51,6 +65,10 @@ class DiscreteVAE(nn.Module): num_resnet_blocks = 0, hidden_dim = 64, channels = 3, + stride = 2, + kernel_size = 4, + use_transposed_convs = True, + encoder_norm = False, smooth_l1_loss = False, straight_through = False, normalization = None, # ((0.5,) * 3, (0.5,) * 3), @@ -73,6 +91,8 @@ class DiscreteVAE(nn.Module): else: conv = nn.Conv1d conv_transpose = nn.ConvTranspose1d + if not use_transposed_convs: + conv_transpose = functools.partial(UpsampledConv, conv) enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)] dec_chans = list(reversed(enc_chans)) @@ -87,9 +107,12 @@ class DiscreteVAE(nn.Module): enc_layers = [] dec_layers = [] + pad = (kernel_size - 1) // 2 for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): - enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU())) - dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU())) + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), nn.ReLU())) + if encoder_norm: + enc_layers.append(nn.GroupNorm(8, enc_out)) + dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), nn.ReLU())) for _ in range(num_resnet_blocks): dec_layers.insert(0, ResBlock(dec_chans[1], conv)) @@ -171,7 +194,9 @@ class DiscreteVAE(nn.Module): sampled = sampled.permute((0,3,1,2) if len(img.shape) == 4 else (0,2,1)) if self.training: - out = self.decoder(sampled) + out = sampled + for d in self.decoder: + out = d(out) else: # This is non-differentiable, but gives a better idea of how the network is actually performing. out = self.decode(codes) @@ -202,7 +227,7 @@ if __name__ == '__main__': #v = DiscreteVAE() #o=v(torch.randn(1,3,256,256)) #print(o.shape) - v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256) - v.eval() - o=v(torch.randn(1,1,256)) + v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096, hidden_dim=256, stride=4, num_resnet_blocks=1, kernel_size=5, num_layers=5, use_transposed_convs=False) + #v.eval() + o=v(torch.randn(1,1,4096)) print(o[-1].shape)