diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index 00ea4d13..10c74ec7 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -66,9 +66,13 @@ class ChainedEmbeddingGen(nn.Module): class ChainedEmbeddingGenWithStructure(nn.Module): - def __init__(self, depth=10): + def __init__(self, depth=10, recurrent=False): super(ChainedEmbeddingGenWithStructure, self).__init__() - self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) + self.recurrent = recurrent + if recurrent: + self.initial_conv_rec = ConvGnLelu(6, 64, kernel_size=7, bias=True, norm=False, activation=False) + else: + self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) @@ -79,7 +83,10 @@ class ChainedEmbeddingGenWithStructure(nn.Module): def forward(self, x): emb = checkpoint(self.spine, x) - fea = self.initial_conv(x) + if self.recurrent: + fea = self.initial_conv_rec(x) + else: + fea = self.initial_conv(x) grad = fea for i, block in enumerate(self.blocks): fea = fea + checkpoint(block, fea, *emb) @@ -88,31 +95,3 @@ class ChainedEmbeddingGenWithStructure(nn.Module): grad = grad + checkpoint(self.structure_blocks[i], structure_br) out = checkpoint(self.upsample, fea) return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out) - - -class ChainedEmbeddingGenWithStructureR2(nn.Module): - def __init__(self, depth=10): - super(ChainedEmbeddingGenWithStructureR2, self).__init__() - self.initial_conv = ConvGnLelu(3, 64, kernel_size=7, bias=True, norm=False, activation=False) - self.spine = SpineNet(arch='49', output_level=[3, 4], double_reduce_early=False) - self.blocks = nn.ModuleList([BasicEmbeddingPyramid() for i in range(depth)]) - self.structure_joins = nn.ModuleList([ConjoinBlock(64) for i in range(3)]) - self.structure_blocks = nn.ModuleList([ConvGnLelu(64, 64, kernel_size=3, bias=False, norm=False, activation=False, weight_init_factor=.1) for i in range(3)]) - self.structure_upsample = FinalUpsampleBlock2x(64) - self.structure_rejoin = ConjoinBlock(64) - self.grad_extract = ImageGradientNoPadding() - self.upsample = FinalUpsampleBlock2x(64) - - def forward(self, x): - emb = checkpoint(self.spine, x) - fea = self.initial_conv(x) - grad = fea - for i, block in enumerate(self.blocks): - fea = fea + checkpoint(block, fea, *emb) - if i < 3: - structure_br = checkpoint(self.structure_joins[i], grad, fea) - grad = grad + checkpoint(self.structure_blocks[i], structure_br) - if i == 3: - fea = fea + self.structure_rejoin(fea, grad) - out = checkpoint(self.upsample, fea) - return out, self.grad_extract(checkpoint(self.structure_upsample, grad)), self.grad_extract(out)