SPSR with ref joining
This commit is contained in:
parent
c41dc9a48c
commit
0ffac391c1
|
@ -490,6 +490,22 @@ class MultiplexerWithReducer(nn.Module):
|
||||||
x = self.reduce(x)
|
x = self.reduce(x)
|
||||||
return self.mplex(x, ref)
|
return self.mplex(x, ref)
|
||||||
|
|
||||||
|
|
||||||
|
class RefJoiner(nn.Module):
|
||||||
|
def __init__(self, nf):
|
||||||
|
super(RefJoiner, self).__init__()
|
||||||
|
self.lin1 = nn.Linear(512, 256)
|
||||||
|
self.lin2 = nn.Linear(256, nf)
|
||||||
|
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
|
||||||
|
|
||||||
|
def forward(self, x, ref):
|
||||||
|
ref = self.lin1(ref)
|
||||||
|
ref = self.lin2(ref)
|
||||||
|
b, _, h, w = x.shape
|
||||||
|
ref = ref.view(b, -1, 1, 1)
|
||||||
|
return self.join(x, ref.repeat((1, 1, h, w)))
|
||||||
|
|
||||||
|
|
||||||
class SwitchedSpsrWithRef2(nn.Module):
|
class SwitchedSpsrWithRef2(nn.Module):
|
||||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
|
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
|
||||||
super(SwitchedSpsrWithRef2, self).__init__()
|
super(SwitchedSpsrWithRef2, self).__init__()
|
||||||
|
@ -506,14 +522,18 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
transformation_filters, kernel_size=3, depth=3,
|
transformation_filters, kernel_size=3, depth=3,
|
||||||
weight_init_factor=.1)
|
weight_init_factor=.1)
|
||||||
|
|
||||||
|
self.reference_processor = ReferenceImageBranch(transformation_filters)
|
||||||
|
|
||||||
# Feature branch
|
# Feature branch
|
||||||
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
self.model_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||||
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
|
self.noise_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.1, norm=False)
|
||||||
|
self.ref_join1 = RefJoiner(nf)
|
||||||
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw1 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||||
add_scalable_noise_to_transforms=False)
|
add_scalable_noise_to_transforms=False)
|
||||||
|
self.ref_join2 = RefJoiner(nf)
|
||||||
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.sw2 = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
attention_norm=True,
|
attention_norm=True,
|
||||||
|
@ -536,7 +556,7 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
self.upsample_grad = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=False, activation=True, bias=False) for _ in range(n_upscale)])
|
||||||
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
|
self.grad_branch_output_conv = ConvGnLelu(nf, out_nc, kernel_size=1, norm=False, activation=False, bias=True)
|
||||||
|
|
||||||
# Join branch (grad+fea
|
# Join branch (grad+fea)
|
||||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||||
|
@ -546,7 +566,7 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
self.final_lr_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=False)
|
||||||
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
self.upsample = nn.Sequential(*[UpconvBlock(nf, nf, block=ConvGnLelu, norm=True, activation=True, bias=False) for _ in range(n_upscale)])
|
||||||
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
self.final_hr_conv1 = ConvGnLelu(nf, nf, kernel_size=3, norm=True, activation=True, bias=False)
|
||||||
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=True)
|
self.final_hr_conv2 = ConvGnLelu(nf, out_nc, kernel_size=3, norm=False, activation=False, bias=False)
|
||||||
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
self.switches = [self.sw1, self.sw2, self.sw_grad, self.conjoin_sw]
|
||||||
self.attentions = None
|
self.attentions = None
|
||||||
self.init_temperature = init_temperature
|
self.init_temperature = init_temperature
|
||||||
|
@ -554,11 +574,16 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
|
|
||||||
def forward(self, x, ref, center_coord):
|
def forward(self, x, ref, center_coord):
|
||||||
x_grad = self.get_g_nopadding(x)
|
x_grad = self.get_g_nopadding(x)
|
||||||
|
ref = self.reference_processor(ref, center_coord)
|
||||||
|
|
||||||
x = self.model_fea_conv(x)
|
x = self.model_fea_conv(x)
|
||||||
x = self.noise_ref_join(x, torch.randn_like(x))
|
x = self.noise_ref_join(x, torch.randn_like(x))
|
||||||
|
x = self.ref_join1(x, ref)
|
||||||
x1, a1 = self.sw1(x, True)
|
x1, a1 = self.sw1(x, True)
|
||||||
x2, a2 = self.sw2(x, True)
|
|
||||||
|
x2 = x1
|
||||||
|
x2 = self.ref_join2(x2, ref)
|
||||||
|
x2, a2 = self.sw2(x2, True)
|
||||||
x_fea = self.feature_lr_conv(x2)
|
x_fea = self.feature_lr_conv(x2)
|
||||||
x_fea = self.feature_lr_conv2(x_fea)
|
x_fea = self.feature_lr_conv2(x_fea)
|
||||||
|
|
||||||
|
@ -570,7 +595,8 @@ class SwitchedSpsrWithRef2(nn.Module):
|
||||||
x_grad_out = self.upsample_grad(x_grad)
|
x_grad_out = self.upsample_grad(x_grad)
|
||||||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||||
|
|
||||||
x_out = self.conjoin_ref_join(x_fea, x_grad)
|
x_out = x_fea
|
||||||
|
x_out = self.conjoin_ref_join(x_out, x_grad)
|
||||||
x_out, a4 = self.conjoin_sw(x_out, True)
|
x_out, a4 = self.conjoin_sw(x_out, True)
|
||||||
x_out = self.final_lr_conv(x_out)
|
x_out = self.final_lr_conv(x_out)
|
||||||
x_out = self.upsample(x_out)
|
x_out = self.upsample(x_out)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user