Allow RRDB to upscale 8x
This commit is contained in:
parent
087e9280ed
commit
ef7eabf457
|
@ -210,6 +210,10 @@ class RRDBNet(nn.Module):
|
|||
# upsample
|
||||
self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||
if scale >= 8:
|
||||
self.conv_up3 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||
else:
|
||||
self.conv_up3 = None
|
||||
self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1)
|
||||
|
||||
|
@ -221,7 +225,7 @@ class RRDBNet(nn.Module):
|
|||
|
||||
for m in [
|
||||
self.conv_first, self.conv_body, self.conv_up1,
|
||||
self.conv_up2, self.conv_hr, self.conv_last
|
||||
self.conv_up2, self.conv_up3, self.conv_hr, self.conv_last
|
||||
]:
|
||||
if m is not None:
|
||||
default_init_weights(m, 0.1)
|
||||
|
@ -262,9 +266,12 @@ class RRDBNet(nn.Module):
|
|||
# upsample
|
||||
out = self.lrelu(
|
||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||
if self.scale == 4:
|
||||
if self.scale >= 4:
|
||||
out = self.lrelu(
|
||||
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
|
||||
if self.scale >= 8:
|
||||
out = self.lrelu(
|
||||
self.conv_up3(F.interpolate(out, scale_factor=2, mode='nearest')))
|
||||
else:
|
||||
out = self.lrelu(self.conv_up2(out))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||
|
|
Loading…
Reference in New Issue
Block a user