forked from mrq/DL-Art-School
Temporary commit - ref
This commit is contained in:
parent
df59d6c99d
commit
00da69d450
|
@ -443,22 +443,22 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
|
||||
def forward(self, x, ref, center_coord):
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
ref = self.reference_processor(ref, center_coord)
|
||||
#ref = self.reference_processor(ref, center_coord)
|
||||
|
||||
x = self.model_fea_conv(x)
|
||||
x1 = x
|
||||
x1 = self.ref_join1(x1, ref)
|
||||
#x1 = self.ref_join1(x1, ref)
|
||||
x1, a1 = self.sw1(x1, True, identity=x)
|
||||
|
||||
x2 = x1
|
||||
x2 = self.noise_ref_join(x2, torch.randn_like(x2))
|
||||
x2 = self.ref_join2(x2, ref)
|
||||
#x2 = self.ref_join2(x2, ref)
|
||||
x2, a2 = self.sw2(x2, True, identity=x1)
|
||||
|
||||
x_grad = self.grad_conv(x_grad)
|
||||
x_grad_identity = x_grad
|
||||
x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad))
|
||||
x_grad = self.ref_join3(x_grad, ref)
|
||||
#x_grad = self.ref_join3(x_grad, ref)
|
||||
x_grad = self.grad_ref_join(x_grad, x1)
|
||||
x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity)
|
||||
x_grad = self.grad_lr_conv(x_grad)
|
||||
|
@ -468,7 +468,7 @@ class SwitchedSpsrWithRef2(nn.Module):
|
|||
|
||||
x_out = x2
|
||||
x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out))
|
||||
x_out = self.ref_join4(x_out, ref)
|
||||
#x_out = self.ref_join4(x_out, ref)
|
||||
x_out = self.conjoin_ref_join(x_out, x_grad)
|
||||
x_out, a4 = self.conjoin_sw(x_out, True, identity=x2)
|
||||
x_out = self.final_lr_conv(x_out)
|
||||
|
|
Loading…
Reference in New Issue
Block a user