Remove rrdb normalization

This commit is contained in:
James Betker 2021-06-05 21:39:19 -06:00
parent f5e75602b9
commit 184e887122

View File

@ -44,8 +44,6 @@ class ResidualDenseBlock(nn.Module):
default_init_weights(getattr(self, f'conv{i + 1}'), init_weight)
default_init_weights(self.conv5, 0)
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
def forward(self, x, emb):
"""Forward function.
@ -69,7 +67,7 @@ class ResidualDenseBlock(nn.Module):
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return self.normalize(x5 * .2 + x)
return x5 * .2 + x
class RRDB(nn.Module):
@ -87,7 +85,6 @@ class RRDB(nn.Module):
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
self.residual_mult = nn.Parameter(torch.FloatTensor([.1]))
def forward(self, x, emb):
@ -103,7 +100,7 @@ class RRDB(nn.Module):
out = self.rdb2(out, emb)
out = self.rdb3(out, emb)
return self.normalize(out * self.residual_mult + x)
return out * self.residual_mult + x
class RRDBNet(nn.Module):
@ -136,9 +133,9 @@ class RRDBNet(nn.Module):
self.mid_channels = mid_channels
# The diffusion RRDB starts with a full resolution image and downsamples into a .25 working space
self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=True, bias=True)
self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=True, bias=True)
self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=True, bias=True)
self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=False, bias=True)
self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True)
self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True)
# Guided diffusion uses a time embedding.
time_embed_dim = mid_channels * 4
@ -163,7 +160,6 @@ class RRDBNet(nn.Module):
self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels)
for m in [
self.conv_body, self.conv_up1,
@ -186,17 +182,16 @@ class RRDBNet(nn.Module):
for bl in self.body:
feat = checkpoint(bl, feat, emb)
feat = feat[:, :self.mid_channels]
body_feat = self.conv_body(feat)
feat = self.normalize(feat + body_feat)
feat = self.conv_body(feat)
# upsample
out = torch.cat([self.lrelu(
self.normalize(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))),
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))),
d2], dim=1)
out = torch.cat([self.lrelu(
self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))),
d1], dim=1)
out = self.conv_last(self.normalize(self.lrelu(self.conv_hr(out))))
out = self.conv_last(self.lrelu(self.conv_hr(out)))
return out