forked from mrq/DL-Art-School
Discriminator part 1
New discriminator. Includes spectral norming.
This commit is contained in:
parent
2c145c39b6
commit
5b8a77f02c
85
codes/models/archs/DiscriminatorResnet_arch.py
Normal file
85
codes/models/archs/DiscriminatorResnet_arch.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchvision
|
||||||
|
import models.archs.arch_util as arch_util
|
||||||
|
import functools
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.utils.spectral_norm as SpectralNorm
|
||||||
|
|
||||||
|
# Class that halfs the image size (x4 complexity reduction) and doubles the filter size. Substantial resnet
|
||||||
|
# processing is also performed.
|
||||||
|
class ResnetDownsampleLayer(nn.Module):
|
||||||
|
def __init__(self, starting_channels: int, number_filters: int, filter_multiplier: int, residual_blocks_input: int, residual_blocks_skip_image: int, total_residual_blocks: int):
|
||||||
|
super(ResnetDownsampleLayer, self).__init__()
|
||||||
|
|
||||||
|
self.skip_image_reducer = SpectralNorm(nn.Conv2d(starting_channels, number_filters, 3, stride=1, padding=1, bias=True))
|
||||||
|
self.skip_image_res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlockSpectralNorm, nf=number_filters, total_residual_blocks=total_residual_blocks), residual_blocks_skip_image)
|
||||||
|
|
||||||
|
self.input_reducer = SpectralNorm(nn.Conv2d(number_filters, number_filters*filter_multiplier, 3, stride=2, padding=1, bias=True))
|
||||||
|
self.res_trunk = arch_util.make_layer(functools.partial(arch_util.ResidualBlockSpectralNorm, nf=number_filters*filter_multiplier, total_residual_blocks=total_residual_blocks), residual_blocks_input)
|
||||||
|
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
arch_util.initialize_weights([self.input_reducer, self.skip_image_reducer], 1)
|
||||||
|
|
||||||
|
def forward(self, x, skip_image):
|
||||||
|
# Process the skip image first.
|
||||||
|
skip = self.lrelu(self.skip_image_reducer(skip_image))
|
||||||
|
skip = self.skip_image_res_trunk(skip)
|
||||||
|
|
||||||
|
# Concat the processed skip image onto the input and perform processing.
|
||||||
|
out = (x + skip) / 2
|
||||||
|
out = self.lrelu(self.input_reducer(out))
|
||||||
|
out = self.res_trunk(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class DiscriminatorResnet(nn.Module):
|
||||||
|
# Discriminator that downsamples 5 times with resnet blocks at each layer. On each downsample, the filter size is
|
||||||
|
# increased by a factor of 2. Feeds the output of the convs into a dense for prediction at the logits. Scales the
|
||||||
|
# final dense based on the input image size. Intended for use with input images which are multiples of 32.
|
||||||
|
#
|
||||||
|
# This discriminator also includes provisions to pass an image at various downsample steps in directly. When this
|
||||||
|
# is done with a generator, it will allow much shorter gradient paths between the generator and discriminator. When
|
||||||
|
# no downsampled images are passed into the forward() pass, they will be automatically generated from the source
|
||||||
|
# image using interpolation.
|
||||||
|
#
|
||||||
|
# Uses spectral normalization rather than batch normalization.
|
||||||
|
def __init__(self, in_nc: int, nf: int, input_img_size: int, trunk_resblocks: int, skip_resblocks: int):
|
||||||
|
super(DiscriminatorResnet, self).__init__()
|
||||||
|
self.dimensionalize = nn.Conv2d(in_nc, nf, kernel_size=3, stride=1, padding=1, bias=True)
|
||||||
|
|
||||||
|
# Trunk resblocks are the important things to get right, so use those. 5=number of downsample layers.
|
||||||
|
total_resblocks = trunk_resblocks * 5
|
||||||
|
self.downsample1 = ResnetDownsampleLayer(in_nc, nf, 2, trunk_resblocks, skip_resblocks, total_resblocks)
|
||||||
|
self.downsample2 = ResnetDownsampleLayer(in_nc, nf*2, 2, trunk_resblocks, skip_resblocks, total_resblocks)
|
||||||
|
self.downsample3 = ResnetDownsampleLayer(in_nc, nf*4, 2, trunk_resblocks, skip_resblocks, total_resblocks)
|
||||||
|
# At the bottom layers, we cap the filter multiplier. We want this particular network to focus as much on the
|
||||||
|
# macro-details at higher image dimensionality as it does to the feature details.
|
||||||
|
self.downsample4 = ResnetDownsampleLayer(in_nc, nf*8, 1, trunk_resblocks, skip_resblocks, total_resblocks)
|
||||||
|
self.downsample5 = ResnetDownsampleLayer(in_nc, nf*8, 1, trunk_resblocks, skip_resblocks, total_resblocks)
|
||||||
|
self.downsamplers = [self.downsample1, self.downsample2, self.downsample3, self.downsample4, self.downsample5]
|
||||||
|
|
||||||
|
downsampled_image_size = input_img_size / 32
|
||||||
|
self.linear1 = nn.Linear(int(nf * 8 * downsampled_image_size * downsampled_image_size), 100)
|
||||||
|
self.linear2 = nn.Linear(100, 1)
|
||||||
|
|
||||||
|
# activation function
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||||
|
|
||||||
|
arch_util.initialize_weights([self.dimensionalize, self.linear1, self.linear2], 1)
|
||||||
|
|
||||||
|
def forward(self, x, skip_images=None):
|
||||||
|
if skip_images is None:
|
||||||
|
# Sythesize them from x.
|
||||||
|
skip_images = []
|
||||||
|
for i in range(len(self.downsamplers)):
|
||||||
|
m = 2 ** i
|
||||||
|
skip_images.append(F.interpolate(x, scale_factor=1 / m, mode='bilinear', align_corners=False))
|
||||||
|
|
||||||
|
fea = self.dimensionalize(x)
|
||||||
|
for skip, d in zip(skip_images, self.downsamplers):
|
||||||
|
fea = d(fea, skip)
|
||||||
|
|
||||||
|
fea = fea.view(fea.size(0), -1)
|
||||||
|
fea = self.lrelu(self.linear1(fea))
|
||||||
|
out = self.linear2(fea)
|
||||||
|
return out
|
|
@ -2,7 +2,16 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.init as init
|
import torch.nn.init as init
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch.nn.utils.spectral_norm as SpectralNorm
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
def scale_conv_weights_fixup(conv, residual_block_count, m=2):
|
||||||
|
k = conv.kernel_size[0]
|
||||||
|
n = conv.out_channels
|
||||||
|
scaling_factor = residual_block_count ** (-1.0 / (2 * m - 2))
|
||||||
|
sigma = sqrt(2 / (k * k * n)) * scaling_factor
|
||||||
|
conv.weight.data = conv.weight.data * sigma
|
||||||
|
return conv
|
||||||
|
|
||||||
def initialize_weights(net_l, scale=1):
|
def initialize_weights(net_l, scale=1):
|
||||||
if not isinstance(net_l, list):
|
if not isinstance(net_l, list):
|
||||||
|
@ -30,6 +39,89 @@ def make_layer(block, n_layers):
|
||||||
layers.append(block())
|
layers.append(block())
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def conv3x3(in_planes, out_planes, stride=1):
|
||||||
|
"""3x3 convolution with padding"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||||
|
padding=1, bias=False)
|
||||||
|
|
||||||
|
def conv1x1(in_planes, out_planes, stride=1):
|
||||||
|
"""1x1 convolution"""
|
||||||
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||||
|
|
||||||
|
class FixupBasicBlock(nn.Module):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(FixupBasicBlock, self).__init__()
|
||||||
|
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||||
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||||
|
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||||
|
self.conv2 = conv3x3(planes, planes)
|
||||||
|
self.scale = nn.Parameter(torch.ones(1))
|
||||||
|
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x + self.bias1a)
|
||||||
|
out = self.relu(out + self.bias1b)
|
||||||
|
|
||||||
|
out = self.conv2(out + self.bias2a)
|
||||||
|
out = out * self.scale + self.bias2b
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x + self.bias1a)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class FixupBottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(FixupBottleneck, self).__init__()
|
||||||
|
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||||
|
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||||
|
self.conv1 = conv1x1(inplanes, planes)
|
||||||
|
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||||
|
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||||
|
self.conv2 = conv3x3(planes, planes, stride)
|
||||||
|
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||||
|
self.bias3a = nn.Parameter(torch.zeros(1))
|
||||||
|
self.conv3 = conv1x1(planes, planes * self.expansion)
|
||||||
|
self.scale = nn.Parameter(torch.ones(1))
|
||||||
|
self.bias3b = nn.Parameter(torch.zeros(1))
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = downsample
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.conv1(x + self.bias1a)
|
||||||
|
out = self.relu(out + self.bias1b)
|
||||||
|
|
||||||
|
out = self.conv2(out + self.bias2a)
|
||||||
|
out = self.relu(out + self.bias2b)
|
||||||
|
|
||||||
|
out = self.conv3(out + self.bias3a)
|
||||||
|
out = out * self.scale + self.bias3b
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x + self.bias1a)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
'''Residual block with BN
|
'''Residual block with BN
|
||||||
---Conv-BN-ReLU-Conv-+-
|
---Conv-BN-ReLU-Conv-+-
|
||||||
|
@ -38,6 +130,7 @@ class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(self, nf=64):
|
def __init__(self, nf=64):
|
||||||
super(ResidualBlock, self).__init__()
|
super(ResidualBlock, self).__init__()
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
self.BN1 = nn.BatchNorm2d(nf)
|
self.BN1 = nn.BatchNorm2d(nf)
|
||||||
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
@ -48,10 +141,33 @@ class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
out = F.relu(self.BN1(self.conv1(x)), inplace=True)
|
out = self.lrelu(self.BN1(self.conv1(x)))
|
||||||
out = self.BN2(self.conv2(out))
|
out = self.BN2(self.conv2(out))
|
||||||
return identity + out
|
return identity + out
|
||||||
|
|
||||||
|
class ResidualBlockSpectralNorm(nn.Module):
|
||||||
|
'''Residual block with Spectral Normalization.
|
||||||
|
---SpecConv-ReLU-SpecConv-+-
|
||||||
|
|________________|
|
||||||
|
'''
|
||||||
|
|
||||||
|
def __init__(self, nf, total_residual_blocks):
|
||||||
|
super(ResidualBlockSpectralNorm, self).__init__()
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
|
self.conv1 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))
|
||||||
|
self.conv2 = SpectralNorm(nn.Conv2d(nf, nf, 3, 1, 1, bias=True))
|
||||||
|
|
||||||
|
# Initialize first.
|
||||||
|
initialize_weights([self.conv1, self.conv2], 1)
|
||||||
|
# Then perform fixup scaling
|
||||||
|
self.conv1 = scale_conv_weights_fixup(self.conv1, total_residual_blocks)
|
||||||
|
self.conv2 = scale_conv_weights_fixup(self.conv2, total_residual_blocks)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
identity = x
|
||||||
|
out = self.lrelu(self.conv1(x))
|
||||||
|
out = self.conv2(out)
|
||||||
|
return identity + out
|
||||||
|
|
||||||
class ResidualBlock_noBN(nn.Module):
|
class ResidualBlock_noBN(nn.Module):
|
||||||
'''Residual block w/o BN
|
'''Residual block w/o BN
|
||||||
|
@ -61,6 +177,7 @@ class ResidualBlock_noBN(nn.Module):
|
||||||
|
|
||||||
def __init__(self, nf=64):
|
def __init__(self, nf=64):
|
||||||
super(ResidualBlock_noBN, self).__init__()
|
super(ResidualBlock_noBN, self).__init__()
|
||||||
|
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||||
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
||||||
|
|
||||||
|
@ -69,7 +186,7 @@ class ResidualBlock_noBN(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
identity = x
|
identity = x
|
||||||
out = F.relu(self.conv1(x), inplace=True)
|
out = self.lrelu(self.conv1(x))
|
||||||
out = self.conv2(out)
|
out = self.conv2(out)
|
||||||
return identity + out
|
return identity + out
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import models.archs.SRResNet_arch as SRResNet_arch
|
import models.archs.SRResNet_arch as SRResNet_arch
|
||||||
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
||||||
|
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
||||||
import models.archs.RRDBNet_arch as RRDBNet_arch
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
||||||
import models.archs.EDVR_arch as EDVR_arch
|
import models.archs.EDVR_arch as EDVR_arch
|
||||||
import models.archs.HighToLowResNet as HighToLowResNet
|
import models.archs.HighToLowResNet as HighToLowResNet
|
||||||
|
@ -52,6 +53,9 @@ def define_D(opt):
|
||||||
|
|
||||||
if which_model == 'discriminator_vgg_128':
|
if which_model == 'discriminator_vgg_128':
|
||||||
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
||||||
|
elif which_model == 'discriminator_resnet':
|
||||||
|
netD = DiscriminatorResnet_arch.DiscriminatorResnet(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_size=img_sz,
|
||||||
|
trunk_resblocks=opt_net['trunk_resblocks'], skip_resblocks=opt_net['skip_resblocks'])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||||||
return netD
|
return netD
|
||||||
|
|
|
@ -16,7 +16,7 @@ datasets:
|
||||||
dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training
|
dataroot_LQ: E:\\4k6k\\datasets\\ultra_lowq\\for_training
|
||||||
mismatched_Data_OK: true
|
mismatched_Data_OK: true
|
||||||
use_shuffle: true
|
use_shuffle: true
|
||||||
n_workers: 4 # per GPU
|
n_workers: 8 # per GPU
|
||||||
batch_size: 32
|
batch_size: 32
|
||||||
target_size: 64
|
target_size: 64
|
||||||
use_flip: false
|
use_flip: false
|
||||||
|
@ -35,19 +35,21 @@ network_G:
|
||||||
in_nc: 3
|
in_nc: 3
|
||||||
out_nc: 3
|
out_nc: 3
|
||||||
nf: 32
|
nf: 32
|
||||||
ra_blocks: 5
|
ra_blocks: 3
|
||||||
assembler_blocks: 3
|
assembler_blocks: 2
|
||||||
|
|
||||||
network_D:
|
network_D:
|
||||||
which_model_D: discriminator_vgg_128
|
which_model_D: discriminator_resnet
|
||||||
in_nc: 3
|
in_nc: 3
|
||||||
nf: 64
|
nf: 32
|
||||||
|
trunk_resblocks: 3
|
||||||
|
skip_resblocks: 2
|
||||||
|
|
||||||
#### path
|
#### path
|
||||||
path:
|
path:
|
||||||
pretrain_model_G: ../experiments/corrupt_flatnet_G.pth
|
pretrain_model_G: ~
|
||||||
pretrain_model_D: ../experiments/corrupt_flatnet_D.pth
|
pretrain_model_D: ~
|
||||||
resume_state: ../experiments/corruptGAN_4k_lqprn_closeup_flat_net/training_state/3000.state
|
resume_state: ~
|
||||||
strict_load: true
|
strict_load: true
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
#### training settings: learning rate scheme, loss
|
||||||
|
@ -56,7 +58,7 @@ train:
|
||||||
weight_decay_G: 0
|
weight_decay_G: 0
|
||||||
beta1_G: 0.9
|
beta1_G: 0.9
|
||||||
beta2_G: 0.99
|
beta2_G: 0.99
|
||||||
lr_D: !!float 4e-5
|
lr_D: !!float 1e-5
|
||||||
weight_decay_D: 0
|
weight_decay_D: 0
|
||||||
beta1_D: 0.9
|
beta1_D: 0.9
|
||||||
beta2_D: 0.99
|
beta2_D: 0.99
|
||||||
|
@ -71,11 +73,11 @@ train:
|
||||||
pixel_weight: !!float 1e-2
|
pixel_weight: !!float 1e-2
|
||||||
feature_criterion: l1
|
feature_criterion: l1
|
||||||
feature_weight: 0
|
feature_weight: 0
|
||||||
gan_type: ragan # gan | ragan
|
gan_type: gan # gan | ragan
|
||||||
gan_weight: !!float 1e-1
|
gan_weight: !!float 1e-1
|
||||||
|
|
||||||
D_update_ratio: 1
|
D_update_ratio: 1
|
||||||
D_init_iters: 0
|
D_init_iters: 1500
|
||||||
|
|
||||||
manual_seed: 10
|
manual_seed: 10
|
||||||
val_freq: !!float 5e2
|
val_freq: !!float 5e2
|
||||||
|
|
Loading…
Reference in New Issue
Block a user