diff --git a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py index f9e1e101..34729080 100644 --- a/codes/models/archs/DiscriminatorResnet_arch_passthrough.py +++ b/codes/models/archs/DiscriminatorResnet_arch_passthrough.py @@ -11,6 +11,11 @@ def conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) +def conv5x5(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, + padding=2, bias=False) + def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" @@ -20,27 +25,35 @@ def conv1x1(in_planes, out_planes, stride=1): class FixupBasicBlock(nn.Module): expansion = 1 - def __init__(self, inplanes, planes, stride=1, downsample=None): + def __init__(self, inplanes, planes, stride=1, downsample=None, use_bn=False, conv_create=conv3x3): super(FixupBasicBlock, self).__init__() # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.bias1a = nn.Parameter(torch.zeros(1)) - self.conv1 = conv3x3(inplanes, planes, stride) + self.conv1 = conv_create(inplanes, planes, stride) self.bias1b = nn.Parameter(torch.zeros(1)) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.bias2a = nn.Parameter(torch.zeros(1)) - self.conv2 = conv3x3(planes, planes) + self.conv2 = conv_create(planes, planes) self.scale = nn.Parameter(torch.ones(1)) self.bias2b = nn.Parameter(torch.zeros(1)) self.downsample = downsample self.stride = stride + self.use_bn = use_bn + if use_bn: + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) def forward(self, x): identity = x out = self.conv1(x + self.bias1a) + if self.use_bn: + out = self.bn1(out) out = self.lrelu(out + self.bias1b) out = self.conv2(out + self.bias2a) + if self.use_bn: + out = self.bn2(out) out = out * self.scale + self.bias2b if self.downsample is not None: @@ -94,26 +107,23 @@ class FixupBottleneck(nn.Module): class FixupResNet(nn.Module): - def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64): + def __init__(self, block, layers, num_filters=64, num_classes=1000, input_img_size=64, use_bn=False): super(FixupResNet, self).__init__() self.num_layers = sum(layers) - self.inplanes = num_filters - self.conv1 = nn.Conv2d(3, num_filters, kernel_size=7, stride=2, padding=3, - bias=False) - self.bias1 = nn.Parameter(torch.zeros(1)) + self.inplanes = 3 self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) - self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1) - self.skip1 = nn.Conv2d(num_filters + 3, num_filters, kernel_size=5, stride=1, padding=2, bias=False) - self.skip1_bias = nn.Parameter(torch.zeros(1)) - self.layer2 = self._make_layer(block, num_filters*2, layers[1], stride=2) - self.skip2 = nn.Conv2d(num_filters*2 + 3, num_filters*2, kernel_size=5, stride=1, padding=2, bias=False) - self.skip2_bias = nn.Parameter(torch.zeros(1)) - self.layer3 = self._make_layer(block, num_filters*4, layers[2], stride=2) - self.layer4 = self._make_layer(block, num_filters*8, layers[3], stride=2) - self.layer5 = self._make_layer(block, num_filters*16, layers[4], stride=2) + self.layer0 = self._make_layer(block, num_filters*2, layers[0], stride=2, use_bn=use_bn, conv_type=conv5x5) + self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator. + self.layer1 = self._make_layer(block, num_filters*4, layers[1], stride=2, use_bn=use_bn, conv_type=conv5x5) + self.inplanes = self.inplanes + 3 # Accomodate a skip connection from the generator. + self.layer2 = self._make_layer(block, num_filters*8, layers[2], stride=2, use_bn=use_bn) + # SRGAN already has a feature loss tied to a separate VGG discriminator. We really don't care about features. + # Therefore, level off the filter count from this block forwards. + self.layer3 = self._make_layer(block, num_filters*8, layers[3], stride=2, use_bn=use_bn) + self.layer4 = self._make_layer(block, num_filters*8, layers[4], stride=2, use_bn=use_bn) self.bias2 = nn.Parameter(torch.zeros(1)) reduced_img_sz = int(input_img_size / 32) - self.fc1 = nn.Linear(num_filters * 16 * reduced_img_sz * reduced_img_sz, 100) + self.fc1 = nn.Linear(num_filters * 8 * reduced_img_sz * reduced_img_sz, 100) self.fc2 = nn.Linear(100, num_classes) for m in self.modules(): @@ -128,39 +138,31 @@ class FixupResNet(nn.Module): nn.init.constant_(m.conv3.weight, 0) if m.downsample is not None: nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) - ''' - elif isinstance(m, nn.Linear): - nn.init.constant_(m.weight, 0) - nn.init.constant_(m.bias, 0)''' - - def _make_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = conv1x1(self.inplanes, planes * block.expansion, stride) + def _make_layer(self, block, outplanes, blocks, stride=1, use_bn=False, conv_type=conv3x3): layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes * block.expansion for _ in range(1, blocks): - layers.append(block(self.inplanes, planes)) + layers.append(block(self.inplanes, self.inplanes)) + + downsample = None + if stride != 1 or self.inplanes != outplanes * block.expansion: + downsample = conv1x1(self.inplanes, outplanes * block.expansion, stride) + layers.append(block(self.inplanes, outplanes, stride, downsample, use_bn=use_bn, conv_create=conv_type)) + self.inplanes = outplanes * block.expansion return nn.Sequential(*layers) def forward(self, x): # This class expects a medium skip (half-res) and low skip (quarter-res) provided as a tuple in the input. - hi, med_skip, lo_skip = x + x, med_skip, lo_skip = x - x = self.conv1(hi) - x = self.lrelu(x + self.bias1) + x = self.layer0(x) + x = torch.cat([x, med_skip], dim=1) x = self.layer1(x) - x = self.lrelu(self.skip1(torch.cat([x, med_skip], dim=1)) + self.skip1_bias) - + x = torch.cat([x, lo_skip], dim=1) x = self.layer2(x) - x = self.lrelu(self.skip2(torch.cat([x, lo_skip], dim=1)) + self.skip2_bias) - x = self.layer3(x) x = self.layer4(x) - x = self.layer5(x) x = x.view(x.size(0), -1) x = self.lrelu(self.fc1(x)) @@ -179,7 +181,7 @@ def fixup_resnet18(**kwargs): def fixup_resnet34(**kwargs): """Constructs a Fixup-ResNet-34 model. """ - model = FixupResNet(FixupBasicBlock, [5, 4, 3, 3, 2], **kwargs) + model = FixupResNet(FixupBasicBlock, [5, 5, 3, 3, 3], **kwargs) return model diff --git a/codes/models/networks.py b/codes/models/networks.py index c23a69c0..f1f39640 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -34,7 +34,8 @@ def define_G(opt): nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'], interpolation_scale_factor=scale_per_step) elif which_model == 'ResGen': - netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf']) + netG = ResGen_arch.fixup_resnet34(nb_denoiser=opt_net['nb_denoiser'], nb_upsampler=opt_net['nb_upsampler'], + num_filters=opt_net['nf']) # image corruption elif which_model == 'HighToLowResNet': @@ -70,7 +71,7 @@ def define_D(opt): elif which_model == 'discriminator_resnet': netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) elif which_model == 'discriminator_resnet_passthrough': - netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz) + netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz, use_bn=True) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD