diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 8485f53e..865c374c 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -443,22 +443,22 @@ class SwitchedSpsrWithRef2(nn.Module): def forward(self, x, ref, center_coord): x_grad = self.get_g_nopadding(x) - ref = self.reference_processor(ref, center_coord) + #ref = self.reference_processor(ref, center_coord) x = self.model_fea_conv(x) x1 = x - x1 = self.ref_join1(x1, ref) + #x1 = self.ref_join1(x1, ref) x1, a1 = self.sw1(x1, True, identity=x) x2 = x1 x2 = self.noise_ref_join(x2, torch.randn_like(x2)) - x2 = self.ref_join2(x2, ref) + #x2 = self.ref_join2(x2, ref) x2, a2 = self.sw2(x2, True, identity=x1) x_grad = self.grad_conv(x_grad) x_grad_identity = x_grad x_grad = self.noise_ref_join_grad(x_grad, torch.randn_like(x_grad)) - x_grad = self.ref_join3(x_grad, ref) + #x_grad = self.ref_join3(x_grad, ref) x_grad = self.grad_ref_join(x_grad, x1) x_grad, a3 = self.sw_grad(x_grad, True, identity=x_grad_identity) x_grad = self.grad_lr_conv(x_grad) @@ -468,7 +468,7 @@ class SwitchedSpsrWithRef2(nn.Module): x_out = x2 x_out = self.noise_ref_join_conjoin(x_out, torch.randn_like(x_out)) - x_out = self.ref_join4(x_out, ref) + #x_out = self.ref_join4(x_out, ref) x_out = self.conjoin_ref_join(x_out, x_grad) x_out, a4 = self.conjoin_sw(x_out, True, identity=x2) x_out = self.final_lr_conv(x_out)