From e89f28ead0594525cb77c960e5e696bade4a34df Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 11 Jun 2020 08:43:39 -0600 Subject: [PATCH] Update multirrdb to do HR fixing in the base image dimension. --- codes/models/archs/RRDBNet_arch.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 0c8852fb..558c24c9 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -354,14 +354,13 @@ class MultiRRDBNet(RRDBBase): def __init__(self, nf_base, gc_base, lo_blocks, hi_blocks, scale=2, rrdb_block_f=None): super(MultiRRDBNet, self).__init__() - # Initial downsampling. - self.conv_first = nn.Conv2d(3, nf_base, 5, stride=2, padding=2, bias=True) - # Chained trunks lo_nf = nf_base * 4 + lo_nf_out = nf_base // 4 hi_nf = nf_base - self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=2, rrdb_block_f=rrdb_block_f) - self.hi_trunk = RRDBTrunk(nf_base, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f) + self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=1, rrdb_block_f=rrdb_block_f, conv_first_block=PixShuffleInitialConv(4, lo_nf)) + self.skip_conv = nn.Conv2d(3, lo_nf_out, 1) + self.hi_trunk = RRDBTrunk(lo_nf_out, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f) self.trunks = [self.lo_trunk, self.hi_trunk] # Upsampling @@ -370,22 +369,19 @@ class MultiRRDBNet(RRDBBase): self.upconv2 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True) self.HRconv = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True) self.conv_last = nn.Conv2d(hi_nf, 3, 3, 1, 1, bias=True) - self.pixel_shuffle = nn.PixelShuffle(2) + self.pixel_shuffle = nn.PixelShuffle(4) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) def forward(self, x): - fea = self.conv_first(x) - fea_lo = self.lo_trunk(fea) - fea = self.pixel_shuffle(fea_lo) + fea + fea_lo = self.lo_trunk(x) + fea = self.pixel_shuffle(fea_lo) + self.skip_conv(x) fea = self.hi_trunk(fea) - # First, return image to original size and perform post-processing. + # Upsampling. fea = F.interpolate(fea, scale_factor=2, mode='nearest') fea = self.lrelu(self.upconv1(fea)) - - # If 2x scaling is specified, do that too. - if self.scale >= 2: + if self.scale >= 4: fea = F.interpolate(fea, scale_factor=2, mode='nearest') fea = self.lrelu(self.upconv2(fea)) out = self.conv_last(self.lrelu(self.HRconv(fea)))