Allow RRDB to upscale 8x
This commit is contained in:
parent
087e9280ed
commit
ef7eabf457
|
@ -210,6 +210,10 @@ class RRDBNet(nn.Module):
|
||||||
# upsample
|
# upsample
|
||||||
self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
self.conv_up1 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||||
self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
self.conv_up2 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||||
|
if scale >= 8:
|
||||||
|
self.conv_up3 = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||||
|
else:
|
||||||
|
self.conv_up3 = None
|
||||||
self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
self.conv_hr = nn.Conv2d(self.reduce_ch, self.reduce_ch, 3, 1, 1)
|
||||||
self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1)
|
self.conv_last = nn.Conv2d(self.reduce_ch, out_channels, 3, 1, 1)
|
||||||
|
|
||||||
|
@ -221,7 +225,7 @@ class RRDBNet(nn.Module):
|
||||||
|
|
||||||
for m in [
|
for m in [
|
||||||
self.conv_first, self.conv_body, self.conv_up1,
|
self.conv_first, self.conv_body, self.conv_up1,
|
||||||
self.conv_up2, self.conv_hr, self.conv_last
|
self.conv_up2, self.conv_up3, self.conv_hr, self.conv_last
|
||||||
]:
|
]:
|
||||||
if m is not None:
|
if m is not None:
|
||||||
default_init_weights(m, 0.1)
|
default_init_weights(m, 0.1)
|
||||||
|
@ -262,9 +266,12 @@ class RRDBNet(nn.Module):
|
||||||
# upsample
|
# upsample
|
||||||
out = self.lrelu(
|
out = self.lrelu(
|
||||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
||||||
if self.scale == 4:
|
if self.scale >= 4:
|
||||||
out = self.lrelu(
|
out = self.lrelu(
|
||||||
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
|
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))
|
||||||
|
if self.scale >= 8:
|
||||||
|
out = self.lrelu(
|
||||||
|
self.conv_up3(F.interpolate(out, scale_factor=2, mode='nearest')))
|
||||||
else:
|
else:
|
||||||
out = self.lrelu(self.conv_up2(out))
|
out = self.lrelu(self.conv_up2(out))
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user