diff --git a/codes/models/diffusion/rrdb_diffusion.py b/codes/models/diffusion/rrdb_diffusion.py index 64fde43b..5dc9c254 100644 --- a/codes/models/diffusion/rrdb_diffusion.py +++ b/codes/models/diffusion/rrdb_diffusion.py @@ -160,6 +160,7 @@ class RRDBNet(nn.Module): self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels) for m in [ self.conv_body, self.conv_up1, @@ -186,10 +187,10 @@ class RRDBNet(nn.Module): # upsample out = torch.cat([self.lrelu( - self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))), + self.normalize(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))), d2], dim=1) out = torch.cat([self.lrelu( - self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))), + self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))), d1], dim=1) out = self.conv_last(self.lrelu(self.conv_hr(out)))