Only head norm removed

This commit is contained in:
James Betker 2021-06-05 23:29:11 -06:00
parent 65d0376b90
commit 75567a9814

View File

@ -40,10 +40,12 @@ class ResidualDenseBlock(nn.Module):
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
1, 1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
for i in range(4):
for i in range(5):
default_init_weights(getattr(self, f'conv{i + 1}'), init_weight)
default_init_weights(self.conv5, 0)
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
def forward(self, x, emb):
"""Forward function.
@ -67,7 +69,7 @@ class ResidualDenseBlock(nn.Module):
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * .2 + x
return self.normalize(x5 * .2 + x)
class RRDB(nn.Module):
@ -85,6 +87,7 @@ class RRDB(nn.Module):
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True)
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
self.residual_mult = nn.Parameter(torch.FloatTensor([.1]))
def forward(self, x, emb):
@ -100,7 +103,7 @@ class RRDB(nn.Module):
out = self.rdb2(out, emb)
out = self.rdb3(out, emb)
return out * self.residual_mult + x
return self.normalize(out * self.residual_mult + x)
class RRDBNet(nn.Module):
@ -192,7 +195,7 @@ class RRDBNet(nn.Module):
out = torch.cat([self.lrelu(
self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
d1], dim=1)
out = self.conv_last(self.lrelu(self.conv_hr(out)))
out = self.conv_last(self.normalize(self.lrelu(self.conv_hr(out))))
return out