forked from mrq/DL-Art-School
Update multirrdb to do HR fixing in the base image dimension.
This commit is contained in:
parent
d3b2cbfe7c
commit
e89f28ead0
|
@ -354,14 +354,13 @@ class MultiRRDBNet(RRDBBase):
|
||||||
def __init__(self, nf_base, gc_base, lo_blocks, hi_blocks, scale=2, rrdb_block_f=None):
|
def __init__(self, nf_base, gc_base, lo_blocks, hi_blocks, scale=2, rrdb_block_f=None):
|
||||||
super(MultiRRDBNet, self).__init__()
|
super(MultiRRDBNet, self).__init__()
|
||||||
|
|
||||||
# Initial downsampling.
|
|
||||||
self.conv_first = nn.Conv2d(3, nf_base, 5, stride=2, padding=2, bias=True)
|
|
||||||
|
|
||||||
# Chained trunks
|
# Chained trunks
|
||||||
lo_nf = nf_base * 4
|
lo_nf = nf_base * 4
|
||||||
|
lo_nf_out = nf_base // 4
|
||||||
hi_nf = nf_base
|
hi_nf = nf_base
|
||||||
self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=2, rrdb_block_f=rrdb_block_f)
|
self.lo_trunk = RRDBTrunk(nf_base, lo_nf, lo_blocks, gc_base * 2, initial_stride=1, rrdb_block_f=rrdb_block_f, conv_first_block=PixShuffleInitialConv(4, lo_nf))
|
||||||
self.hi_trunk = RRDBTrunk(nf_base, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f)
|
self.skip_conv = nn.Conv2d(3, lo_nf_out, 1)
|
||||||
|
self.hi_trunk = RRDBTrunk(lo_nf_out, hi_nf, hi_blocks, gc_base, initial_stride=1, rrdb_block_f=rrdb_block_f)
|
||||||
self.trunks = [self.lo_trunk, self.hi_trunk]
|
self.trunks = [self.lo_trunk, self.hi_trunk]
|
||||||
|
|
||||||
# Upsampling
|
# Upsampling
|
||||||
|
@ -370,22 +369,19 @@ class MultiRRDBNet(RRDBBase):
|
||||||
self.upconv2 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
self.upconv2 = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
||||||
self.HRconv = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
self.HRconv = nn.Conv2d(hi_nf, hi_nf, 5, 1, padding=2, bias=True)
|
||||||
self.conv_last = nn.Conv2d(hi_nf, 3, 3, 1, 1, bias=True)
|
self.conv_last = nn.Conv2d(hi_nf, 3, 3, 1, 1, bias=True)
|
||||||
self.pixel_shuffle = nn.PixelShuffle(2)
|
self.pixel_shuffle = nn.PixelShuffle(4)
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
fea = self.conv_first(x)
|
fea_lo = self.lo_trunk(x)
|
||||||
fea_lo = self.lo_trunk(fea)
|
fea = self.pixel_shuffle(fea_lo) + self.skip_conv(x)
|
||||||
fea = self.pixel_shuffle(fea_lo) + fea
|
|
||||||
fea = self.hi_trunk(fea)
|
fea = self.hi_trunk(fea)
|
||||||
|
|
||||||
# First, return image to original size and perform post-processing.
|
# Upsampling.
|
||||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||||
fea = self.lrelu(self.upconv1(fea))
|
fea = self.lrelu(self.upconv1(fea))
|
||||||
|
if self.scale >= 4:
|
||||||
# If 2x scaling is specified, do that too.
|
|
||||||
if self.scale >= 2:
|
|
||||||
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
fea = F.interpolate(fea, scale_factor=2, mode='nearest')
|
||||||
fea = self.lrelu(self.upconv2(fea))
|
fea = self.lrelu(self.upconv2(fea))
|
||||||
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
out = self.conv_last(self.lrelu(self.HRconv(fea)))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user