Re-add normalization at the tail of the RRDB

This commit is contained in:
James Betker 2021-06-05 23:04:05 -06:00
parent 184e887122
commit 65d0376b90

View File

@ -160,6 +160,7 @@ class RRDBNet(nn.Module):
self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1) self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels)
for m in [ for m in [
self.conv_body, self.conv_up1, self.conv_body, self.conv_up1,
@ -186,10 +187,10 @@ class RRDBNet(nn.Module):
# upsample # upsample
out = torch.cat([self.lrelu( out = torch.cat([self.lrelu(
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))), self.normalize(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))),
d2], dim=1) d2], dim=1)
out = torch.cat([self.lrelu( out = torch.cat([self.lrelu(
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))), self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
d1], dim=1) d1], dim=1)
out = self.conv_last(self.lrelu(self.conv_hr(out))) out = self.conv_last(self.lrelu(self.conv_hr(out)))