forked from mrq/DL-Art-School
Only head norm removed
This commit is contained in:
parent
65d0376b90
commit
75567a9814
|
@ -40,10 +40,12 @@ class ResidualDenseBlock(nn.Module):
|
||||||
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
|
nn.Conv2d(mid_channels + i * growth_channels, out_channels, 3,
|
||||||
1, 1))
|
1, 1))
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
for i in range(4):
|
for i in range(5):
|
||||||
default_init_weights(getattr(self, f'conv{i + 1}'), init_weight)
|
default_init_weights(getattr(self, f'conv{i + 1}'), init_weight)
|
||||||
default_init_weights(self.conv5, 0)
|
default_init_weights(self.conv5, 0)
|
||||||
|
|
||||||
|
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
"""Forward function.
|
"""Forward function.
|
||||||
|
|
||||||
|
@ -67,7 +69,7 @@ class ResidualDenseBlock(nn.Module):
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
|
||||||
return x5 * .2 + x
|
return self.normalize(x5 * .2 + x)
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
class RRDB(nn.Module):
|
||||||
|
@ -85,6 +87,7 @@ class RRDB(nn.Module):
|
||||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True)
|
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True)
|
||||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
|
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||||
self.residual_mult = nn.Parameter(torch.FloatTensor([.1]))
|
self.residual_mult = nn.Parameter(torch.FloatTensor([.1]))
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
|
@ -100,7 +103,7 @@ class RRDB(nn.Module):
|
||||||
out = self.rdb2(out, emb)
|
out = self.rdb2(out, emb)
|
||||||
out = self.rdb3(out, emb)
|
out = self.rdb3(out, emb)
|
||||||
|
|
||||||
return out * self.residual_mult + x
|
return self.normalize(out * self.residual_mult + x)
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class RRDBNet(nn.Module):
|
||||||
|
@ -192,7 +195,7 @@ class RRDBNet(nn.Module):
|
||||||
out = torch.cat([self.lrelu(
|
out = torch.cat([self.lrelu(
|
||||||
self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
|
self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
|
||||||
d1], dim=1)
|
d1], dim=1)
|
||||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
out = self.conv_last(self.normalize(self.lrelu(self.conv_hr(out))))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user