Add BackboneEncoderNoRef

This commit is contained in:
James Betker 2020-09-15 16:55:38 -06:00
parent d0321ca5de
commit 6deab85b9b
2 changed files with 29 additions and 5 deletions

View File

@ -379,7 +379,7 @@ class BackboneEncoder(nn.Module):
ref_emb = checkpoint(self.ref_spine, ref)[0] ref_emb = checkpoint(self.ref_spine, ref)[0]
ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location. ref_code = gather_2d(ref_emb, ref_center_point // 8) # Divide by 8 to bring the center point to the correct location.
patch = checkpoint(self.ref_spine, x)[0] patch = checkpoint(self.patch_spine, x)[0]
ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3]) ref_code_expanded = ref_code.view(-1, 256, 1, 1).repeat(1, 1, patch.shape[2], patch.shape[3])
combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1)) combined = self.merge_process1(torch.cat([patch, ref_code_expanded], dim=1))
combined = self.merge_process2(combined) combined = self.merge_process2(combined)
@ -387,6 +387,29 @@ class BackboneEncoder(nn.Module):
return combined return combined
class BackboneEncoderNoRef(nn.Module):
def __init__(self, interpolate_first=True, pretrained_backbone=None):
super(BackboneEncoderNoRef, self).__init__()
self.interpolate_first = interpolate_first
self.patch_spine = SpineNet('49', in_channels=3, use_input_norm=True)
if pretrained_backbone is not None:
loaded_params = torch.load(pretrained_backbone)
self.patch_spine.load_state_dict(loaded_params['state_dict'], strict=True)
# Returned embedding will have been reduced in size by a factor of 8 (4 if interpolate_first=True).
# Output channels are always 256.
# ex, 64x64 input with interpolate_first=True will result in tensor of shape [bx256x16x16]
def forward(self, x):
if self.interpolate_first:
x = F.interpolate(x, scale_factor=2, mode="bicubic")
patch = checkpoint(self.patch_spine, x)[0]
return patch
# Note to future self: # Note to future self:
# Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by # Can I do a real transformer here? Such as by having the multiplexer be able to toggle off of transformations by
# their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship. # their output? The embedding will be used as the "Query" to the "QueryxKey=Value" relationship.
@ -456,6 +479,7 @@ class QueryKeyMultiplexer(nn.Module):
self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True) self.key_process = ConvGnSilu(nf, nf, kernel_size=1, activation=True, norm=False, bias=True)
# Postprocessing blocks. # Postprocessing blocks.
self.query_key_combine = ConvGnSilu(nf*2, nf, kernel_size=1, activation=True, norm=False, bias=False)
self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4) self.cbl1 = ConvGnSilu(nf, nf // 2, kernel_size=1, norm=True, bias=False, num_groups=4)
self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False) self.cbl2 = ConvGnSilu(nf // 2, 1, kernel_size=1, norm=False, bias=False)
@ -474,10 +498,8 @@ class QueryKeyMultiplexer(nn.Module):
k = transformations.view(b * t, f, h, w) k = transformations.view(b * t, f, h, w)
k = self.key_process(k) k = self.key_process(k)
k = k.view(b, t, f, h, w) # Not sure if this is necessary.. q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1).view(b * t, f, h, w)
q = q.view(b, 1, f, h, w).repeat(1, t, 1, 1, 1) v = self.query_key_combine(torch.cat([q, k], dim=1))
v = q * k
v = v.view(b * t, f, h, w)
v = self.cbl1(v) v = self.cbl1(v)
v = self.cbl2(v) v = self.cbl2(v)

View File

@ -61,6 +61,8 @@ def define_G(opt, net_key='network_G', scale=None):
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == "backbone_encoder": elif which_model == "backbone_encoder":
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
elif which_model == "backbone_encoder_no_ref":
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
else: else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))