From b6e036147a06e83104112a45e80d64a3bf069465 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 30 Apr 2020 11:47:21 -0600 Subject: [PATCH] Add more batch norms to FlatProcessorNet_arch --- codes/models/archs/FlatProcessorNet_arch.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/codes/models/archs/FlatProcessorNet_arch.py b/codes/models/archs/FlatProcessorNet_arch.py index 2ce1b978..504487b8 100644 --- a/codes/models/archs/FlatProcessorNet_arch.py +++ b/codes/models/archs/FlatProcessorNet_arch.py @@ -23,10 +23,12 @@ class ReduceAnnealer(nn.Module): self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) arch_util.initialize_weights([self.reducer, self.annealer], .1) + self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True) + self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True) def forward(self, x, interpolated_trunk): - out = self.lrelu(self.reducer(x)) - out = self.lrelu(self.res_trunk(out)) + out = self.lrelu(self.bn_reduce(self.reducer(x))) + out = self.lrelu(self.bn_anneal(self.res_trunk(out))) annealed = self.lrelu(self.annealer(out)) + interpolated_trunk return annealed, out @@ -41,11 +43,13 @@ class Assembler(nn.Module): self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True) self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.bn = nn.BatchNorm2d(number_filters*4, affine=True) + self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True) def forward(self, input, skip_raw): out = self.pixel_shuffle(input) - out = self.upsampler(out) + skip_raw - out = self.lrelu(self.res_trunk(out)) + out = self.bn_up(self.upsampler(out)) + skip_raw + out = self.lrelu(self.bn(self.res_trunk(out))) return out class FlatProcessorNet(nn.Module):