Add recurrent support to chainedgenwithstructure
This commit is contained in:
parent
d4a3e11ab2
commit
fc4c064867
|
@ -66,9 +66,13 @@ class ChainedEmbeddingGen(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ChainedEmbeddingGenWithStructure(nn.Module):
|
class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
def __init__(self, depth=10):
|
def __init__(self, depth=10, recurrent=False):
|
||||||
super(ChainedEmbeddingGenWithStructure, self).__init__()
|
super(ChainedEmbeddingGenWithStructure, self).__init__()
|
||||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
self.recurrent = recurrent
|
||||||
|
if recurrent:
|
||||||
|
self.initial_conv_rec = ConvGnLelu(6, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||||
|
else:
|
||||||
|
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
||||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
||||||
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
||||||
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
||||||
|
@ -79,7 +83,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
emb = checkpoint(self.spine, x)
|
emb = checkpoint(self.spine, x)
|
||||||
fea = self.initial_conv(x)
|
if self.recurrent:
|
||||||
|
fea = self.initial_conv_rec(x)
|
||||||
|
else:
|
||||||
|
fea = self.initial_conv(x)
|
||||||
grad = fea
|
grad = fea
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
fea = fea + checkpoint(block, fea, *emb)
|
fea = fea + checkpoint(block, fea, *emb)
|
||||||
|
@ -88,31 +95,3 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
||||||
out = checkpoint(self.upsample, fea)
|
out = checkpoint(self.upsample, fea)
|
||||||
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out)
|
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out)
|
||||||
|
|
||||||
|
|
||||||
class ChainedEmbeddingGenWithStructureR2(nn.Module):
|
|
||||||
def __init__(self, depth=10):
|
|
||||||
super(ChainedEmbeddingGenWithStructureR2, self).__init__()
|
|
||||||
self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False)
|
|
||||||
self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False)
|
|
||||||
self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)])
|
|
||||||
self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)])
|
|
||||||
self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)])
|
|
||||||
self.structure_upsample = FinalUpsampleBlock2x(64)
|
|
||||||
self.structure_rejoin = ConjoinBlock(64)
|
|
||||||
self.grad_extract = ImageGradientNoPadding()
|
|
||||||
self.upsample = FinalUpsampleBlock2x(64)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
emb = checkpoint(self.spine, x)
|
|
||||||
fea = self.initial_conv(x)
|
|
||||||
grad = fea
|
|
||||||
for i, block in enumerate(self.blocks):
|
|
||||||
fea = fea + checkpoint(block, fea, *emb)
|
|
||||||
if i < 3:
|
|
||||||
structure_br = checkpoint(self.structure_joins[i], grad, fea)
|
|
||||||
grad = grad + checkpoint(self.structure_blocks[i], structure_br)
|
|
||||||
if i == 3:
|
|
||||||
fea = fea + self.structure_rejoin(fea, grad)
|
|
||||||
out = checkpoint(self.upsample, fea)
|
|
||||||
return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out)
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user