Update multirrdb to do HR fixing in the base image dimension.

This commit is contained in:
James Betker 2020-06-11 08:43:39 -06:00
parent d3b2cbfe7c
commit e89f28ead0

View File

@ -354,14 +354,13 @@ class MultiRRDBNet(RRDBBase):
def __init__(self, nf_base, gc_base, lo_blocks, hi_blocks, scale=2, rrdb_block_f=None):
super(MultiRRDBNet, self).__init__()
# Initial downsampling.
self.conv_first = nn.Conv2d(3, nf_base, 5, stride=2, padding=2, bias=True)
# Chained trunks
lo_nf = nf_base * 4
lo_nf_out = nf_base // 4
hi_nf = nf_base
self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=2, rrdb_block_f=rrdb_block_f)
self.hi_trunk = RRDBTrunk(nf_base, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f)
self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=1, rrdb_block_f=rrdb_block_f, conv_first_block=PixShuffleInitialConv(4, lo_nf))
self.skip_conv = nn.Conv2d(3, lo_nf_out, 1)
self.hi_trunk = RRDBTrunk(lo_nf_out, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f)
self.trunks = [self.lo_trunk, self.hi_trunk]
# Upsampling
@ -370,22 +369,19 @@ class MultiRRDBNet(RRDBBase):
self.upconv2 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
self.HRconv = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
self.conv_last = nn.Conv2d(hi_nf, 3, 3, 1, 1, bias=True)
self.pixel_shuffle = nn.PixelShuffle(2)
self.pixel_shuffle = nn.PixelShuffle(4)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
fea_lo = self.lo_trunk(fea)
fea = self.pixel_shuffle(fea_lo) + fea
fea_lo = self.lo_trunk(x)
fea = self.pixel_shuffle(fea_lo) + self.skip_conv(x)
fea = self.hi_trunk(fea)
# First, return image to original size and perform post-processing.
# Upsampling.
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
fea = self.lrelu(self.upconv1(fea))
# If 2x scaling is specified, do that too.
if self.scale >= 2:
if self.scale >= 4:
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
fea = self.lrelu(self.upconv2(fea))
out = self.conv_last(self.lrelu(self.HRconv(fea)))