spsr3 with conjoin stage as part of the switch

This commit is contained in:
James Betker 2020-09-10 09:11:37 -06:00
parent e0fc5eb50c
commit 992b0a8d98

View File

@ -371,6 +371,17 @@ class RefJoiner(nn.Module):
return self.join(x, ref.repeat((1, 1, h, w)))
class ModuleWithRef(nn.Module):
def __init__(self, nf, mcnv, *args):
super(ModuleWithRef, self).__init__()
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
self.multi = mcnv(*args)
def forward(self, x, ref):
out = self.join(x, ref)
return self.multi(out)
class SwitchedSpsrWithRef2(nn.Module):
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
super(SwitchedSpsrWithRef2, self).__init__()
@ -426,9 +437,10 @@ class SwitchedSpsrWithRef2(nn.Module):
# Join branch (grad+fea)
self.ref_join4 = RefJoiner(nf)
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
pre_transform_block=pretransform_fn, transform_block=transform_fn,
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters,
functools.partial(ModuleWithRef, nf, multiplx_fn),
pre_transform_block=None,
transform_block=functools.partial(ModuleWithRef, nf, transform_fn),
attention_norm=True,
transform_count=self.transformation_counts, init_temp=init_temperature,
add_scalable_noise_to_transforms=False)
@ -443,22 +455,22 @@ class SwitchedSpsrWithRef2(nn.Module):
def forward(self, x, ref, center_coord):
x_grad = self.get_g_nopadding(x)
#ref = self.reference_processor(ref, center_coord)
ref = self.reference_processor(ref, center_coord)
x = self.model_fea_conv(x)
x1 = x
#x1 = self.ref_join1(x1, ref)
x1 = self.ref_join1(x1, ref)
x1, a1 = self.sw1(x1, True, identity=x)
x2 = x1
#x2 = self.noise_ref_join(x2, torch.randn_like(x2))
#x2 = self.ref_join2(x2, ref)
x2 = self.noise_ref_join(x2, torch.randn_like(x2))
x2 = self.ref_join2(x2, ref)
x2, a2 = self.sw2(x2, True, identity=x1)
x_grad = self.grad_conv(x_grad)
x_grad_identity = x_grad
#x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
#x_grad = self.ref_join3(x_grad, ref)
x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
x_grad = self.ref_join3(x_grad, ref)
x_grad = self.grad_ref_join(x_grad, x1)
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
x_grad = self.grad_lr_conv(x_grad)
@ -467,10 +479,9 @@ class SwitchedSpsrWithRef2(nn.Module):
x_grad_out = self.grad_branch_output_conv(x_grad_out)
x_out = x2
#x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
#x_out = self.ref_join4(x_out, ref)
x_out = self.conjoin_ref_join(x_out, x_grad)
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
x_out = self.ref_join4(x_out, ref)
x_out, a4 = self.conjoin_sw((x_out, x_grad), True, identity=x2)
x_out = self.final_lr_conv(x_out)
x_out = self.upsample(x_out)
x_out = self.final_hr_conv1(x_out)