forked from mrq/DL-Art-School
Remove rrdb normalization
This commit is contained in:
parent
f5e75602b9
commit
184e887122
|
@ -44,8 +44,6 @@ class ResidualDenseBlock(nn.Module):
|
||||||
default_init_weights(getattr(self, f'conv{i + 1}'), init_weight)
|
default_init_weights(getattr(self, f'conv{i + 1}'), init_weight)
|
||||||
default_init_weights(self.conv5, 0)
|
default_init_weights(self.conv5, 0)
|
||||||
|
|
||||||
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
"""Forward function.
|
"""Forward function.
|
||||||
|
|
||||||
|
@ -69,7 +67,7 @@ class ResidualDenseBlock(nn.Module):
|
||||||
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
||||||
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
||||||
|
|
||||||
return self.normalize(x5 * .2 + x)
|
return x5 * .2 + x
|
||||||
|
|
||||||
|
|
||||||
class RRDB(nn.Module):
|
class RRDB(nn.Module):
|
||||||
|
@ -87,7 +85,6 @@ class RRDB(nn.Module):
|
||||||
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True)
|
self.rdb1 = ResidualDenseBlock(mid_channels, growth_channels, embedding=True)
|
||||||
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
self.rdb2 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
self.rdb3 = ResidualDenseBlock(mid_channels, growth_channels)
|
||||||
self.normalize = nn.GroupNorm(num_groups=8, num_channels=mid_channels)
|
|
||||||
self.residual_mult = nn.Parameter(torch.FloatTensor([.1]))
|
self.residual_mult = nn.Parameter(torch.FloatTensor([.1]))
|
||||||
|
|
||||||
def forward(self, x, emb):
|
def forward(self, x, emb):
|
||||||
|
@ -103,7 +100,7 @@ class RRDB(nn.Module):
|
||||||
out = self.rdb2(out, emb)
|
out = self.rdb2(out, emb)
|
||||||
out = self.rdb3(out, emb)
|
out = self.rdb3(out, emb)
|
||||||
|
|
||||||
return self.normalize(out * self.residual_mult + x)
|
return out * self.residual_mult + x
|
||||||
|
|
||||||
|
|
||||||
class RRDBNet(nn.Module):
|
class RRDBNet(nn.Module):
|
||||||
|
@ -136,9 +133,9 @@ class RRDBNet(nn.Module):
|
||||||
self.mid_channels = mid_channels
|
self.mid_channels = mid_channels
|
||||||
|
|
||||||
# The diffusion RRDB starts with a full resolution image and downsamples into a .25 working space
|
# The diffusion RRDB starts with a full resolution image and downsamples into a .25 working space
|
||||||
self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=True, bias=True)
|
self.input_block = ConvGnLelu(in_channels, mid_channels, kernel_size=7, stride=1, activation=True, norm=False, bias=True)
|
||||||
self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=True, bias=True)
|
self.down1 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True)
|
||||||
self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=True, bias=True)
|
self.down2 = ConvGnLelu(mid_channels, mid_channels, kernel_size=3, stride=2, activation=True, norm=False, bias=True)
|
||||||
|
|
||||||
# Guided diffusion uses a time embedding.
|
# Guided diffusion uses a time embedding.
|
||||||
time_embed_dim = mid_channels * 4
|
time_embed_dim = mid_channels * 4
|
||||||
|
@ -163,7 +160,6 @@ class RRDBNet(nn.Module):
|
||||||
self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1)
|
self.conv_last = nn.Conv2d(self.mid_channels, out_channels, 3, 1, 1)
|
||||||
|
|
||||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
self.normalize = nn.GroupNorm(num_groups=8, num_channels=self.mid_channels)
|
|
||||||
|
|
||||||
for m in [
|
for m in [
|
||||||
self.conv_body, self.conv_up1,
|
self.conv_body, self.conv_up1,
|
||||||
|
@ -186,17 +182,16 @@ class RRDBNet(nn.Module):
|
||||||
for bl in self.body:
|
for bl in self.body:
|
||||||
feat = checkpoint(bl, feat, emb)
|
feat = checkpoint(bl, feat, emb)
|
||||||
feat = feat[:, :self.mid_channels]
|
feat = feat[:, :self.mid_channels]
|
||||||
body_feat = self.conv_body(feat)
|
feat = self.conv_body(feat)
|
||||||
feat = self.normalize(feat + body_feat)
|
|
||||||
|
|
||||||
# upsample
|
# upsample
|
||||||
out = torch.cat([self.lrelu(
|
out = torch.cat([self.lrelu(
|
||||||
self.normalize(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))),
|
self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))),
|
||||||
d2], dim=1)
|
d2], dim=1)
|
||||||
out = torch.cat([self.lrelu(
|
out = torch.cat([self.lrelu(
|
||||||
self.normalize(self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest')))),
|
self.conv_up2(F.interpolate(out, scale_factor=2, mode='nearest'))),
|
||||||
d1], dim=1)
|
d1], dim=1)
|
||||||
out = self.conv_last(self.normalize(self.lrelu(self.conv_hr(out))))
|
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user