Only head norm removed
This commit is contained in:
parent
65d0376b90
commit
75567a9814
|
@ -40,10 +40,12 @@ class ResidualDenseBlock(nn.Module):
|
|||
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
|
||||
1, 1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
for i in range(4):
|
||||
for i in range(5):
|
||||
default_init_weights(getattr(self, f'conv{i + 1}'), init_weight)
|
||||
default_init_weights(self.conv5, 0)
|
||||
|
||||
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""Forward function.
|
||||
|
||||
|
@ -67,7 +69,7 @@ class ResidualDenseBlock(nn.Module):
|
|||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
|
||||
return x5 * .2 + x
|
||||
return self.normalize(x5 * .2 + x)
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
|
@ -85,6 +87,7 @@ class RRDB(nn.Module):
|
|||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.residual_mult = nn.Parameter(torch.FloatTensor([.1]))
|
||||
|
||||
def forward(self, x, emb):
|
||||
|
@ -100,7 +103,7 @@ class RRDB(nn.Module):
|
|||
out = self.rdb2(out, emb)
|
||||
out = self.rdb3(out, emb)
|
||||
|
||||
return out * self.residual_mult + x
|
||||
return self.normalize(out * self.residual_mult + x)
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
|
@ -192,7 +195,7 @@ class RRDBNet(nn.Module):
|
|||
out = torch.cat([self.lrelu(
|
||||
self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
|
||||
d1], dim=1)
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||
out = self.conv_last(self.normalize(self.lrelu(self.conv_hr(out))))
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user