diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index eedaa9ce..37aa8f47 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -78,7 +78,6 @@ class DiscreteVAE(nn.Module): discretization_loss_averaging_steps = 100, ): super().__init__() - assert num_layers >= 1, 'number of layers must be greater than or equal to 1' has_resblocks = num_resnet_blocks > 0 self.num_tokens = num_tokens @@ -106,35 +105,43 @@ class DiscreteVAE(nn.Module): assert NotImplementedError() - enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)] - dec_chans = list(reversed(enc_chans)) - - enc_chans = [channels, *enc_chans] - - dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] - dec_chans = [dec_init_chan, *dec_chans] - - enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) - enc_layers = [] dec_layers = [] - pad = (kernel_size - 1) // 2 - for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): - enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act())) - if encoder_norm: - enc_layers.append(nn.GroupNorm(8, enc_out)) - dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act())) + if num_layers > 0: + enc_chans = [hidden_dim * 2 ** i for i in range(num_layers)] + dec_chans = list(reversed(enc_chans)) + + enc_chans = [channels, *enc_chans] + + dec_init_chan = codebook_dim if not has_resblocks else dec_chans[0] + dec_chans = [dec_init_chan, *dec_chans] + + enc_chans_io, dec_chans_io = map(lambda t: list(zip(t[:-1], t[1:])), (enc_chans, dec_chans)) + + pad = (kernel_size - 1) // 2 + for (enc_in, enc_out), (dec_in, dec_out) in zip(enc_chans_io, dec_chans_io): + enc_layers.append(nn.Sequential(conv(enc_in, enc_out, kernel_size, stride = stride, padding = pad), act())) + if encoder_norm: + enc_layers.append(nn.GroupNorm(8, enc_out)) + dec_layers.append(nn.Sequential(conv_transpose(dec_in, dec_out, kernel_size, stride = stride, padding = pad), act())) + dec_out_chans = dec_chans[-1] + innermost_dim = dec_chans[0] + else: + enc_layers.append(nn.Sequential(conv(channels, hidden_dim, 1), act())) + dec_out_chans = hidden_dim + innermost_dim = hidden_dim + for _ in range(num_resnet_blocks): - dec_layers.insert(0, ResBlock(dec_chans[1], conv, act)) - enc_layers.append(ResBlock(enc_chans[-1], conv, act)) + dec_layers.insert(0, ResBlock(innermost_dim, conv, act)) + enc_layers.append(ResBlock(innermost_dim, conv, act)) if num_resnet_blocks > 0: - dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1)) + dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) - enc_layers.append(conv(enc_chans[-1], codebook_dim, 1)) - dec_layers.append(conv(dec_chans[-1], channels, 1)) + enc_layers.append(conv(innermost_dim, codebook_dim, 1)) + dec_layers.append(conv(dec_out_chans, channels, 1)) self.encoder = nn.Sequential(*enc_layers) self.decoder = nn.Sequential(*dec_layers) @@ -258,7 +265,7 @@ if __name__ == '__main__': #o=v(torch.randn(1,3,256,256)) #print(o.shape) v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096, - hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False) + hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=0, use_transposed_convs=False) #v.eval() o=v(torch.randn(1,80,256)) print(o[-1].shape)