Retrofit SPSR_arch so it is capable of accepting a ref

This commit is contained in:
James Betker 2020-10-27 11:14:36 -06:00
parent 8202ee72b9
commit 11a9e223a6

View File

@ -94,10 +94,13 @@ class SPSRNet(nn.Module):
n_upscale = int(math.log(upscale, 2))
self.scale=n_upscale
if upscale == 3:
n_upscale = 1
fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False)
self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=7, norm=False, activation=False)
self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
rb_blocks = [RRDB(nf) for _ in range(nb)]
LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
@ -114,7 +117,9 @@ class SPSRNet(nn.Module):
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \
*upsampler, self.HR_conv0_new)
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False)
self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=3, norm=False, activation=False)
self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
self.b_block_1 = RRDB(nf * 2)
@ -151,10 +156,16 @@ class SPSRNet(nn.Module):
self.get_g_nopadding = ImageGradientNoPadding()
def forward(self, x):
def forward(self, x, ref=None):
b,f,h,w = x.shape
if ref is None:
ref = torch.zeros((b,f,h*self.scale,w*self.scale), device=x.device, dtype=x.dtype)
x_grad = self.get_g_nopadding(x)
ref_grad = self.get_g_nopadding(ref)
x = self.model[0](x)
x_ref = self.ref_conv(ref)
x = self.join_conv(torch.cat([x, x_ref], dim=1))
x, block_list = self.model[1](x)
@ -182,6 +193,8 @@ class SPSRNet(nn.Module):
x = self.HR_conv1_new(x)
x_b_fea = self.b_fea_conv(x_grad)
x_b_ref = self.b_ref_conv(ref_grad)
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
x_cat_1 = self.b_block_1(x_cat_1)