More recurrence fixes for chainedgen

This commit is contained in:
James Betker 2020-10-17 08:35:46 -06:00
parent cf8118a85b
commit 6141aa1110

View File

@ -81,9 +81,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
self.grad_extract = ImageGradientNoPadding()
self.upsample = FinalUpsampleBlock2x(64)
def forward(self, x):
def forward(self, x, recurrent=None):
emb = checkpoint(self.spine, x)
if self.recurrent:
fea = torch.cat([x,recurrent], dim=1)
fea = self.initial_conv_rec(x)
else:
fea = self.initial_conv(x)