forked from mrq/DL-Art-School
More recurrence fixes for chainedgen
This commit is contained in:
parent
cf8118a85b
commit
6141aa1110
|
@ -81,9 +81,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
|||
self.grad_extract = ImageGradientNoPadding()
|
||||
self.upsample = FinalUpsampleBlock2x(64)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, recurrent=None):
|
||||
emb = checkpoint(self.spine, x)
|
||||
if self.recurrent:
|
||||
fea = torch.cat([x,recurrent], dim=1)
|
||||
fea = self.initial_conv_rec(x)
|
||||
else:
|
||||
fea = self.initial_conv(x)
|
||||
|
|
Loading…
Reference in New Issue
Block a user