Allow RRDB to upscale 8x

This commit is contained in:
James Betker 2020-12-14 23:58:52 -07:00
parent 087e9280ed
commit ef7eabf457

View File

@ -210,6 +210,10 @@ class RRDBNet(nn.Module):
# upsample
self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
if scale >= 8:
self.conv_up3 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
else:
self.conv_up3 = None
self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1)
@ -221,7 +225,7 @@ class RRDBNet(nn.Module):
for m in [
self.conv_first, self.conv_body, self.conv_up1,
self.conv_up2, self.conv_hr, self.conv_last
self.conv_up2, self.conv_up3, self.conv_hr, self.conv_last
]:
if m is not None:
default_init_weights(m, 0.1)
@ -262,9 +266,12 @@ class RRDBNet(nn.Module):
# upsample
out = self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
if self.scale == 4:
if self.scale >= 4:
out = self.lrelu(
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
if self.scale >= 8:
out = self.lrelu(
self.conv_up3(F.interpolate(out, scale_factor=2, mode='nearest')))
else:
out = self.lrelu(self.conv_up2(out))
out = self.conv_last(self.lrelu(self.conv_hr(out)))