Add more batch norms to FlatProcessorNet_arch

This commit is contained in:
James Betker 2020-04-30 11:47:21 -06:00
parent 66e91a3d9e
commit b6e036147a

View File

@ -23,10 +23,12 @@ class ReduceAnnealer(nn.Module):
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
arch_util.initialize_weights([self.reducer, self.annealer], .1)
self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True)
self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True)
def forward(self, x, interpolated_trunk):
out = self.lrelu(self.reducer(x))
out = self.lrelu(self.res_trunk(out))
out = self.lrelu(self.bn_reduce(self.reducer(x)))
out = self.lrelu(self.bn_anneal(self.res_trunk(out)))
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
return annealed, out
@ -41,11 +43,13 @@ class Assembler(nn.Module):
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bn = nn.BatchNorm2d(number_filters*4, affine=True)
self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True)
def forward(self, input, skip_raw):
out = self.pixel_shuffle(input)
out = self.upsampler(out) + skip_raw
out = self.lrelu(self.res_trunk(out))
out = self.bn_up(self.upsampler(out)) + skip_raw
out = self.lrelu(self.bn(self.res_trunk(out)))
return out
class FlatProcessorNet(nn.Module):