This commit is contained in:
James Betker 2021-11-20 17:45:14 -07:00
parent 687e0746b3
commit 14f3155ec4
2 changed files with 5 additions and 3 deletions

View File

@ -264,8 +264,9 @@ if __name__ == '__main__':
#v = DiscreteVAE()
#o=v(torch.randn(1,3,256,256))
#print(o.shape)
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=0, use_transposed_convs=False)
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
hidden_dim=512, stride=2, num_resnet_blocks=3, kernel_size=4, num_layers=2, use_transposed_convs=True)
v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
#v.eval()
o=v(torch.randn(1,80,256))
print(o[-1].shape)

View File

@ -15,7 +15,8 @@ class Injector(torch.nn.Module):
self.env = env
if 'in' in opt.keys():
self.input = opt['in']
self.output = opt['out']
if 'out' in opt.keys():
self.output = opt['out']
# This should return a dict of new state variables.
def forward(self, state):