Add choke to lucidrains_dvae
This commit is contained in:
parent
934395d4b8
commit
82d0e7720e
|
@ -76,6 +76,8 @@ class DiscreteVAE(nn.Module):
|
|||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||
record_codes = False,
|
||||
discretization_loss_averaging_steps = 100,
|
||||
encoder_choke=False,
|
||||
choke_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
has_resblocks = num_resnet_blocks > 0
|
||||
|
@ -140,6 +142,10 @@ class DiscreteVAE(nn.Module):
|
|||
if num_resnet_blocks > 0:
|
||||
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
||||
|
||||
if encoder_choke:
|
||||
enc_layers.append(conv(innermost_dim, choke_dim, 1))
|
||||
innermost_dim = choke_dim
|
||||
|
||||
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
||||
dec_layers.append(conv(dec_out_chans, channels, 1))
|
||||
|
||||
|
@ -265,8 +271,9 @@ if __name__ == '__main__':
|
|||
#o=v(torch.randn(1,3,256,256))
|
||||
#print(o.shape)
|
||||
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
|
||||
hidden_dim=512, stride=2, num_resnet_blocks=3, kernel_size=4, num_layers=2, use_transposed_convs=True)
|
||||
v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
||||
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False,
|
||||
encoder_choke=True, choke_dim=256)
|
||||
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
||||
#v.eval()
|
||||
o=v(torch.randn(1,80,256))
|
||||
print(o[-1].shape)
|
||||
|
|
|
@ -114,9 +114,10 @@ if __name__ == "__main__":
|
|||
if audio_mode:
|
||||
data = {
|
||||
'clip': im.to('cuda'),
|
||||
'alt_clips': torch.zeros_like(refs[:,0].to('cuda')),
|
||||
'alt_clips': refs.to('cuda'),
|
||||
'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'),
|
||||
'GT_path': opt['image']
|
||||
'GT_path': opt['image'],
|
||||
'resampled_clip': refs[:, 0].to('cuda')
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
|
|
Loading…
Reference in New Issue
Block a user