BackboneSpineNoHead takes ref

This commit is contained in:
James Betker 2020-09-26 21:25:04 -06:00
parent 5a27187c59
commit d8621e611a

View File

@ -413,11 +413,24 @@ class BackboneEncoderNoRef(nn.Module):
class BackboneSpinenetNoHead(nn.Module):
def __init__(self):
super(BackboneSpinenetNoHead, self).__init__()
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True, double_reduce_early=False)
# Uses dual spinenets, one for the input patch and the other for the reference image.
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=False, double_reduce_early=False)
self.ref_spine = SpineNet('49', in_channels=4, use_input_norm=False, double_reduce_early=False)
self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True)
self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False)
self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True)
def forward(self, x, ref, ref_center_point):
ref_emb = checkpoint(self.ref_spine, ref)[0]
ref_code = gather_2d(ref_emb, ref_center_point // 4) # Divide by 8 to bring the center point to the correct location.
def forward(self, x):
patch = checkpoint(self.patch_spine, x)[0]
return patch
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
combined = self.merge_process2(combined)
combined = self.merge_process3(combined)
return combined
class ResBlock(nn.Module):