From 11a9e223a65003874be9dedf513413611fff9b48 Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 27 Oct 2020 11:14:36 -0600
Subject: [PATCH] Retrofit SPSR_arch so it is capable of accepting a ref

---
 codes/models/archs/SPSR_arch.py | 19 ++++++++++++++++---
 1 file changed, 16 insertions(+), 3 deletions(-)

diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py
index 6f60c324..a0e44a28 100644
--- a/codes/models/archs/SPSR_arch.py
+++ b/codes/models/archs/SPSR_arch.py
@@ -94,10 +94,13 @@ class SPSRNet(nn.Module):
 
         n_upscale = int(math.log(upscale, 2))
 
+        self.scale=n_upscale
         if upscale == 3:
             n_upscale = 1
 
-        fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
+        fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False)
+        self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=7, norm=False, activation=False)
+        self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
         rb_blocks = [RRDB(nf) for _ in range(nb)]
 
         LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
@@ -114,7 +117,9 @@ class SPSRNet(nn.Module):
         self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \
                                   *upsampler, self.HR_conv0_new)
 
-        self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
+        self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False)
+        self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=3, norm=False, activation=False)
+        self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
 
         self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
         self.b_block_1 = RRDB(nf * 2)
@@ -151,10 +156,16 @@ class SPSRNet(nn.Module):
 
         self.get_g_nopadding = ImageGradientNoPadding()
 
-    def forward(self, x):
+    def forward(self, x, ref=None):
+        b,f,h,w = x.shape
+        if ref is None:
+            ref = torch.zeros((b,f,h*self.scale,w*self.scale), device=x.device, dtype=x.dtype)
 
         x_grad = self.get_g_nopadding(x)
+        ref_grad = self.get_g_nopadding(ref)
         x = self.model[0](x)
+        x_ref = self.ref_conv(ref)
+        x = self.join_conv(torch.cat([x, x_ref], dim=1))
 
         x, block_list = self.model[1](x)
 
@@ -182,6 +193,8 @@ class SPSRNet(nn.Module):
         x = self.HR_conv1_new(x)
 
         x_b_fea = self.b_fea_conv(x_grad)
+        x_b_ref = self.b_ref_conv(ref_grad)
+        x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
         x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
 
         x_cat_1 = self.b_block_1(x_cat_1)