Remove rrdb normalization
This commit is contained in:
parent
f5e75602b9
commit
184e887122
|
@ -44,8 +44,6 @@ class ResidualDenseBlock(nn.Module):
|
|||
default_init_weights(getattr(self, f'conv{i + 1}'), init_weight)
|
||||
default_init_weights(self.conv5, 0)
|
||||
|
||||
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""Forward function.
|
||||
|
||||
|
@ -69,7 +67,7 @@ class ResidualDenseBlock(nn.Module):
|
|||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||
|
||||
return self.normalize(x5 * .2 + x)
|
||||
return x5 * .2 + x
|
||||
|
||||
|
||||
class RRDB(nn.Module):
|
||||
|
@ -87,7 +85,6 @@ class RRDB(nn.Module):
|
|||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True)
|
||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
||||
self.residual_mult = nn.Parameter(torch.FloatTensor([.1]))
|
||||
|
||||
def forward(self, x, emb):
|
||||
|
@ -103,7 +100,7 @@ class RRDB(nn.Module):
|
|||
out = self.rdb2(out, emb)
|
||||
out = self.rdb3(out, emb)
|
||||
|
||||
return self.normalize(out * self.residual_mult + x)
|
||||
return out * self.residual_mult + x
|
||||
|
||||
|
||||
class RRDBNet(nn.Module):
|
||||
|
@ -136,9 +133,9 @@ class RRDBNet(nn.Module):
|
|||
self.mid_channels = mid_channels
|
||||
|
||||
# The diffusion RRDB starts with a full resolution image and downsamples into a .25 working space
|
||||
self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=True, bias=True)
|
||||
self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=True, bias=True)
|
||||
self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=True, bias=True)
|
||||
self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=False, bias=True)
|
||||
self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True)
|
||||
self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True)
|
||||
|
||||
# Guided diffusion uses a time embedding.
|
||||
time_embed_dim = mid_channels * 4
|
||||
|
@ -163,7 +160,6 @@ class RRDBNet(nn.Module):
|
|||
self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1)
|
||||
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels)
|
||||
|
||||
for m in [
|
||||
self.conv_body, self.conv_up1,
|
||||
|
@ -186,17 +182,16 @@ class RRDBNet(nn.Module):
|
|||
for bl in self.body:
|
||||
feat = checkpoint(bl, feat, emb)
|
||||
feat = feat[:, :self.mid_channels]
|
||||
body_feat = self.conv_body(feat)
|
||||
feat = self.normalize(feat + body_feat)
|
||||
feat = self.conv_body(feat)
|
||||
|
||||
# upsample
|
||||
out = torch.cat([self.lrelu(
|
||||
self.normalize(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))),
|
||||
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))),
|
||||
d2], dim=1)
|
||||
out = torch.cat([self.lrelu(
|
||||
self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
|
||||
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))),
|
||||
d1], dim=1)
|
||||
out = self.conv_last(self.normalize(self.lrelu(self.conv_hr(out))))
|
||||
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user