From d8621e611a650acbe9dee6a75637c2124d99924a Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 26 Sep 2020 21:25:04 -0600 Subject: [PATCH] BackboneSpineNoHead takes ref --- .../archs/SwitchedResidualGenerator_arch.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 01034863..f9bf0799 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -413,11 +413,24 @@ class BackboneEncoderNoRef(nn.Module): class BackboneSpinenetNoHead(nn.Module): def __init__(self): super(BackboneSpinenetNoHead, self).__init__() - self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True, double_reduce_early=False) + # Uses dual spinenets, one for the input patch and the other for the reference image. + self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=False, double_reduce_early=False) + self.ref_spine = SpineNet('49', in_channels=4, use_input_norm=False, double_reduce_early=False) + + self.merge_process1 = ConvGnSilu(512, 512, kernel_size=1, activation=True, norm=False, bias=True) + self.merge_process2 = ConvGnSilu(512, 384, kernel_size=1, activation=True, norm=True, bias=False) + self.merge_process3 = ConvGnSilu(384, 256, kernel_size=1, activation=False, norm=False, bias=True) + + def forward(self, x, ref, ref_center_point): + ref_emb = checkpoint(self.ref_spine, ref)[0] + ref_code = gather_2d(ref_emb, ref_center_point // 4) # Divide by 8 to bring the center point to the correct location. - def forward(self, x): patch = checkpoint(self.patch_spine, x)[0] - return patch + ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3]) + combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1)) + combined = self.merge_process2(combined) + combined = self.merge_process3(combined) + return combined class ResBlock(nn.Module):