From 5b8a77f02c55ae06554e588d7caa358c4bfad8d8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 28 Apr 2020 23:00:29 -0600 Subject: [PATCH] Discriminator part 1 New discriminator. Includes spectral norming. --- .../models/archs/DiscriminatorResnet_arch.py | 85 ++++++++++++ codes/models/archs/arch_util.py | 121 +++++++++++++++++- codes/models/networks.py | 4 + .../train/train_GAN_blacked_corrupt.yml | 24 ++-- 4 files changed, 221 insertions(+), 13 deletions(-) create mode 100644 codes/models/archs/DiscriminatorResnet_arch.py diff --git a/codes/models/archs/DiscriminatorResnet_arch.py b/codes/models/archs/DiscriminatorResnet_arch.py new file mode 100644 index 00000000..e30a3ab9 --- /dev/null +++ b/codes/models/archs/DiscriminatorResnet_arch.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torchvision +import models.archs.arch_util as arch_util +import functools +import torch.nn.functional as F +import torch.nn.utils.spectral_norm as SpectralNorm + +# Class that halfs the image size (x4 complexity reduction) and doubles the filter size. Substantial resnet +# processing is also performed. +class ResnetDownsampleLayer(nn.Module): + def __init__(self, starting_channels: int, number_filters: int, filter_multiplier: int, residual_blocks_input: int, residual_blocks_skip_image: int, total_residual_blocks: int): + super(ResnetDownsampleLayer, self).__init__() + + self.skip_image_reducer = SpectralNorm(nn.Conv2d(starting_channels, number_filters, 3, stride=1, padding=1, bias=True)) + self.skip_image_res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlockSpectralNorm, nf=number_filters, total_residual_blocks=total_residual_blocks), residual_blocks_skip_image) + + self.input_reducer = SpectralNorm(nn.Conv2d(number_filters, number_filters*filter_multiplier, 3, stride=2, padding=1, bias=True)) + self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlockSpectralNorm, nf=number_filters*filter_multiplier, total_residual_blocks=total_residual_blocks), residual_blocks_input) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + arch_util.initialize_weights([self.input_reducer, self.skip_image_reducer], 1) + + def forward(self, x, skip_image): + # Process the skip image first. + skip = self.lrelu(self.skip_image_reducer(skip_image)) + skip = self.skip_image_res_trunk(skip) + + # Concat the processed skip image onto the input and perform processing. + out = (x + skip) / 2 + out = self.lrelu(self.input_reducer(out)) + out = self.res_trunk(out) + return out + +class DiscriminatorResnet(nn.Module): + # Discriminator that downsamples 5 times with resnet blocks at each layer. On each downsample, the filter size is + # increased by a factor of 2. Feeds the output of the convs into a dense for prediction at the logits. Scales the + # final dense based on the input image size. Intended for use with input images which are multiples of 32. + # + # This discriminator also includes provisions to pass an image at various downsample steps in directly. When this + # is done with a generator, it will allow much shorter gradient paths between the generator and discriminator. When + # no downsampled images are passed into the forward() pass, they will be automatically generated from the source + # image using interpolation. + # + # Uses spectral normalization rather than batch normalization. + def __init__(self, in_nc: int, nf: int, input_img_size: int, trunk_resblocks: int, skip_resblocks: int): + super(DiscriminatorResnet, self).__init__() + self.dimensionalize = nn.Conv2d(in_nc, nf, kernel_size=3, stride=1, padding=1, bias=True) + + # Trunk resblocks are the important things to get right, so use those. 5=number of downsample layers. + total_resblocks = trunk_resblocks * 5 + self.downsample1 = ResnetDownsampleLayer(in_nc, nf, 2, trunk_resblocks, skip_resblocks, total_resblocks) + self.downsample2 = ResnetDownsampleLayer(in_nc, nf*2, 2, trunk_resblocks, skip_resblocks, total_resblocks) + self.downsample3 = ResnetDownsampleLayer(in_nc, nf*4, 2, trunk_resblocks, skip_resblocks, total_resblocks) + # At the bottom layers, we cap the filter multiplier. We want this particular network to focus as much on the + # macro-details at higher image dimensionality as it does to the feature details. + self.downsample4 = ResnetDownsampleLayer(in_nc, nf*8, 1, trunk_resblocks, skip_resblocks, total_resblocks) + self.downsample5 = ResnetDownsampleLayer(in_nc, nf*8, 1, trunk_resblocks, skip_resblocks, total_resblocks) + self.downsamplers = [self.downsample1, self.downsample2, self.downsample3, self.downsample4, self.downsample5] + + downsampled_image_size = input_img_size / 32 + self.linear1 = nn.Linear(int(nf * 8 * downsampled_image_size * downsampled_image_size), 100) + self.linear2 = nn.Linear(100, 1) + + # activation function + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + arch_util.initialize_weights([self.dimensionalize, self.linear1, self.linear2], 1) + + def forward(self, x, skip_images=None): + if skip_images is None: + # Sythesize them from x. + skip_images = [] + for i in range(len(self.downsamplers)): + m = 2 ** i + skip_images.append(F.interpolate(x, scale_factor=1 / m, mode='bilinear', align_corners=False)) + + fea = self.dimensionalize(x) + for skip, d in zip(skip_images, self.downsamplers): + fea = d(fea, skip) + + fea = fea.view(fea.size(0), -1) + fea = self.lrelu(self.linear1(fea)) + out = self.linear2(fea) + return out diff --git a/codes/models/archs/arch_util.py b/codes/models/archs/arch_util.py index e2b4a0b9..c33af559 100644 --- a/codes/models/archs/arch_util.py +++ b/codes/models/archs/arch_util.py @@ -2,7 +2,16 @@ import torch import torch.nn as nn import torch.nn.init as init import torch.nn.functional as F +import torch.nn.utils.spectral_norm as SpectralNorm +from math import sqrt +def scale_conv_weights_fixup(conv, residual_block_count, m=2): + k = conv.kernel_size[0] + n = conv.out_channels + scaling_factor = residual_block_count ** (-1.0 / (2 * m - 2)) + sigma = sqrt(2 / (k * k * n)) * scaling_factor + conv.weight.data = conv.weight.data * sigma + return conv def initialize_weights(net_l, scale=1): if not isinstance(net_l, list): @@ -30,6 +39,89 @@ def make_layer(block, n_layers): layers.append(block()) return nn.Sequential(*layers) +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class FixupBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(FixupBasicBlock, self).__init__() + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.bias1a = nn.Parameter(torch.zeros(1)) + self.conv1 = conv3x3(inplanes, planes, stride) + self.bias1b = nn.Parameter(torch.zeros(1)) + self.relu = nn.ReLU(inplace=True) + self.bias2a = nn.Parameter(torch.zeros(1)) + self.conv2 = conv3x3(planes, planes) + self.scale = nn.Parameter(torch.ones(1)) + self.bias2b = nn.Parameter(torch.zeros(1)) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x + self.bias1a) + out = self.relu(out + self.bias1b) + + out = self.conv2(out + self.bias2a) + out = out * self.scale + self.bias2b + + if self.downsample is not None: + identity = self.downsample(x + self.bias1a) + + out += identity + out = self.relu(out) + + return out + +class FixupBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(FixupBottleneck, self).__init__() + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.bias1a = nn.Parameter(torch.zeros(1)) + self.conv1 = conv1x1(inplanes, planes) + self.bias1b = nn.Parameter(torch.zeros(1)) + self.bias2a = nn.Parameter(torch.zeros(1)) + self.conv2 = conv3x3(planes, planes, stride) + self.bias2b = nn.Parameter(torch.zeros(1)) + self.bias3a = nn.Parameter(torch.zeros(1)) + self.conv3 = conv1x1(planes, planes * self.expansion) + self.scale = nn.Parameter(torch.ones(1)) + self.bias3b = nn.Parameter(torch.zeros(1)) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x + self.bias1a) + out = self.relu(out + self.bias1b) + + out = self.conv2(out + self.bias2a) + out = self.relu(out + self.bias2b) + + out = self.conv3(out + self.bias3a) + out = out * self.scale + self.bias3b + + if self.downsample is not None: + identity = self.downsample(x + self.bias1a) + + out += identity + out = self.relu(out) + + return out + class ResidualBlock(nn.Module): '''Residual block with BN ---Conv-BN-ReLU-Conv-+- @@ -38,6 +130,7 @@ class ResidualBlock(nn.Module): def __init__(self, nf=64): super(ResidualBlock, self).__init__() + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.BN1 = nn.BatchNorm2d(nf) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) @@ -48,10 +141,33 @@ class ResidualBlock(nn.Module): def forward(self, x): identity = x - out = F.relu(self.BN1(self.conv1(x)), inplace=True) + out = self.lrelu(self.BN1(self.conv1(x))) out = self.BN2(self.conv2(out)) return identity + out +class ResidualBlockSpectralNorm(nn.Module): + '''Residual block with Spectral Normalization. + ---SpecConv-ReLU-SpecConv-+- + |________________| + ''' + + def __init__(self, nf, total_residual_blocks): + super(ResidualBlockSpectralNorm, self).__init__() + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.conv1 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) + self.conv2 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True)) + + # Initialize first. + initialize_weights([self.conv1, self.conv2], 1) + # Then perform fixup scaling + self.conv1 = scale_conv_weights_fixup(self.conv1, total_residual_blocks) + self.conv2 = scale_conv_weights_fixup(self.conv2, total_residual_blocks) + + def forward(self, x): + identity = x + out = self.lrelu(self.conv1(x)) + out = self.conv2(out) + return identity + out class ResidualBlock_noBN(nn.Module): '''Residual block w/o BN @@ -61,6 +177,7 @@ class ResidualBlock_noBN(nn.Module): def __init__(self, nf=64): super(ResidualBlock_noBN, self).__init__() + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) @@ -69,7 +186,7 @@ class ResidualBlock_noBN(nn.Module): def forward(self, x): identity = x - out = F.relu(self.conv1(x), inplace=True) + out = self.lrelu(self.conv1(x)) out = self.conv2(out) return identity + out diff --git a/codes/models/networks.py b/codes/models/networks.py index dfb9ad32..75d3cf11 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -1,6 +1,7 @@ import torch import models.archs.SRResNet_arch as SRResNet_arch import models.archs.discriminator_vgg_arch as SRGAN_arch +import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch import models.archs.RRDBNet_arch as RRDBNet_arch import models.archs.EDVR_arch as EDVR_arch import models.archs.HighToLowResNet as HighToLowResNet @@ -52,6 +53,9 @@ def define_D(opt): if which_model == 'discriminator_vgg_128': netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128) + elif which_model == 'discriminator_resnet': + netD = DiscriminatorResnet_arch.DiscriminatorResnet(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_size=img_sz, + trunk_resblocks=opt_net['trunk_resblocks'], skip_resblocks=opt_net['skip_resblocks']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/options/train/train_GAN_blacked_corrupt.yml b/codes/options/train/train_GAN_blacked_corrupt.yml index d9257128..8f8d0c99 100644 --- a/codes/options/train/train_GAN_blacked_corrupt.yml +++ b/codes/options/train/train_GAN_blacked_corrupt.yml @@ -16,7 +16,7 @@ datasets: dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training mismatched_Data_OK: true use_shuffle: true - n_workers: 4 # per GPU + n_workers: 8 # per GPU batch_size: 32 target_size: 64 use_flip: false @@ -35,19 +35,21 @@ network_G: in_nc: 3 out_nc: 3 nf: 32 - ra_blocks: 5 - assembler_blocks: 3 + ra_blocks: 3 + assembler_blocks: 2 network_D: - which_model_D: discriminator_vgg_128 + which_model_D: discriminator_resnet in_nc: 3 - nf: 64 + nf: 32 + trunk_resblocks: 3 + skip_resblocks: 2 #### path path: - pretrain_model_G: ../experiments/corrupt_flatnet_G.pth - pretrain_model_D: ../experiments/corrupt_flatnet_D.pth - resume_state: ../experiments/corruptGAN_4k_lqprn_closeup_flat_net/training_state/3000.state + pretrain_model_G: ~ + pretrain_model_D: ~ + resume_state: ~ strict_load: true #### training settings: learning rate scheme, loss @@ -56,7 +58,7 @@ train: weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 4e-5 + lr_D: !!float 1e-5 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -71,11 +73,11 @@ train: pixel_weight: !!float 1e-2 feature_criterion: l1 feature_weight: 0 - gan_type: ragan # gan | ragan + gan_type: gan # gan | ragan gan_weight: !!float 1e-1 D_update_ratio: 1 - D_init_iters: 0 + D_init_iters: 1500 manual_seed: 10 val_freq: !!float 5e2