Retrofit SPSR_arch so it is capable of accepting a ref
This commit is contained in:
parent
8202ee72b9
commit
11a9e223a6
|
@ -94,10 +94,13 @@ class SPSRNet(nn.Module):
|
|||
|
||||
n_upscale = int(math.log(upscale, 2))
|
||||
|
||||
self.scale=n_upscale
|
||||
if upscale == 3:
|
||||
n_upscale = 1
|
||||
|
||||
fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False)
|
||||
self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=7, norm=False, activation=False)
|
||||
self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
rb_blocks = [RRDB(nf) for _ in range(nb)]
|
||||
|
||||
LR_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
@ -114,7 +117,9 @@ class SPSRNet(nn.Module):
|
|||
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)), \
|
||||
*upsampler, self.HR_conv0_new)
|
||||
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False)
|
||||
self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=3, norm=False, activation=False)
|
||||
self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
|
||||
self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False)
|
||||
self.b_block_1 = RRDB(nf * 2)
|
||||
|
@ -151,10 +156,16 @@ class SPSRNet(nn.Module):
|
|||
|
||||
self.get_g_nopadding = ImageGradientNoPadding()
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, ref=None):
|
||||
b,f,h,w = x.shape
|
||||
if ref is None:
|
||||
ref = torch.zeros((b,f,h*self.scale,w*self.scale), device=x.device, dtype=x.dtype)
|
||||
|
||||
x_grad = self.get_g_nopadding(x)
|
||||
ref_grad = self.get_g_nopadding(ref)
|
||||
x = self.model[0](x)
|
||||
x_ref = self.ref_conv(ref)
|
||||
x = self.join_conv(torch.cat([x, x_ref], dim=1))
|
||||
|
||||
x, block_list = self.model[1](x)
|
||||
|
||||
|
@ -182,6 +193,8 @@ class SPSRNet(nn.Module):
|
|||
x = self.HR_conv1_new(x)
|
||||
|
||||
x_b_fea = self.b_fea_conv(x_grad)
|
||||
x_b_ref = self.b_ref_conv(ref_grad)
|
||||
x_b_fea = self.b_join_conv(torch.cat([x_b_fea, x_b_ref], dim=1))
|
||||
x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
|
||||
|
||||
x_cat_1 = self.b_block_1(x_cat_1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user