Add BackboneEncoderNoRef
This commit is contained in:
parent
d0321ca5de
commit
6deab85b9b
|
@ -379,7 +379,7 @@ class BackboneEncoder(nn.Module):
|
||||||
ref_emb = checkpoint(self.ref_spine, ref)[0]
|
ref_emb = checkpoint(self.ref_spine, ref)[0]
|
||||||
ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location.
|
ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location.
|
||||||
|
|
||||||
patch = checkpoint(self.ref_spine, x)[0]
|
patch = checkpoint(self.patch_spine, x)[0]
|
||||||
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
|
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
|
||||||
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
|
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
|
||||||
combined = self.merge_process2(combined)
|
combined = self.merge_process2(combined)
|
||||||
|
@ -387,6 +387,29 @@ class BackboneEncoder(nn.Module):
|
||||||
|
|
||||||
return combined
|
return combined
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneEncoderNoRef(nn.Module):
|
||||||
|
def __init__(self, interpolate_first=True, pretrained_backbone=None):
|
||||||
|
super(BackboneEncoderNoRef, self).__init__()
|
||||||
|
self.interpolate_first = interpolate_first
|
||||||
|
|
||||||
|
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
|
||||||
|
|
||||||
|
if pretrained_backbone is not None:
|
||||||
|
loaded_params = torch.load(pretrained_backbone)
|
||||||
|
self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True)
|
||||||
|
|
||||||
|
# Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True).
|
||||||
|
# Output channels are always 256.
|
||||||
|
# ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16]
|
||||||
|
def forward(self, x):
|
||||||
|
if self.interpolate_first:
|
||||||
|
x = F.interpolate(x, scale_factor=2, mode="bicubic")
|
||||||
|
|
||||||
|
patch = checkpoint(self.patch_spine, x)[0]
|
||||||
|
return patch
|
||||||
|
|
||||||
|
|
||||||
# Note to future self:
|
# Note to future self:
|
||||||
# Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by
|
# Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by
|
||||||
# their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship.
|
# their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship.
|
||||||
|
@ -456,6 +479,7 @@ class QueryKeyMultiplexer(nn.Module):
|
||||||
self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True)
|
self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True)
|
||||||
|
|
||||||
# Postprocessing blocks.
|
# Postprocessing blocks.
|
||||||
|
self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=1, activation=True, norm=False, bias=False)
|
||||||
self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4)
|
self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4)
|
||||||
self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False)
|
self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False)
|
||||||
|
|
||||||
|
@ -474,10 +498,8 @@ class QueryKeyMultiplexer(nn.Module):
|
||||||
k = transformations.view(b * t, f, h, w)
|
k = transformations.view(b * t, f, h, w)
|
||||||
k = self.key_process(k)
|
k = self.key_process(k)
|
||||||
|
|
||||||
k = k.view(b, t, f, h, w) # Not sure if this is necessary..
|
q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w)
|
||||||
q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1)
|
v = self.query_key_combine(torch.cat([q, k], dim=1))
|
||||||
v = q * k
|
|
||||||
v = v.view(b * t, f, h, w)
|
|
||||||
|
|
||||||
v = self.cbl1(v)
|
v = self.cbl1(v)
|
||||||
v = self.cbl2(v)
|
v = self.cbl2(v)
|
||||||
|
|
|
@ -61,6 +61,8 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||||
elif which_model == "backbone_encoder":
|
elif which_model == "backbone_encoder":
|
||||||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||||
|
elif which_model == "backbone_encoder_no_ref":
|
||||||
|
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user