From 3b81712c4950903b673ee0d94a2cc992ceed28fc Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 19 Jun 2020 16:52:56 -0600 Subject: [PATCH] Remove BN from transforms --- .../archs/SwitchedResidualGenerator_arch.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/codes/models/archs/SwitchedResidualGenerator_arch.py b/codes/models/archs/SwitchedResidualGenerator_arch.py index 1a96a544..c35612ce 100644 --- a/codes/models/archs/SwitchedResidualGenerator_arch.py +++ b/codes/models/archs/SwitchedResidualGenerator_arch.py @@ -8,12 +8,15 @@ from switched_conv_util import save_attention_to_image class ConvBnLelu(nn.Module): - def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True): + def __init__(self, filters_in, filters_out, kernel_size=3, stride=1, lelu=True, bn=True): super(ConvBnLelu, self).__init__() padding_map = {1: 0, 3: 1, 5: 2, 7: 3} assert kernel_size in padding_map.keys() self.conv = nn.Conv2d(filters_in, filters_out, kernel_size, stride, padding_map[kernel_size]) - self.bn = nn.BatchNorm2d(filters_out) + if bn: + self.bn = nn.BatchNorm2d(filters_out) + else: + self.bn = None if lelu: self.lelu = nn.LeakyReLU(negative_slope=.1) else: @@ -21,7 +24,8 @@ class ConvBnLelu(nn.Module): def forward(self, x): x = self.conv(x) - x = self.bn(x) + if self.bn: + x = self.bn(x) if self.lelu: return self.lelu(x) else: @@ -32,9 +36,9 @@ class ResidualBranch(nn.Module): def __init__(self, filters_in, filters_mid, filters_out, kernel_size, depth): assert depth >= 2 super(ResidualBranch, self).__init__() - self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size)] + - [ConvBnLelu(filters_mid, filters_mid, kernel_size) for i in range(depth-2)] + - [ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False)]) + self.bnconvs = nn.ModuleList([ConvBnLelu(filters_in, filters_mid, kernel_size, bn=False)] + + [ConvBnLelu(filters_mid, filters_mid, kernel_size, bn=False) for i in range(depth-2)] + + [ConvBnLelu(filters_mid, filters_out, kernel_size, lelu=False, bn=False)]) self.scale = nn.Parameter(torch.ones(1)) self.bias = nn.Parameter(torch.zeros(1))