Implement ResGen arch

This is a simpler resnet-based generator which performs mutations
on an input interspersed with interpolate-upsampling. It is a two
part generator:
1) A component that "fixes" LQ images with a long string of resnet
    blocks. This component is intended to remove compression artifacts
    and other noise from a LQ image.
2) A component that can double the image size. The idea is that this
    component be trained so that it can work at most reasonable
    resolutions, such that it can be repeatedly applied to itself to
    perform multiple upsamples.

The motivation here is to simplify what is being done inside of RRDB.
I don't believe the complexity inside of that network is justified.
This commit is contained in:
James Betker 2020-05-05 11:59:46 -06:00
parent 9f4581aacb
commit 3cd85f8073
6 changed files with 243 additions and 24 deletions

View File

@ -0,0 +1,141 @@
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv5x5(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
padding=2, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class FixupBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, conv_create=conv3x3):
super(FixupBasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.bias1a = nn.Parameter(torch.zeros(1))
self.conv1 = conv_create(inplanes, planes, stride)
self.bias1b = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bias2a = nn.Parameter(torch.zeros(1))
self.conv2 = conv_create(planes, planes)
self.scale = nn.Parameter(torch.ones(1))
self.bias2b = nn.Parameter(torch.zeros(1))
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
out = self.lrelu(out + self.bias1b)
out = self.conv2(out + self.bias2a)
out = out * self.scale + self.bias2b
if self.downsample is not None:
identity = self.downsample(x + self.bias1a)
out += identity
out = self.lrelu(out)
return out
class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64):
super(FixupResNet, self).__init__()
self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling.
self.inplanes = num_filters
# Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated
# part of the block.
self.conv1 = nn.Conv2d(3, num_filters, kernel_size=5, stride=1, padding=2,
bias=False)
self.bias1 = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1)
self.skip1 = nn.Conv2d(num_filters, 3, kernel_size=5, stride=1, padding=2, bias=False)
self.skip1_bias = nn.Parameter(torch.zeros(1))
# Part 2 - This is the upsampler core. It consists of a normal multiplicative conv followed by several residual
# convs which are intended to repair artifacts caused by 2x interpolation.
# This core layer should by itself accomplish 2x super-resolution. We use it in repeat to do the
# requested SR.
nf2 = int(num_filters/4)
# This part isn't repeated. It de-filters the output from the previous step to fit the filter size used in the
# upsampler-conv.
self.upsampler_conv = nn.Conv2d(num_filters, nf2, kernel_size=3, stride=1, padding=1, bias=False)
self.uc_bias = nn.Parameter(torch.zeros(1))
self.inplanes = nf2
# This is the repeated part.
self.layer2 = self._make_layer(block, int(nf2), layers[1], stride=1, conv_type=conv5x5)
self.skip2 = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=False)
self.skip2_bias = nn.Parameter(torch.zeros(1))
self.final_defilter = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=True)
self.bias2 = nn.Parameter(torch.zeros(1))
for m in self.modules():
if isinstance(m, FixupBasicBlock):
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
nn.init.constant_(m.conv2.weight, 0)
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
'''
elif isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)'''
def _make_layer(self, block, planes, blocks, stride=1, conv_type=conv3x3):
defilter = None
if self.inplanes != planes * block.expansion:
defilter = conv1x1(self.inplanes, planes * block.expansion, stride)
layers = []
layers.append(block(self.inplanes, planes, stride, defilter))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, conv_create=conv_type))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.lrelu(x + self.bias1)
x = self.layer1(x)
skip_lo = self.skip1(x) + self.skip1_bias
x = self.lrelu(self.upsampler_conv(x) + self.uc_bias)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.layer2(x)
skip_med = self.skip2(x) + self.skip2_bias
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.layer2(x)
x = self.final_defilter(x) + self.bias2
return x, skip_med, skip_lo
def fixup_resnet34(**kwargs):
"""Constructs a Fixup-ResNet-34 model.
"""
model = FixupResNet(FixupBasicBlock, [2, 28], **kwargs)
return model
__all__ = ['FixupResNet', 'fixup_resnet34']

View File

@ -10,6 +10,7 @@ import models.archs.RRDBNetXL_arch as RRDBNetXL_arch
import models.archs.HighToLowResNet as HighToLowResNet
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
import models.archs.arch_util as arch_utils
import models.archs.ResGen_arch as ResGen_arch
import math
# Generator
@ -32,6 +33,9 @@ def define_G(opt):
netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
interpolation_scale_factor=scale_per_step)
elif which_model == 'ResGen':
netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf'])
# image corruption
elif which_model == 'HighToLowResNet':
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],

View File

@ -17,7 +17,7 @@ datasets:
use_shuffle: true
n_workers: 16 # per GPU
batch_size: 32
batch_size: 16
target_size: 128
use_flip: true
use_rot: true
@ -30,15 +30,11 @@ datasets:
#### network structures
network_G:
which_model_G: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
which_model_G: ResGen
nf: 256
network_D:
which_model_D: discriminator_resnet
in_nc: 3
nf: 64
which_model_D: discriminator_resnet_passthrough
nf: 42
#### path
path:
@ -62,6 +58,7 @@ train:
warmup_iter: -1 # no warm up
lr_steps: [50000, 100000, 200000, 300000]
lr_gamma: 0.5
mega_batch_factor: 1
pixel_criterion: l1
pixel_weight: !!float 1e-2

View File

@ -16,8 +16,8 @@ datasets:
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
doCrop: false
use_shuffle: true
n_workers: 12 # per GPU
batch_size: 24
n_workers: 10 # per GPU
batch_size: 16
target_size: 256
color: RGB
val:
@ -28,16 +28,10 @@ datasets:
#### network structures
network_G:
which_model_G: RRDBNetXL
in_nc: 3
out_nc: 3
nf: 64
nblo: 18
nbmed: 8
nbhi: 6
which_model_G: ResGen
nf: 256
network_D:
which_model_D: discriminator_resnet_passthrough
in_nc: 3
nf: 42
#### path
@ -49,11 +43,11 @@ path:
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 2e-4
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 4e-4
lr_D: !!float 2e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
@ -63,7 +57,7 @@ train:
warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 50000, 60000]
lr_gamma: 0.5
mega_batch_factor: 3
mega_batch_factor: 2
pixel_criterion: l1
pixel_weight: !!float 1e-2

View File

@ -0,0 +1,83 @@
#### general settings
name: esrgan_res
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O1
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
use_shuffle: true
n_workers: 10 # per GPU
batch_size: 24
target_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: div2kval
mode: LQGT
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
#### network structures
network_G:
which_model_G: ResGen
nf: 256
network_D:
which_model_D: discriminator_resnet_passthrough
nf: 42
#### path
path:
#pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
#pretrain_model_D: ~
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 50000, 60000]
lr_gamma: 0.5
mega_batch_factor: 2
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
feature_weight_decay: 1
feature_weight_decay_steps: 500
feature_weight_minimum: 1
gan_type: gan # gan | ragan
gan_weight: !!float 1e-2
D_update_ratio: 1
D_init_iters: -1
manual_seed: 10
val_freq: !!float 5e2
#### logger
logger:
print_freq: 50
save_checkpoint_freq: !!float 5e2

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked_xl.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_res.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
@ -147,7 +147,7 @@ def main():
current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers
else:
current_step = -1
current_step = 0
start_epoch = 0
#### training