From 14f3155ec42d507dd7d2d548aafc615101fa8f4d Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Sat, 20 Nov 2021 17:45:14 -0700
Subject: [PATCH] misc

---
 codes/models/gpt_voice/lucidrains_dvae.py | 5 +++--
 codes/trainer/inject.py                   | 3 ++-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/codes/models/gpt_voice/lucidrains_dvae.py b/codes/models/gpt_voice/lucidrains_dvae.py
index 673bf9eb..688c63a7 100644
--- a/codes/models/gpt_voice/lucidrains_dvae.py
+++ b/codes/models/gpt_voice/lucidrains_dvae.py
@@ -264,8 +264,9 @@ if __name__ == '__main__':
     #v = DiscreteVAE()
     #o=v(torch.randn(1,3,256,256))
     #print(o.shape)
-    v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
-                    hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=0, use_transposed_convs=False)
+    v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=8192, codebook_dim=2048,
+                    hidden_dim=512, stride=2, num_resnet_blocks=3, kernel_size=4, num_layers=2, use_transposed_convs=True)
+    v.load_state_dict(torch.load('../experiments/clips_dvae_8192_rev2.pth'))
     #v.eval()
     o=v(torch.randn(1,80,256))
     print(o[-1].shape)
diff --git a/codes/trainer/inject.py b/codes/trainer/inject.py
index 2f73fef3..a37f9351 100644
--- a/codes/trainer/inject.py
+++ b/codes/trainer/inject.py
@@ -15,7 +15,8 @@ class Injector(torch.nn.Module):
         self.env = env
         if 'in' in opt.keys():
             self.input = opt['in']
-        self.output = opt['out']
+        if 'out' in opt.keys():
+            self.output = opt['out']
 
     # This should return a dict of new state variables.
     def forward(self, state):