Add more batch norms to FlatProcessorNet_arch
This commit is contained in:
parent
66e91a3d9e
commit
b6e036147a
|
@ -23,10 +23,12 @@ class ReduceAnnealer(nn.Module):
|
|||
self.annealer = nn.Conv2d(number_filters*4, number_filters, 3, stride=1, padding=1, bias=True)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
arch_util.initialize_weights([self.reducer, self.annealer], .1)
|
||||
self.bn_reduce = nn.BatchNorm2d(number_filters*4, affine=True)
|
||||
self.bn_anneal = nn.BatchNorm2d(number_filters*4, affine=True)
|
||||
|
||||
def forward(self, x, interpolated_trunk):
|
||||
out = self.lrelu(self.reducer(x))
|
||||
out = self.lrelu(self.res_trunk(out))
|
||||
out = self.lrelu(self.bn_reduce(self.reducer(x)))
|
||||
out = self.lrelu(self.bn_anneal(self.res_trunk(out)))
|
||||
annealed = self.lrelu(self.annealer(out)) + interpolated_trunk
|
||||
return annealed, out
|
||||
|
||||
|
@ -41,11 +43,13 @@ class Assembler(nn.Module):
|
|||
self.upsampler = nn.Conv2d(number_filters, number_filters*4, 3, stride=1, padding=1, bias=True)
|
||||
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlock, nf=number_filters*4), residual_blocks)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.bn = nn.BatchNorm2d(number_filters*4, affine=True)
|
||||
self.bn_up = nn.BatchNorm2d(number_filters*4, affine=True)
|
||||
|
||||
def forward(self, input, skip_raw):
|
||||
out = self.pixel_shuffle(input)
|
||||
out = self.upsampler(out) + skip_raw
|
||||
out = self.lrelu(self.res_trunk(out))
|
||||
out = self.bn_up(self.upsampler(out)) + skip_raw
|
||||
out = self.lrelu(self.bn(self.res_trunk(out)))
|
||||
return out
|
||||
|
||||
class FlatProcessorNet(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue
Block a user