SPSR with ref joining

This commit is contained in:
James Betker 2020-09-09 11:17:07 -06:00
parent c41dc9a48c
commit 0ffac391c1

View File

@ -490,6 +490,22 @@ class MultiplexerWithReducer(nn.Module):
x = self.reduce(x)
return self.mplex(x, ref)
class RefJoiner(nn.Module):
def __init__(self, nf):
super(RefJoiner, self).__init__()
self.lin1 = nn.Linear(512, 256)
self.lin2 = nn.Linear(256, nf)
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
def forward(self, x, ref):
ref = self.lin1(ref)
ref = self.lin2(ref)
b, _, h, w = x.shape
ref = ref.view(b, -1, 1, 1)
return self.join(x, ref.repeat((1, 1, h, w)))
class SwitchedSpsrWithRef2(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
super(SwitchedSpsrWithRef2, self).__init__()
@ -506,14 +522,18 @@ class SwitchedSpsrWithRef2(nn.Module):
transformation_filters, kernel_size=3, depth=3,
weight_init_factor=.1)
self.reference_processor = ReferenceImageBranch(transformation_filters)
# Feature branch
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
self.ref_join1 = RefJoiner(nf)
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
self.ref_join2 = RefJoiner(nf)
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
attention_norm=True,
@ -536,7 +556,7 @@ class SwitchedSpsrWithRef2(nn.Module):
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
# Join branch (grad+fea
# Join branch (grad+fea)
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
@ -546,7 +566,7 @@ class SwitchedSpsrWithRef2(nn.Module):
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
self.attentions = None
self.init_temperature = init_temperature
@ -554,11 +574,16 @@ class SwitchedSpsrWithRef2(nn.Module):
def forward(self, x, ref, center_coord):
x_grad = self.get_g_nopadding(x)
ref = self.reference_processor(ref, center_coord)
x = self.model_fea_conv(x)
x = self.noise_ref_join(x, torch.randn_like(x))
x = self.ref_join1(x, ref)
x1, a1 = self.sw1(x, True)
x2, a2 = self.sw2(x, True)
x2 = x1
x2 = self.ref_join2(x2, ref)
x2, a2 = self.sw2(x2, True)
x_fea = self.feature_lr_conv(x2)
x_fea = self.feature_lr_conv2(x_fea)
@ -570,7 +595,8 @@ class SwitchedSpsrWithRef2(nn.Module):
x_grad_out = self.upsample_grad(x_grad)
x_grad_out = self.grad_branch_output_conv(x_grad_out)
x_out = self.conjoin_ref_join(x_fea, x_grad)
x_out = x_fea
x_out = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True)
x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out)