From 184e8871229547d88984885a27683a4d7f477036 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 5 Jun 2021 21:39:19 -0600 Subject: [PATCH] Remove rrdb normalization --- codes/models/diffusion/rrdb_diffusion.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/codes/models/diffusion/rrdb_diffusion.py b/codes/models/diffusion/rrdb_diffusion.py index 6a06c25f..64fde43b 100644 --- a/codes/models/diffusion/rrdb_diffusion.py +++ b/codes/models/diffusion/rrdb_diffusion.py @@ -44,8 +44,6 @@ class ResidualDenseBlock(nn.Module): default_init_weights(getattr(self, f'conv{i + 1}'), init_weight) default_init_weights(self.conv5, 0) - self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels) - def forward(self, x, emb): """Forward function. @@ -69,7 +67,7 @@ class ResidualDenseBlock(nn.Module): x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) - return self.normalize(x5 * .2 + x) + return x5 * .2 + x class RRDB(nn.Module): @@ -87,7 +85,6 @@ class RRDB(nn.Module): self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True) self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels) self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels) - self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels) self.residual_mult = nn.Parameter(torch.FloatTensor([.1])) def forward(self, x, emb): @@ -103,7 +100,7 @@ class RRDB(nn.Module): out = self.rdb2(out, emb) out = self.rdb3(out, emb) - return self.normalize(out * self.residual_mult + x) + return out * self.residual_mult + x class RRDBNet(nn.Module): @@ -136,9 +133,9 @@ class RRDBNet(nn.Module): self.mid_channels = mid_channels # The diffusion RRDB starts with a full resolution image and downsamples into a .25 working space - self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=True, bias=True) - self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=True, bias=True) - self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=True, bias=True) + self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=False, bias=True) + self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True) + self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True) # Guided diffusion uses a time embedding. time_embed_dim = mid_channels * 4 @@ -163,7 +160,6 @@ class RRDBNet(nn.Module): self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1) self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels) for m in [ self.conv_body, self.conv_up1, @@ -186,17 +182,16 @@ class RRDBNet(nn.Module): for bl in self.body: feat = checkpoint(bl, feat, emb) feat = feat[:, :self.mid_channels] - body_feat = self.conv_body(feat) - feat = self.normalize(feat + body_feat) + feat = self.conv_body(feat) # upsample out = torch.cat([self.lrelu( - self.normalize(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))), + self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))), d2], dim=1) out = torch.cat([self.lrelu( - self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))), + self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))), d1], dim=1) - out = self.conv_last(self.normalize(self.lrelu(self.conv_hr(out)))) + out = self.conv_last(self.lrelu(self.conv_hr(out))) return out