misc
This commit is contained in:
parent
687e0746b3
commit
14f3155ec4
|
@ -264,8 +264,9 @@ if __name__ == '__main__':
|
|||
#v = DiscreteVAE()
|
||||
#o=v(torch.randn(1,3,256,256))
|
||||
#print(o.shape)
|
||||
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
|
||||
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=0, use_transposed_convs=False)
|
||||
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
|
||||
hidden_dim=512, stride=2, num_resnet_blocks=3, kernel_size=4, num_layers=2, use_transposed_convs=True)
|
||||
v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
|
||||
#v.eval()
|
||||
o=v(torch.randn(1,80,256))
|
||||
print(o[-1].shape)
|
||||
|
|
|
@ -15,7 +15,8 @@ class Injector(torch.nn.Module):
|
|||
self.env = env
|
||||
if 'in' in opt.keys():
|
||||
self.input = opt['in']
|
||||
self.output = opt['out']
|
||||
if 'out' in opt.keys():
|
||||
self.output = opt['out']
|
||||
|
||||
# This should return a dict of new state variables.
|
||||
def forward(self, state):
|
||||
|
|
Loading…
Reference in New Issue
Block a user