forked from mrq/DL-Art-School
More recurrence fixes for chainedgen
This commit is contained in:
parent
cf8118a85b
commit
6141aa1110
|
@ -81,9 +81,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
self.grad_extract = ImageGradientNoPadding()
|
self.grad_extract = ImageGradientNoPadding()
|
||||||
self.upsample = FinalUpsampleBlock2x(64)
|
self.upsample = FinalUpsampleBlock2x(64)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, recurrent=None):
|
||||||
emb = checkpoint(self.spine, x)
|
emb = checkpoint(self.spine, x)
|
||||||
if self.recurrent:
|
if self.recurrent:
|
||||||
|
fea = torch.cat([x,recurrent], dim=1)
|
||||||
fea = self.initial_conv_rec(x)
|
fea = self.initial_conv_rec(x)
|
||||||
else:
|
else:
|
||||||
fea = self.initial_conv(x)
|
fea = self.initial_conv(x)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user