forked from mrq/DL-Art-School
Re-add normalization at the tail of the RRDB
This commit is contained in:
parent
184e887122
commit
65d0376b90
|
@ -160,6 +160,7 @@ class RRDBNet(nn.Module):
|
||||||
self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1)
|
self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1)
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels)
|
||||||
|
|
||||||
for m in [
|
for m in [
|
||||||
self.conv_body, self.conv_up1,
|
self.conv_body, self.conv_up1,
|
||||||
|
@ -186,10 +187,10 @@ class RRDBNet(nn.Module):
|
||||||
|
|
||||||
# upsample
|
# upsample
|
||||||
out = torch.cat([self.lrelu(
|
out = torch.cat([self.lrelu(
|
||||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))),
|
self.normalize(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))),
|
||||||
d2], dim=1)
|
d2], dim=1)
|
||||||
out = torch.cat([self.lrelu(
|
out = torch.cat([self.lrelu(
|
||||||
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))),
|
self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
|
||||||
d1], dim=1)
|
d1], dim=1)
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user