Update multirrdb to do HR fixing in the base image dimension.
This commit is contained in:
parent
d3b2cbfe7c
commit
e89f28ead0
|
@ -354,14 +354,13 @@ class MultiRRDBNet(RRDBBase):
|
|||
def __init__(self, nf_base, gc_base, lo_blocks, hi_blocks, scale=2, rrdb_block_f=None):
|
||||
super(MultiRRDBNet, self).__init__()
|
||||
|
||||
# Initial downsampling.
|
||||
self.conv_first = nn.Conv2d(3, nf_base, 5, stride=2, padding=2, bias=True)
|
||||
|
||||
# Chained trunks
|
||||
lo_nf = nf_base * 4
|
||||
lo_nf_out = nf_base // 4
|
||||
hi_nf = nf_base
|
||||
self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=2, rrdb_block_f=rrdb_block_f)
|
||||
self.hi_trunk = RRDBTrunk(nf_base, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f)
|
||||
self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=1, rrdb_block_f=rrdb_block_f, conv_first_block=PixShuffleInitialConv(4, lo_nf))
|
||||
self.skip_conv = nn.Conv2d(3, lo_nf_out, 1)
|
||||
self.hi_trunk = RRDBTrunk(lo_nf_out, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f)
|
||||
self.trunks = [self.lo_trunk, self.hi_trunk]
|
||||
|
||||
# Upsampling
|
||||
|
@ -370,22 +369,19 @@ class MultiRRDBNet(RRDBBase):
|
|||
self.upconv2 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
||||
self.HRconv = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
||||
self.conv_last = nn.Conv2d(hi_nf, 3, 3, 1, 1, bias=True)
|
||||
self.pixel_shuffle = nn.PixelShuffle(2)
|
||||
self.pixel_shuffle = nn.PixelShuffle(4)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
fea = self.conv_first(x)
|
||||
fea_lo = self.lo_trunk(fea)
|
||||
fea = self.pixel_shuffle(fea_lo) + fea
|
||||
fea_lo = self.lo_trunk(x)
|
||||
fea = self.pixel_shuffle(fea_lo) + self.skip_conv(x)
|
||||
fea = self.hi_trunk(fea)
|
||||
|
||||
# First, return image to original size and perform post-processing.
|
||||
# Upsampling.
|
||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||
fea = self.lrelu(self.upconv1(fea))
|
||||
|
||||
# If 2x scaling is specified, do that too.
|
||||
if self.scale >= 2:
|
||||
if self.scale >= 4:
|
||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||
fea = self.lrelu(self.upconv2(fea))
|
||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||
|
|
Loading…
Reference in New Issue
Block a user