forked from mrq/DL-Art-School
Add choke to lucidrains_dvae
This commit is contained in:
parent
934395d4b8
commit
82d0e7720e
|
@ -76,6 +76,8 @@ class DiscreteVAE(nn.Module):
|
||||||
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
normalization = None, # ((0.5,) * 3, (0.5,) * 3),
|
||||||
record_codes = False,
|
record_codes = False,
|
||||||
discretization_loss_averaging_steps = 100,
|
discretization_loss_averaging_steps = 100,
|
||||||
|
encoder_choke=False,
|
||||||
|
choke_dim=128,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
has_resblocks = num_resnet_blocks > 0
|
has_resblocks = num_resnet_blocks > 0
|
||||||
|
@ -140,6 +142,10 @@ class DiscreteVAE(nn.Module):
|
||||||
if num_resnet_blocks > 0:
|
if num_resnet_blocks > 0:
|
||||||
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
dec_layers.insert(0, conv(codebook_dim, innermost_dim, 1))
|
||||||
|
|
||||||
|
if encoder_choke:
|
||||||
|
enc_layers.append(conv(innermost_dim, choke_dim, 1))
|
||||||
|
innermost_dim = choke_dim
|
||||||
|
|
||||||
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
enc_layers.append(conv(innermost_dim, codebook_dim, 1))
|
||||||
dec_layers.append(conv(dec_out_chans, channels, 1))
|
dec_layers.append(conv(dec_out_chans, channels, 1))
|
||||||
|
|
||||||
|
@ -265,8 +271,9 @@ if __name__ == '__main__':
|
||||||
#o=v(torch.randn(1,3,256,256))
|
#o=v(torch.randn(1,3,256,256))
|
||||||
#print(o.shape)
|
#print(o.shape)
|
||||||
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
|
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
|
||||||
hidden_dim=512, stride=2, num_resnet_blocks=3, kernel_size=4, num_layers=2, use_transposed_convs=True)
|
hidden_dim=512, num_resnet_blocks=3, kernel_size=3, num_layers=1, use_transposed_convs=False,
|
||||||
v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
encoder_choke=True, choke_dim=256)
|
||||||
|
#v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
||||||
#v.eval()
|
#v.eval()
|
||||||
o=v(torch.randn(1,80,256))
|
o=v(torch.randn(1,80,256))
|
||||||
print(o[-1].shape)
|
print(o[-1].shape)
|
||||||
|
|
|
@ -114,9 +114,10 @@ if __name__ == "__main__":
|
||||||
if audio_mode:
|
if audio_mode:
|
||||||
data = {
|
data = {
|
||||||
'clip': im.to('cuda'),
|
'clip': im.to('cuda'),
|
||||||
'alt_clips': torch.zeros_like(refs[:,0].to('cuda')),
|
'alt_clips': refs.to('cuda'),
|
||||||
'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'),
|
'num_alt_clips': torch.tensor([refs.shape[1]], dtype=torch.int32, device='cuda'),
|
||||||
'GT_path': opt['image']
|
'GT_path': opt['image'],
|
||||||
|
'resampled_clip': refs[:, 0].to('cuda')
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
data = {
|
data = {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user