From b008a27d3920cdba2938845ee6c560bd72952601 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 17 Oct 2020 20:16:47 -0600 Subject: [PATCH] Spinenet should allow bypassing the initial conv This makes feeding in references for recurrence easier. --- codes/data/chunk_with_reference.py | 2 +- codes/models/archs/ChainedEmbeddingGen.py | 4 ++-- codes/models/archs/spinenet_arch.py | 11 +++-------- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/codes/data/chunk_with_reference.py b/codes/data/chunk_with_reference.py index 0393a67c..bab35f35 100644 --- a/codes/data/chunk_with_reference.py +++ b/codes/data/chunk_with_reference.py @@ -32,7 +32,7 @@ class ChunkWithReference: elif self.strict: raise FileNotFoundError(tile_id, self.tiles[item]) else: - center = torch.tensor([128,128], dtype=torch.long) + center = torch.tensor([128, 128], dtype=torch.long) tile_width = 256 mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype) mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1 diff --git a/codes/models/archs/ChainedEmbeddingGen.py b/codes/models/archs/ChainedEmbeddingGen.py index 6781630c..63dfbce4 100644 --- a/codes/models/archs/ChainedEmbeddingGen.py +++ b/codes/models/archs/ChainedEmbeddingGen.py @@ -58,8 +58,8 @@ class ChainedEmbeddingGen(nn.Module): self.upsample = FinalUpsampleBlock2x(64) def forward(self, x): - emb = checkpoint(self.spine, x) fea = self.initial_conv(x) + emb = checkpoint(self.spine, fea) for block in self.blocks: fea = fea + checkpoint(block, fea, *emb) return checkpoint(self.upsample, fea), @@ -82,11 +82,11 @@ class ChainedEmbeddingGenWithStructure(nn.Module): self.upsample = FinalUpsampleBlock2x(64) def forward(self, x, recurrent=None): - emb = checkpoint(self.spine, x) fea = self.initial_conv(x) if self.recurrent: rec = self.recurrent_process(recurrent) fea, _ = self.recurrent_join(fea, rec) + emb = checkpoint(self.spine, fea) grad = fea for i, block in enumerate(self.blocks): fea = fea + checkpoint(block, fea, *emb) diff --git a/codes/models/archs/spinenet_arch.py b/codes/models/archs/spinenet_arch.py index 86bac6f4..5dcd84a0 100644 --- a/codes/models/archs/spinenet_arch.py +++ b/codes/models/archs/spinenet_arch.py @@ -245,12 +245,7 @@ class SpineNet(nn.Module): stride=2) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) else: - self.conv1 = ConvGnSilu( - in_channels, - 64, - kernel_size=7, - stride=1) - self.maxpool = None + self.conv1 = None # Build the initial level 2 blocks. self.init_block1 = make_res_layer( @@ -311,8 +306,8 @@ class SpineNet(nn.Module): std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device) input = (input - mean) / std - feat = self.conv1(input) - if self.maxpool: + if self.conv1 is not None: + feat = self.conv1(input) feat = self.maxpool(feat) feat1 = self.init_block1(feat) feat2 = self.init_block2(feat1)