diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index 10c74ec7..d304458d 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -81,9 +81,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module): self.grad_extract = ImageGradientNoPadding() self.upsample = FinalUpsampleBlock2x(64) - def forward(self, x): + def forward(self, x, recurrent=None): emb = checkpoint(self.spine, x) if self.recurrent: + fea = torch.cat([x,recurrent], dim=1) fea = self.initial_conv_rec(x) else: fea = self.initial_conv(x)