From 3781ea725c9f0c095688ad248dc76c58f185224d Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 29 Apr 2020 20:51:57 -0600 Subject: [PATCH] Add Resnet Discriminator with BN --- codes/data/__init__.py | 2 +- .../archs/DiscriminatorResnetBN_arch.py | 150 ++++++++++++++++++ .../models/archs/DiscriminatorResnet_arch.py | 12 +- codes/models/networks.py | 5 +- .../train/train_GAN_blacked_corrupt.yml | 30 ++-- 5 files changed, 174 insertions(+), 25 deletions(-) create mode 100644 codes/models/archs/DiscriminatorResnetBN_arch.py diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 8887eef3..3bfb6261 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -21,7 +21,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): num_workers=num_workers, sampler=sampler, drop_last=True, pin_memory=False) else: - return torch.utils.data.DataLoader(dataset, batch_size=12, shuffle=False, num_workers=3, + return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False) diff --git a/codes/models/archs/DiscriminatorResnetBN_arch.py b/codes/models/archs/DiscriminatorResnetBN_arch.py new file mode 100644 index 00000000..ccd25432 --- /dev/null +++ b/codes/models/archs/DiscriminatorResnetBN_arch.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import numpy as np + + +__all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202'] + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + identity = torch.cat((identity, torch.zeros_like(identity)), 1) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_filters=16, num_classes=10): + super(ResNet, self).__init__() + self.num_layers = sum(layers) + self.inplanes = num_filters + self.conv1 = conv3x3(3, num_filters) + self.bn1 = nn.BatchNorm2d(num_filters) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(block, num_filters, layers[0]) + self.layer2 = self._make_layer(block, num_filters * 2, layers[1], stride=2) + self.layer3 = self._make_layer(block, num_filters * 4, layers[2], stride=2) + self.layer4 = self._make_layer(block, num_filters * 8, layers[2], stride=2) + self.fc1 = nn.Linear(num_filters * 8 * 8 * 8, 64, bias=True) + self.fc2 = nn.Linear(64, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + for m in self.modules(): + if isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1: + downsample = nn.Sequential( + nn.AvgPool2d(1, stride=stride), + nn.BatchNorm2d(self.inplanes), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes + for _ in range(1, blocks): + layers.append(block(planes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.view(x.size(0), -1) + x = self.relu(self.fc1(x)) + x = self.fc2(x) + + return x + + +def resnet20(**kwargs): + """Constructs a ResNet-20 model. + """ + model = ResNet(BasicBlock, [3, 3, 3], **kwargs) + return model + + +def resnet32(**kwargs): + """Constructs a ResNet-32 model. + """ + model = ResNet(BasicBlock, [5, 5, 5], **kwargs) + return model + + +def resnet44(**kwargs): + """Constructs a ResNet-44 model. + """ + model = ResNet(BasicBlock, [7, 7, 7], **kwargs) + return model + + +def resnet56(**kwargs): + """Constructs a ResNet-56 model. + """ + model = ResNet(BasicBlock, [9, 9, 9], **kwargs) + return model + + +def resnet110(**kwargs): + """Constructs a ResNet-110 model. + """ + model = ResNet(BasicBlock, [18, 18, 18], **kwargs) + return model + + +def resnet1202(**kwargs): + """Constructs a ResNet-1202 model. + """ + model = ResNet(BasicBlock, [200, 200, 200], **kwargs) + return model \ No newline at end of file diff --git a/codes/models/archs/DiscriminatorResnet_arch.py b/codes/models/archs/DiscriminatorResnet_arch.py index b1feea64..e0438ed3 100644 --- a/codes/models/archs/DiscriminatorResnet_arch.py +++ b/codes/models/archs/DiscriminatorResnet_arch.py @@ -101,15 +101,15 @@ class FixupResNet(nn.Module): self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bias1 = nn.Parameter(torch.zeros(1)) - self.relu = nn.ReLU(inplace=True) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.bias2 = nn.Parameter(torch.zeros(1)) - self.fc = nn.Linear(512 * block.expansion, num_classes) + self.fc1 = nn.Linear(512 * 2 * 2, 100) + self.fc2 = nn.Linear(100, num_classes) for m in self.modules(): if isinstance(m, FixupBasicBlock): @@ -142,7 +142,7 @@ class FixupResNet(nn.Module): def forward(self, x): x = self.conv1(x) - x = self.relu(x + self.bias1) + x = self.lrelu(x + self.bias1) x = self.maxpool(x) x = self.layer1(x) @@ -150,9 +150,9 @@ class FixupResNet(nn.Module): x = self.layer3(x) x = self.layer4(x) - x = self.avgpool(x) x = x.view(x.size(0), -1) - x = self.fc(x + self.bias2) + x = self.lrelu(self.fc1(x)) + x = self.fc2(x + self.bias2) return x diff --git a/codes/models/networks.py b/codes/models/networks.py index 75d3cf11..546f4659 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -2,10 +2,12 @@ import torch import models.archs.SRResNet_arch as SRResNet_arch import models.archs.discriminator_vgg_arch as SRGAN_arch import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch +import models.archs.DiscriminatorResnetBN_arch as DiscriminatorResnetBN_arch import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.EDVR_arch as EDVR_arch import models.archs.HighToLowResNet as HighToLowResNet import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch +import models.archs.arch_util as arch_utils import math # Generator @@ -54,8 +56,7 @@ def define_D(opt): if which_model == 'discriminator_vgg_128': netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) elif which_model == 'discriminator_resnet': - netD = DiscriminatorResnet_arch.DiscriminatorResnet(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_size=img_sz, - trunk_resblocks=opt_net['trunk_resblocks'], skip_resblocks=opt_net['skip_resblocks']) + netD = DiscriminatorResnetBN_arch.resnet32(num_filters=opt_net['nf'], num_classes=1) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/options/train/train_GAN_blacked_corrupt.yml b/codes/options/train/train_GAN_blacked_corrupt.yml index 8f8d0c99..f2f0edd6 100644 --- a/codes/options/train/train_GAN_blacked_corrupt.yml +++ b/codes/options/train/train_GAN_blacked_corrupt.yml @@ -16,8 +16,8 @@ datasets: dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training mismatched_Data_OK: true use_shuffle: true - n_workers: 8 # per GPU - batch_size: 32 + n_workers: 0 # per GPU + batch_size: 16 target_size: 64 use_flip: false use_rot: false @@ -34,31 +34,29 @@ network_G: which_model_G: FlatProcessorNet in_nc: 3 out_nc: 3 - nf: 32 - ra_blocks: 3 - assembler_blocks: 2 + nf: 48 + ra_blocks: 4 + assembler_blocks: 3 network_D: which_model_D: discriminator_resnet in_nc: 3 - nf: 32 - trunk_resblocks: 3 - skip_resblocks: 2 + nf: 64 #### path path: pretrain_model_G: ~ - pretrain_model_D: ~ + pretrain_model_D: ~ #../experiments/resnet_corrupt_discriminator_fixup.pth resume_state: ~ strict_load: true #### training settings: learning rate scheme, loss train: - lr_G: !!float 1e-5 + lr_G: !!float 1e-4 weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 1e-5 + lr_D: !!float 1e-4 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -66,18 +64,18 @@ train: niter: 400000 warmup_iter: -1 # no warm up - lr_steps: [4000, 8000, 12000, 15000, 20000] + lr_steps: [12000, 24000, 36000, 48000, 64000] lr_gamma: 0.5 - pixel_criterion: l1 + pixel_criterion: l2 pixel_weight: !!float 1e-2 feature_criterion: l1 feature_weight: 0 - gan_type: gan # gan | ragan + gan_type: ragan # gan | ragan gan_weight: !!float 1e-1 - D_update_ratio: 1 - D_init_iters: 1500 + D_update_ratio: 2 + D_init_iters: 1200 manual_seed: 10 val_freq: !!float 5e2