forked from mrq/DL-Art-School
spsr3 with conjoin stage as part of the switch
This commit is contained in:
parent
e0fc5eb50c
commit
992b0a8d98
|
@ -371,6 +371,17 @@ class RefJoiner(nn.Module):
|
|||
return self.join(x, ref.repeat((1, 1, h, w)))
|
||||
|
||||
|
||||
class ModuleWithRef(nn.Module):
|
||||
def __init__(self, nf, mcnv, *args):
|
||||
super(ModuleWithRef, self).__init__()
|
||||
self.join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||
self.multi = mcnv(*args)
|
||||
|
||||
def forward(self, x, ref):
|
||||
out = self.join(x, ref)
|
||||
return self.multi(out)
|
||||
|
||||
|
||||
class SwitchedSpsrWithRef2(nn.Module):
|
||||
def __init__(self, in_nc, out_nc, nf, xforms=8, upscale=4, init_temperature=10):
|
||||
super(SwitchedSpsrWithRef2, self).__init__()
|
||||
|
@ -426,9 +437,10 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
# Join branch (grad+fea)
|
||||
self.ref_join4 = RefJoiner(nf)
|
||||
self.noise_ref_join_conjoin = ReferenceJoinBlock(nf, residual_weight_init_factor=.01, norm=False)
|
||||
self.conjoin_ref_join = ReferenceJoinBlock(nf, residual_weight_init_factor=.2, norm=False)
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters, multiplx_fn,
|
||||
pre_transform_block=pretransform_fn, transform_block=transform_fn,
|
||||
self.conjoin_sw = ConfigurableSwitchComputer(transformation_filters,
|
||||
functools.partial(ModuleWithRef, nf, multiplx_fn),
|
||||
pre_transform_block=None,
|
||||
transform_block=functools.partial(ModuleWithRef, nf, transform_fn),
|
||||
attention_norm=True,
|
||||
transform_count=self.transformation_counts, init_temp=init_temperature,
|
||||
add_scalable_noise_to_transforms=False)
|
||||
|
@ -443,22 +455,22 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
|
||||
def forward(self, x, ref, center_coord):
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
#ref = self.reference_processor(ref, center_coord)
|
||||
ref = self.reference_processor(ref, center_coord)
|
||||
|
||||
x = self.model_fea_conv(x)
|
||||
x1 = x
|
||||
#x1 = self.ref_join1(x1, ref)
|
||||
x1 = self.ref_join1(x1, ref)
|
||||
x1, a1 = self.sw1(x1, True, identity=x)
|
||||
|
||||
x2 = x1
|
||||
#x2 = self.noise_ref_join(x2, torch.randn_like(x2))
|
||||
#x2 = self.ref_join2(x2, ref)
|
||||
x2 = self.noise_ref_join(x2, torch.randn_like(x2))
|
||||
x2 = self.ref_join2(x2, ref)
|
||||
x2, a2 = self.sw2(x2, True, identity=x1)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = x_grad
|
||||
#x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
|
||||
#x_grad = self.ref_join3(x_grad, ref)
|
||||
x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
|
||||
x_grad = self.ref_join3(x_grad, ref)
|
||||
x_grad = self.grad_ref_join(x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
|
@ -467,10 +479,9 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
x_grad_out = self.grad_branch_output_conv(x_grad_out)
|
||||
|
||||
x_out = x2
|
||||
#x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
||||
#x_out = self.ref_join4(x_out, ref)
|
||||
x_out = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
|
||||
x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
||||
x_out = self.ref_join4(x_out, ref)
|
||||
x_out, a4 = self.conjoin_sw((x_out, x_grad), True, identity=x2)
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
x_out = self.upsample(x_out)
|
||||
x_out = self.final_hr_conv1(x_out)
|
||||
|
|
Loading…
Reference in New Issue
Block a user