diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py index 688c63a7..a8cc36c7 100644 --- a/codes/models/gpt_voice/lucidrains_dvae.py +++ b/codes/models/gpt_voice/lucidrains_dvae.py @@ -76,6 +76,8 @@ class DiscreteVAE(nn.Module): normalization = None, # ((0.5,) * 3, (0.5,) * 3), record_codes = False, discretization_loss_averaging_steps = 100, + encoder_choke=False, + choke_dim=128, ): super().__init__() has_resblocks = num_resnet_blocks > 0 @@ -140,6 +142,10 @@ class DiscreteVAE(nn.Module): if num_resnet_blocks > 0: dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) + if encoder_choke: + enc_layers.append(conv(innermost_dim, choke_dim, 1)) + innermost_dim = choke_dim + enc_layers.append(conv(innermost_dim, codebook_dim, 1)) dec_layers.append(conv(dec_out_chans, channels, 1)) @@ -265,8 +271,9 @@ if __name__ == '__main__': #o=v(torch.randn(1,3,256,256)) #print(o.shape) v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048, - hidden_dim=512, stride=2, num_resnet_blocks=3, kernel_size=4, num_layers=2, use_transposed_convs=True) - v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) + hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False, + encoder_choke=True, choke_dim=256) + #v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) #v.eval() o=v(torch.randn(1,80,256)) print(o[-1].shape) diff --git a/codes/scripts/diffusion/diffusion_noise_surfer.py b/codes/scripts/diffusion/diffusion_noise_surfer.py index 1b4cf9a2..f3d790ab 100644 --- a/codes/scripts/diffusion/diffusion_noise_surfer.py +++ b/codes/scripts/diffusion/diffusion_noise_surfer.py @@ -114,9 +114,10 @@ if __name__ == "__main__": if audio_mode: data = { 'clip': im.to('cuda'), - 'alt_clips': torch.zeros_like(refs[:,0].to('cuda')), + 'alt_clips': refs.to('cuda'), 'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'), - 'GT_path': opt['image'] + 'GT_path': opt['image'], + 'resampled_clip': refs[:, 0].to('cuda') } else: data = {