Spinenet should allow bypassing the initial conv
This makes feeding in references for recurrence easier.
This commit is contained in:
parent
c7f3fc4dd9
commit
b008a27d39
|
@ -32,7 +32,7 @@ class ChunkWithReference:
|
||||||
elif self.strict:
|
elif self.strict:
|
||||||
raise FileNotFoundError(tile_id, self.tiles[item])
|
raise FileNotFoundError(tile_id, self.tiles[item])
|
||||||
else:
|
else:
|
||||||
center = torch.tensor([128,128], dtype=torch.long)
|
center = torch.tensor([128, 128], dtype=torch.long)
|
||||||
tile_width = 256
|
tile_width = 256
|
||||||
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
mask = np.full(tile.shape[:2] + (1,), fill_value=.1, dtype=tile.dtype)
|
||||||
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
mask[center[0] - tile_width // 2:center[0] + tile_width // 2, center[1] - tile_width // 2:center[1] + tile_width // 2] = 1
|
||||||
|
|
|
@ -58,8 +58,8 @@ class ChainedEmbeddingGen(nn.Module):
|
||||||
self.upsample = FinalUpsampleBlock2x(64)
|
self.upsample = FinalUpsampleBlock2x(64)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
emb = checkpoint(self.spine, x)
|
|
||||||
fea = self.initial_conv(x)
|
fea = self.initial_conv(x)
|
||||||
|
emb = checkpoint(self.spine, fea)
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
fea = fea + checkpoint(block, fea, *emb)
|
fea = fea + checkpoint(block, fea, *emb)
|
||||||
return checkpoint(self.upsample, fea),
|
return checkpoint(self.upsample, fea),
|
||||||
|
@ -82,11 +82,11 @@ class ChainedEmbeddingGenWithStructure(nn.Module):
|
||||||
self.upsample = FinalUpsampleBlock2x(64)
|
self.upsample = FinalUpsampleBlock2x(64)
|
||||||
|
|
||||||
def forward(self, x, recurrent=None):
|
def forward(self, x, recurrent=None):
|
||||||
emb = checkpoint(self.spine, x)
|
|
||||||
fea = self.initial_conv(x)
|
fea = self.initial_conv(x)
|
||||||
if self.recurrent:
|
if self.recurrent:
|
||||||
rec = self.recurrent_process(recurrent)
|
rec = self.recurrent_process(recurrent)
|
||||||
fea, _ = self.recurrent_join(fea, rec)
|
fea, _ = self.recurrent_join(fea, rec)
|
||||||
|
emb = checkpoint(self.spine, fea)
|
||||||
grad = fea
|
grad = fea
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
fea = fea + checkpoint(block, fea, *emb)
|
fea = fea + checkpoint(block, fea, *emb)
|
||||||
|
|
|
@ -245,12 +245,7 @@ class SpineNet(nn.Module):
|
||||||
stride=2)
|
stride=2)
|
||||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
else:
|
else:
|
||||||
self.conv1 = ConvGnSilu(
|
self.conv1 = None
|
||||||
in_channels,
|
|
||||||
64,
|
|
||||||
kernel_size=7,
|
|
||||||
stride=1)
|
|
||||||
self.maxpool = None
|
|
||||||
|
|
||||||
# Build the initial level 2 blocks.
|
# Build the initial level 2 blocks.
|
||||||
self.init_block1 = make_res_layer(
|
self.init_block1 = make_res_layer(
|
||||||
|
@ -311,8 +306,8 @@ class SpineNet(nn.Module):
|
||||||
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device)
|
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input.device)
|
||||||
input = (input - mean) / std
|
input = (input - mean) / std
|
||||||
|
|
||||||
feat = self.conv1(input)
|
if self.conv1 is not None:
|
||||||
if self.maxpool:
|
feat = self.conv1(input)
|
||||||
feat = self.maxpool(feat)
|
feat = self.maxpool(feat)
|
||||||
feat1 = self.init_block1(feat)
|
feat1 = self.init_block1(feat)
|
||||||
feat2 = self.init_block2(feat1)
|
feat2 = self.init_block2(feat1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user