From 11a9e223a65003874be9dedf513413611fff9b48 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 27 Oct 2020 11:14:36 -0600 Subject: [PATCH] Retrofit SPSR_arch so it is capable of accepting a ref --- codes/models/archs/SPSR_arch.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index 6f60c324..a0e44a28 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -94,10 +94,13 @@ class SPSRNet(nn.Module): n_upscale = int(math.log(upscale, 2)) + self.scale=n_upscale if upscale == 3: n_upscale = 1 - fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False) + self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=7, norm=False, activation=False) + self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) rb_blocks = [RRDB(nf) for _ in range(nb)] LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) @@ -114,7 +117,9 @@ class SPSRNet(nn.Module): self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \ *upsampler, self.HR_conv0_new) - self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False) + self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False) + self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=3, norm=False, activation=False) + self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) self.b_block_1 = RRDB(nf * 2) @@ -151,10 +156,16 @@ class SPSRNet(nn.Module): self.get_g_nopadding = ImageGradientNoPadding() - def forward(self, x): + def forward(self, x, ref=None): + b,f,h,w = x.shape + if ref is None: + ref = torch.zeros((b,f,h*self.scale,w*self.scale), device=x.device, dtype=x.dtype) x_grad = self.get_g_nopadding(x) + ref_grad = self.get_g_nopadding(ref) x = self.model[0](x) + x_ref = self.ref_conv(ref) + x = self.join_conv(torch.cat([x, x_ref], dim=1)) x, block_list = self.model[1](x) @@ -182,6 +193,8 @@ class SPSRNet(nn.Module): x = self.HR_conv1_new(x) x_b_fea = self.b_fea_conv(x_grad) + x_b_ref = self.b_ref_conv(ref_grad) + x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1)) x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1) x_cat_1 = self.b_block_1(x_cat_1)