From a826d5f658903e4329aa3a2f1bd0a8ceaafadd61 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sun, 15 Aug 2021 20:54:10 -0600 Subject: [PATCH] Mods to dvae - Add resblock to each layer - Increase filter size for each layer - Use SiLU --- codes/models/gpt_voice/lucidrains_dvae.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index c59483df..6933a301 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -31,9 +31,9 @@ class ResBlock(nn.Module): super().__init__() self.net = nn.Sequential( conv(chan, chan, 3, padding = 1), - nn.ReLU(), + nn.SiLU(), conv(chan, chan, 3, padding = 1), - nn.ReLU(), + nn.SiLU(), conv(chan, chan, 1) ) @@ -74,7 +74,7 @@ class DiscreteVAE(nn.Module): conv = nn.Conv1d conv_transpose = nn.ConvTranspose1d - enc_chans = [hidden_dim] * num_layers + enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)] dec_chans = list(reversed(enc_chans)) enc_chans = [channels, *enc_chans] @@ -88,12 +88,14 @@ class DiscreteVAE(nn.Module): dec_layers = [] for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): - enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.ReLU())) - dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.ReLU())) + for _ in range(num_resnet_blocks): + dec_layers.append(ResBlock(dec_in, conv)) - for _ in range(num_resnet_blocks): - dec_layers.insert(0, ResBlock(dec_chans[1], conv)) - enc_layers.append(ResBlock(enc_chans[-1], conv)) + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, 4, stride = 2, padding = 1), nn.SiLU())) + dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, 4, stride = 2, padding = 1), nn.SiLU())) + + for _ in range(num_resnet_blocks): + enc_layers.append(ResBlock(enc_out, conv)) if num_resnet_blocks > 0: dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1)) @@ -202,7 +204,7 @@ if __name__ == '__main__': #v = DiscreteVAE() #o=v(torch.randn(1,3,256,256)) #print(o.shape) - v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256) + v = DiscreteVAE(channels=1, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=2048, hidden_dim=256, num_resnet_blocks=2) v.eval() o=v(torch.randn(1,1,256)) print(o[-1].shape)