Allow RRDB to upscale 8x

This commit is contained in:
James Betker 2020-12-14 23:58:52 -07:00
parent 087e9280ed
commit ef7eabf457

View File

@ -210,6 +210,10 @@ class RRDBNet(nn.Module):
# upsample # upsample
self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1) self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1) self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
if scale >= 8:
self.conv_up3 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
else:
self.conv_up3 = None
self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1) self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1) self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1)
@ -221,7 +225,7 @@ class RRDBNet(nn.Module):
for m in [ for m in [
self.conv_first, self.conv_body, self.conv_up1, self.conv_first, self.conv_body, self.conv_up1,
self.conv_up2, self.conv_hr, self.conv_last self.conv_up2, self.conv_up3, self.conv_hr, self.conv_last
]: ]:
if m is not None: if m is not None:
default_init_weights(m, 0.1) default_init_weights(m, 0.1)
@ -262,9 +266,12 @@ class RRDBNet(nn.Module):
# upsample # upsample
out = self.lrelu( out = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
if self.scale == 4: if self.scale >= 4:
out = self.lrelu( out = self.lrelu(
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))) self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
if self.scale >= 8:
out = self.lrelu(
self.conv_up3(F.interpolate(out, scale_factor=2, mode='nearest')))
else: else:
out = self.lrelu(self.conv_up2(out)) out = self.lrelu(self.conv_up2(out))
out = self.conv_last(self.lrelu(self.conv_hr(out))) out = self.conv_last(self.lrelu(self.conv_hr(out)))