From 6141aa111078e030cf545575d47e7796f35c9be3 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 17 Oct 2020 08:35:46 -0600 Subject: [PATCH] More recurrence fixes for chainedgen --- codes/models/archs/ChainedEmbeddingGen.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index 10c74ec7..d304458d 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -81,9 +81,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module): self.grad_extract = ImageGradientNoPadding() self.upsample = FinalUpsampleBlock2x(64) - def forward(self, x): + def forward(self, x, recurrent=None): emb = checkpoint(self.spine, x) if self.recurrent: + fea = torch.cat([x,recurrent], dim=1) fea = self.initial_conv_rec(x) else: fea = self.initial_conv(x)