From 75567a9814f693c64fc333db2cf50309a216d0bd Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Jun 2021 23:29:11 -0600 Subject: [PATCH] Only head norm removed --- codes/models/diffusion/rrdb_diffusion.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/codes/models/diffusion/rrdb_diffusion.py b/codes/models/diffusion/rrdb_diffusion.py index 5dc9c254..0cbbe8ab 100644 --- a/codes/models/diffusion/rrdb_diffusion.py +++ b/codes/models/diffusion/rrdb_diffusion.py @@ -40,10 +40,12 @@ class ResidualDenseBlock(nn.Module): nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3, 1, 1)) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - for i in range(4): + for i in range(5): default_init_weights(getattr(self, f'conv{i + 1}'), init_weight) default_init_weights(self.conv5, 0) + self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels) + def forward(self, x, emb): """Forward function. @@ -67,7 +69,7 @@ class ResidualDenseBlock(nn.Module): x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return x5 * .2 + x + return self.normalize(x5 * .2 + x) class RRDB(nn.Module): @@ -85,6 +87,7 @@ class RRDB(nn.Module): self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True) self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) + self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels) self.residual_mult = nn.Parameter(torch.FloatTensor([.1])) def forward(self, x, emb): @@ -100,7 +103,7 @@ class RRDB(nn.Module): out = self.rdb2(out, emb) out = self.rdb3(out, emb) - return out * self.residual_mult + x + return self.normalize(out * self.residual_mult + x) class RRDBNet(nn.Module): @@ -192,7 +195,7 @@ class RRDBNet(nn.Module): out = torch.cat([self.lrelu( self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))), d1], dim=1) - out = self.conv_last(self.lrelu(self.conv_hr(out))) + out = self.conv_last(self.normalize(self.lrelu(self.conv_hr(out)))) return out