diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index aef344b4..0e9e4460 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -379,7 +379,7 @@ class BackboneEncoder(nn.Module): ref_emb = checkpoint(self.ref_spine, ref)[0] ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location. - patch = checkpoint(self.ref_spine, x)[0] + patch = checkpoint(self.patch_spine, x)[0] ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3]) combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1)) combined = self.merge_process2(combined) @@ -387,6 +387,29 @@ class BackboneEncoder(nn.Module): return combined + +class BackboneEncoderNoRef(nn.Module): + def __init__(self, interpolate_first=True, pretrained_backbone=None): + super(BackboneEncoderNoRef, self).__init__() + self.interpolate_first = interpolate_first + + self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True) + + if pretrained_backbone is not None: + loaded_params = torch.load(pretrained_backbone) + self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True) + + # Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True). + # Output channels are always 256. + # ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16] + def forward(self, x): + if self.interpolate_first: + x = F.interpolate(x, scale_factor=2, mode="bicubic") + + patch = checkpoint(self.patch_spine, x)[0] + return patch + + # Note to future self: # Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by # their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship. @@ -456,6 +479,7 @@ class QueryKeyMultiplexer(nn.Module): self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True) # Postprocessing blocks. + self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=1, activation=True, norm=False, bias=False) self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4) self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False) @@ -474,10 +498,8 @@ class QueryKeyMultiplexer(nn.Module): k = transformations.view(b * t, f, h, w) k = self.key_process(k) - k = k.view(b, t, f, h, w) # Not sure if this is necessary.. - q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1) - v = q * k - v = v.view(b * t, f, h, w) + q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w) + v = self.query_key_combine(torch.cat([q, k], dim=1)) v = self.cbl1(v) v = self.cbl2(v) diff --git a/codes/models/networks.py b/codes/models/networks.py index 9c4a855c..e5d66744 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -61,6 +61,8 @@ def define_G(opt, net_key='network_G', scale=None): init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) elif which_model == "backbone_encoder": netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) + elif which_model == "backbone_encoder_no_ref": + netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet']) else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))