BackboneSpineNoHead takes ref
This commit is contained in:
parent
5a27187c59
commit
d8621e611a
|
@ -413,11 +413,24 @@ class BackboneEncoderNoRef(nn.Module):
|
|||
class BackboneSpinenetNoHead(nn.Module):
|
||||
def __init__(self):
|
||||
super(BackboneSpinenetNoHead, self).__init__()
|
||||
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True, double_reduce_early=False)
|
||||
# Uses dual spinenets, one for the input patch and the other for the reference image.
|
||||
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=False, double_reduce_early=False)
|
||||
self.ref_spine = SpineNet('49', in_channels=4, use_input_norm=False, double_reduce_early=False)
|
||||
|
||||
self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True)
|
||||
self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False)
|
||||
self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True)
|
||||
|
||||
def forward(self, x, ref, ref_center_point):
|
||||
ref_emb = checkpoint(self.ref_spine, ref)[0]
|
||||
ref_code = gather_2d(ref_emb, ref_center_point // 4) # Divide by 8 to bring the center point to the correct location.
|
||||
|
||||
def forward(self, x):
|
||||
patch = checkpoint(self.patch_spine, x)[0]
|
||||
return patch
|
||||
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
|
||||
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
|
||||
combined = self.merge_process2(combined)
|
||||
combined = self.merge_process3(combined)
|
||||
return combined
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user