Implement ResGen arch

This is a simpler resnet-based generator which performs mutations
on an input interspersed with interpolate-upsampling. It is a two
part generator:
1) A component that "fixes" LQ images with a long string of resnet
    blocks. This component is intended to remove compression artifacts
    and other noise from a LQ image.
2) A component that can double the image size. The idea is that this
    component be trained so that it can work at most reasonable
    resolutions, such that it can be repeatedly applied to itself to
    perform multiple upsamples.

The motivation here is to simplify what is being done inside of RRDB.
I don't believe the complexity inside of that network is justified.
This commit is contained in:
James Betker 2020-05-05 11:59:46 -06:00
parent 9f4581aacb
commit 3cd85f8073
6 changed files with 243 additions and 24 deletions

View File

@ -0,0 +1,141 @@
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
def conv5x5(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
padding=2, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class FixupBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, conv_create=conv3x3):
super(FixupBasicBlock, self).__init__()
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.bias1a = nn.Parameter(torch.zeros(1))
self.conv1 = conv_create(inplanes, planes, stride)
self.bias1b = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.bias2a = nn.Parameter(torch.zeros(1))
self.conv2 = conv_create(planes, planes)
self.scale = nn.Parameter(torch.ones(1))
self.bias2b = nn.Parameter(torch.zeros(1))
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x + self.bias1a)
out = self.lrelu(out + self.bias1b)
out = self.conv2(out + self.bias2a)
out = out * self.scale + self.bias2b
if self.downsample is not None:
identity = self.downsample(x + self.bias1a)
out += identity
out = self.lrelu(out)
return out
class FixupResNet(nn.Module):
def __init__(self, block, layers, num_filters=64):
super(FixupResNet, self).__init__()
self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling.
self.inplanes = num_filters
# Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated
# part of the block.
self.conv1 = nn.Conv2d(3, num_filters, kernel_size=5, stride=1, padding=2,
bias=False)
self.bias1 = nn.Parameter(torch.zeros(1))
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1)
self.skip1 = nn.Conv2d(num_filters, 3, kernel_size=5, stride=1, padding=2, bias=False)
self.skip1_bias = nn.Parameter(torch.zeros(1))
# Part 2 - This is the upsampler core. It consists of a normal multiplicative conv followed by several residual
# convs which are intended to repair artifacts caused by 2x interpolation.
# This core layer should by itself accomplish 2x super-resolution. We use it in repeat to do the
# requested SR.
nf2 = int(num_filters/4)
# This part isn't repeated. It de-filters the output from the previous step to fit the filter size used in the
# upsampler-conv.
self.upsampler_conv = nn.Conv2d(num_filters, nf2, kernel_size=3, stride=1, padding=1, bias=False)
self.uc_bias = nn.Parameter(torch.zeros(1))
self.inplanes = nf2
# This is the repeated part.
self.layer2 = self._make_layer(block, int(nf2), layers[1], stride=1, conv_type=conv5x5)
self.skip2 = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=False)
self.skip2_bias = nn.Parameter(torch.zeros(1))
self.final_defilter = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=True)
self.bias2 = nn.Parameter(torch.zeros(1))
for m in self.modules():
if isinstance(m, FixupBasicBlock):
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
nn.init.constant_(m.conv2.weight, 0)
if m.downsample is not None:
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
'''
elif isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)'''
def _make_layer(self, block, planes, blocks, stride=1, conv_type=conv3x3):
defilter = None
if self.inplanes != planes * block.expansion:
defilter = conv1x1(self.inplanes, planes * block.expansion, stride)
layers = []
layers.append(block(self.inplanes, planes, stride, defilter))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, conv_create=conv_type))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.lrelu(x + self.bias1)
x = self.layer1(x)
skip_lo = self.skip1(x) + self.skip1_bias
x = self.lrelu(self.upsampler_conv(x) + self.uc_bias)
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.layer2(x)
skip_med = self.skip2(x) + self.skip2_bias
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.layer2(x)
x = self.final_defilter(x) + self.bias2
return x, skip_med, skip_lo
def fixup_resnet34(**kwargs):
"""Constructs a Fixup-ResNet-34 model.
"""
model = FixupResNet(FixupBasicBlock, [2, 28], **kwargs)
return model
__all__ = ['FixupResNet', 'fixup_resnet34']

View File

@ -10,6 +10,7 @@ import models.archs.RRDBNetXL_arch as RRDBNetXL_arch
import models.archs.HighToLowResNet as HighToLowResNet import models.archs.HighToLowResNet as HighToLowResNet
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
import models.archs.arch_util as arch_utils import models.archs.arch_util as arch_utils
import models.archs.ResGen_arch as ResGen_arch
import math import math
# Generator # Generator
@ -32,6 +33,9 @@ def define_G(opt):
netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'], nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
interpolation_scale_factor=scale_per_step) interpolation_scale_factor=scale_per_step)
elif which_model == 'ResGen':
netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf'])
# image corruption # image corruption
elif which_model == 'HighToLowResNet': elif which_model == 'HighToLowResNet':
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],

View File

@ -17,7 +17,7 @@ datasets:
use_shuffle: true use_shuffle: true
n_workers: 16 # per GPU n_workers: 16 # per GPU
batch_size: 32 batch_size: 16
target_size: 128 target_size: 128
use_flip: true use_flip: true
use_rot: true use_rot: true
@ -30,15 +30,11 @@ datasets:
#### network structures #### network structures
network_G: network_G:
which_model_G: RRDBNet which_model_G: ResGen
in_nc: 3 nf: 256
out_nc: 3
nf: 64
nb: 23
network_D: network_D:
which_model_D: discriminator_resnet which_model_D: discriminator_resnet_passthrough
in_nc: 3 nf: 42
nf: 64
#### path #### path
path: path:
@ -62,6 +58,7 @@ train:
warmup_iter: -1 # no warm up warmup_iter: -1 # no warm up
lr_steps: [50000, 100000, 200000, 300000] lr_steps: [50000, 100000, 200000, 300000]
lr_gamma: 0.5 lr_gamma: 0.5
mega_batch_factor: 1
pixel_criterion: l1 pixel_criterion: l1
pixel_weight: !!float 1e-2 pixel_weight: !!float 1e-2

View File

@ -16,8 +16,8 @@ datasets:
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
doCrop: false doCrop: false
use_shuffle: true use_shuffle: true
n_workers: 12 # per GPU n_workers: 10 # per GPU
batch_size: 24 batch_size: 16
target_size: 256 target_size: 256
color: RGB color: RGB
val: val:
@ -28,16 +28,10 @@ datasets:
#### network structures #### network structures
network_G: network_G:
which_model_G: RRDBNetXL which_model_G: ResGen
in_nc: 3 nf: 256
out_nc: 3
nf: 64
nblo: 18
nbmed: 8
nbhi: 6
network_D: network_D:
which_model_D: discriminator_resnet_passthrough which_model_D: discriminator_resnet_passthrough
in_nc: 3
nf: 42 nf: 42
#### path #### path
@ -49,11 +43,11 @@ path:
#### training settings: learning rate scheme, loss #### training settings: learning rate scheme, loss
train: train:
lr_G: !!float 2e-4 lr_G: !!float 1e-4
weight_decay_G: 0 weight_decay_G: 0
beta1_G: 0.9 beta1_G: 0.9
beta2_G: 0.99 beta2_G: 0.99
lr_D: !!float 4e-4 lr_D: !!float 2e-4
weight_decay_D: 0 weight_decay_D: 0
beta1_D: 0.9 beta1_D: 0.9
beta2_D: 0.99 beta2_D: 0.99
@ -63,7 +57,7 @@ train:
warmup_iter: -1 # no warm up warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 50000, 60000] lr_steps: [20000, 40000, 50000, 60000]
lr_gamma: 0.5 lr_gamma: 0.5
mega_batch_factor: 3 mega_batch_factor: 2
pixel_criterion: l1 pixel_criterion: l1
pixel_weight: !!float 1e-2 pixel_weight: !!float 1e-2

View File

@ -0,0 +1,83 @@
#### general settings
name: esrgan_res
use_tb_logger: true
model: srgan
distortion: sr
scale: 4
gpu_ids: [0]
amp_opt_level: O1
#### datasets
datasets:
train:
name: DIV2K
mode: LQGT
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
use_shuffle: true
n_workers: 10 # per GPU
batch_size: 24
target_size: 128
use_flip: true
use_rot: true
color: RGB
val:
name: div2kval
mode: LQGT
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
#### network structures
network_G:
which_model_G: ResGen
nf: 256
network_D:
which_model_D: discriminator_resnet_passthrough
nf: 42
#### path
path:
#pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
#pretrain_model_D: ~
strict_load: true
resume_state: ~
#### training settings: learning rate scheme, loss
train:
lr_G: !!float 1e-4
weight_decay_G: 0
beta1_G: 0.9
beta2_G: 0.99
lr_D: !!float 1e-4
weight_decay_D: 0
beta1_D: 0.9
beta2_D: 0.99
lr_scheme: MultiStepLR
niter: 400000
warmup_iter: -1 # no warm up
lr_steps: [20000, 40000, 50000, 60000]
lr_gamma: 0.5
mega_batch_factor: 2
pixel_criterion: l1
pixel_weight: !!float 1e-2
feature_criterion: l1
feature_weight: 1
feature_weight_decay: 1
feature_weight_decay_steps: 500
feature_weight_minimum: 1
gan_type: gan # gan | ragan
gan_weight: !!float 1e-2
D_update_ratio: 1
D_init_iters: -1
manual_seed: 10
val_freq: !!float 5e2
#### logger
logger:
print_freq: 50
save_checkpoint_freq: !!float 5e2

View File

@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked_xl.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_res.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
@ -147,7 +147,7 @@ def main():
current_step = resume_state['iter'] current_step = resume_state['iter']
model.resume_training(resume_state) # handle optimizers and schedulers model.resume_training(resume_state) # handle optimizers and schedulers
else: else:
current_step = -1 current_step = 0
start_epoch = 0 start_epoch = 0
#### training #### training