Add choke to lucidrains_dvae

This commit is contained in:
James Betker 2021-11-23 18:53:37 -07:00
parent 934395d4b8
commit 82d0e7720e
2 changed files with 12 additions and 4 deletions

View File

@ -76,6 +76,8 @@ class DiscreteVAE(nn.Module):
normalization = None, # ((0.5,) * 3, (0.5,) * 3), normalization = None, # ((0.5,) * 3, (0.5,) * 3),
record_codes = False, record_codes = False,
discretization_loss_averaging_steps = 100, discretization_loss_averaging_steps = 100,
encoder_choke=False,
choke_dim=128,
): ):
super().__init__() super().__init__()
has_resblocks = num_resnet_blocks > 0 has_resblocks = num_resnet_blocks > 0
@ -140,6 +142,10 @@ class DiscreteVAE(nn.Module):
if num_resnet_blocks > 0: if num_resnet_blocks > 0:
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1)) dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
if encoder_choke:
enc_layers.append(conv(innermost_dim, choke_dim, 1))
innermost_dim = choke_dim
enc_layers.append(conv(innermost_dim, codebook_dim, 1)) enc_layers.append(conv(innermost_dim, codebook_dim, 1))
dec_layers.append(conv(dec_out_chans, channels, 1)) dec_layers.append(conv(dec_out_chans, channels, 1))
@ -265,8 +271,9 @@ if __name__ == '__main__':
#o=v(torch.randn(1,3,256,256)) #o=v(torch.randn(1,3,256,256))
#print(o.shape) #print(o.shape)
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048, v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
hidden_dim=512, stride=2, num_resnet_blocks=3, kernel_size=4, num_layers=2, use_transposed_convs=True) hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False,
v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth')) encoder_choke=True, choke_dim=256)
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
#v.eval() #v.eval()
o=v(torch.randn(1,80,256)) o=v(torch.randn(1,80,256))
print(o[-1].shape) print(o[-1].shape)

View File

@ -114,9 +114,10 @@ if __name__ == "__main__":
if audio_mode: if audio_mode:
data = { data = {
'clip': im.to('cuda'), 'clip': im.to('cuda'),
'alt_clips': torch.zeros_like(refs[:,0].to('cuda')), 'alt_clips': refs.to('cuda'),
'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'), 'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'),
'GT_path': opt['image'] 'GT_path': opt['image'],
'resampled_clip': refs[:, 0].to('cuda')
} }
else: else:
data = { data = {