forked from mrq/DL-Art-School
Implement ResGen arch
This is a simpler resnet-based generator which performs mutations on an input interspersed with interpolate-upsampling. It is a two part generator: 1) A component that "fixes" LQ images with a long string of resnet blocks. This component is intended to remove compression artifacts and other noise from a LQ image. 2) A component that can double the image size. The idea is that this component be trained so that it can work at most reasonable resolutions, such that it can be repeatedly applied to itself to perform multiple upsamples. The motivation here is to simplify what is being done inside of RRDB. I don't believe the complexity inside of that network is justified.
This commit is contained in:
parent
9f4581aacb
commit
3cd85f8073
141
codes/models/archs/ResGen_arch.py
Normal file
141
codes/models/archs/ResGen_arch.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152']
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
def conv5x5(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
|
||||
padding=2, bias=False)
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class FixupBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, conv_create=conv3x3):
|
||||
super(FixupBasicBlock, self).__init__()
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.bias1a = nn.Parameter(torch.zeros(1))
|
||||
self.conv1 = conv_create(inplanes, planes, stride)
|
||||
self.bias1b = nn.Parameter(torch.zeros(1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
||||
self.bias2a = nn.Parameter(torch.zeros(1))
|
||||
self.conv2 = conv_create(planes, planes)
|
||||
self.scale = nn.Parameter(torch.ones(1))
|
||||
self.bias2b = nn.Parameter(torch.zeros(1))
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x + self.bias1a)
|
||||
out = self.lrelu(out + self.bias1b)
|
||||
|
||||
out = self.conv2(out + self.bias2a)
|
||||
out = out * self.scale + self.bias2b
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x + self.bias1a)
|
||||
|
||||
out += identity
|
||||
out = self.lrelu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class FixupResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_filters=64):
|
||||
super(FixupResNet, self).__init__()
|
||||
self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling.
|
||||
self.inplanes = num_filters
|
||||
# Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated
|
||||
# part of the block.
|
||||
self.conv1 = nn.Conv2d(3, num_filters, kernel_size=5, stride=1, padding=2,
|
||||
bias=False)
|
||||
self.bias1 = nn.Parameter(torch.zeros(1))
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1)
|
||||
self.skip1 = nn.Conv2d(num_filters, 3, kernel_size=5, stride=1, padding=2, bias=False)
|
||||
self.skip1_bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
# Part 2 - This is the upsampler core. It consists of a normal multiplicative conv followed by several residual
|
||||
# convs which are intended to repair artifacts caused by 2x interpolation.
|
||||
# This core layer should by itself accomplish 2x super-resolution. We use it in repeat to do the
|
||||
# requested SR.
|
||||
nf2 = int(num_filters/4)
|
||||
# This part isn't repeated. It de-filters the output from the previous step to fit the filter size used in the
|
||||
# upsampler-conv.
|
||||
self.upsampler_conv = nn.Conv2d(num_filters, nf2, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.uc_bias = nn.Parameter(torch.zeros(1))
|
||||
self.inplanes = nf2
|
||||
|
||||
# This is the repeated part.
|
||||
self.layer2 = self._make_layer(block, int(nf2), layers[1], stride=1, conv_type=conv5x5)
|
||||
self.skip2 = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=False)
|
||||
self.skip2_bias = nn.Parameter(torch.zeros(1))
|
||||
|
||||
self.final_defilter = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=True)
|
||||
self.bias2 = nn.Parameter(torch.zeros(1))
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, FixupBasicBlock):
|
||||
nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
|
||||
nn.init.constant_(m.conv2.weight, 0)
|
||||
if m.downsample is not None:
|
||||
nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:]))))
|
||||
'''
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.constant_(m.weight, 0)
|
||||
nn.init.constant_(m.bias, 0)'''
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, conv_type=conv3x3):
|
||||
defilter = None
|
||||
if self.inplanes != planes * block.expansion:
|
||||
defilter = conv1x1(self.inplanes, planes * block.expansion, stride)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, defilter))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, conv_create=conv_type))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.lrelu(x + self.bias1)
|
||||
x = self.layer1(x)
|
||||
skip_lo = self.skip1(x) + self.skip1_bias
|
||||
|
||||
x = self.lrelu(self.upsampler_conv(x) + self.uc_bias)
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.layer2(x)
|
||||
skip_med = self.skip2(x) + self.skip2_bias
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
x = self.layer2(x)
|
||||
x = self.final_defilter(x) + self.bias2
|
||||
return x, skip_med, skip_lo
|
||||
|
||||
def fixup_resnet34(**kwargs):
|
||||
"""Constructs a Fixup-ResNet-34 model.
|
||||
"""
|
||||
model = FixupResNet(FixupBasicBlock, [2, 28], **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
__all__ = ['FixupResNet', 'fixup_resnet34']
|
|
@ -10,6 +10,7 @@ import models.archs.RRDBNetXL_arch as RRDBNetXL_arch
|
|||
import models.archs.HighToLowResNet as HighToLowResNet
|
||||
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
|
||||
import models.archs.arch_util as arch_utils
|
||||
import models.archs.ResGen_arch as ResGen_arch
|
||||
import math
|
||||
|
||||
# Generator
|
||||
|
@ -32,6 +33,9 @@ def define_G(opt):
|
|||
netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
|
||||
interpolation_scale_factor=scale_per_step)
|
||||
elif which_model == 'ResGen':
|
||||
netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf'])
|
||||
|
||||
# image corruption
|
||||
elif which_model == 'HighToLowResNet':
|
||||
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
||||
|
|
|
@ -17,7 +17,7 @@ datasets:
|
|||
|
||||
use_shuffle: true
|
||||
n_workers: 16 # per GPU
|
||||
batch_size: 32
|
||||
batch_size: 16
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
|
@ -30,15 +30,11 @@ datasets:
|
|||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: RRDBNet
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nb: 23
|
||||
which_model_G: ResGen
|
||||
nf: 256
|
||||
network_D:
|
||||
which_model_D: discriminator_resnet
|
||||
in_nc: 3
|
||||
nf: 64
|
||||
which_model_D: discriminator_resnet_passthrough
|
||||
nf: 42
|
||||
|
||||
#### path
|
||||
path:
|
||||
|
@ -62,6 +58,7 @@ train:
|
|||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [50000, 100000, 200000, 300000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 1
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
|
|
|
@ -16,8 +16,8 @@ datasets:
|
|||
dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted
|
||||
doCrop: false
|
||||
use_shuffle: true
|
||||
n_workers: 12 # per GPU
|
||||
batch_size: 24
|
||||
n_workers: 10 # per GPU
|
||||
batch_size: 16
|
||||
target_size: 256
|
||||
color: RGB
|
||||
val:
|
||||
|
@ -28,16 +28,10 @@ datasets:
|
|||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: RRDBNetXL
|
||||
in_nc: 3
|
||||
out_nc: 3
|
||||
nf: 64
|
||||
nblo: 18
|
||||
nbmed: 8
|
||||
nbhi: 6
|
||||
which_model_G: ResGen
|
||||
nf: 256
|
||||
network_D:
|
||||
which_model_D: discriminator_resnet_passthrough
|
||||
in_nc: 3
|
||||
nf: 42
|
||||
|
||||
#### path
|
||||
|
@ -49,11 +43,11 @@ path:
|
|||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 2e-4
|
||||
lr_G: !!float 1e-4
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 4e-4
|
||||
lr_D: !!float 2e-4
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
|
@ -63,7 +57,7 @@ train:
|
|||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [20000, 40000, 50000, 60000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 3
|
||||
mega_batch_factor: 2
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
|
|
83
codes/options/train/train_ESRGAN_res.yml
Normal file
83
codes/options/train/train_ESRGAN_res.yml
Normal file
|
@ -0,0 +1,83 @@
|
|||
#### general settings
|
||||
name: esrgan_res
|
||||
use_tb_logger: true
|
||||
model: srgan
|
||||
distortion: sr
|
||||
scale: 4
|
||||
gpu_ids: [0]
|
||||
amp_opt_level: O1
|
||||
|
||||
#### datasets
|
||||
datasets:
|
||||
train:
|
||||
name: DIV2K
|
||||
mode: LQGT
|
||||
dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub
|
||||
dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4
|
||||
|
||||
use_shuffle: true
|
||||
n_workers: 10 # per GPU
|
||||
batch_size: 24
|
||||
target_size: 128
|
||||
use_flip: true
|
||||
use_rot: true
|
||||
color: RGB
|
||||
val:
|
||||
name: div2kval
|
||||
mode: LQGT
|
||||
dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr
|
||||
dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic
|
||||
|
||||
#### network structures
|
||||
network_G:
|
||||
which_model_G: ResGen
|
||||
nf: 256
|
||||
network_D:
|
||||
which_model_D: discriminator_resnet_passthrough
|
||||
nf: 42
|
||||
|
||||
#### path
|
||||
path:
|
||||
#pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth
|
||||
#pretrain_model_D: ~
|
||||
strict_load: true
|
||||
resume_state: ~
|
||||
|
||||
#### training settings: learning rate scheme, loss
|
||||
train:
|
||||
lr_G: !!float 1e-4
|
||||
weight_decay_G: 0
|
||||
beta1_G: 0.9
|
||||
beta2_G: 0.99
|
||||
lr_D: !!float 1e-4
|
||||
weight_decay_D: 0
|
||||
beta1_D: 0.9
|
||||
beta2_D: 0.99
|
||||
lr_scheme: MultiStepLR
|
||||
|
||||
niter: 400000
|
||||
warmup_iter: -1 # no warm up
|
||||
lr_steps: [20000, 40000, 50000, 60000]
|
||||
lr_gamma: 0.5
|
||||
mega_batch_factor: 2
|
||||
|
||||
pixel_criterion: l1
|
||||
pixel_weight: !!float 1e-2
|
||||
feature_criterion: l1
|
||||
feature_weight: 1
|
||||
feature_weight_decay: 1
|
||||
feature_weight_decay_steps: 500
|
||||
feature_weight_minimum: 1
|
||||
gan_type: gan # gan | ragan
|
||||
gan_weight: !!float 1e-2
|
||||
|
||||
D_update_ratio: 1
|
||||
D_init_iters: -1
|
||||
|
||||
manual_seed: 10
|
||||
val_freq: !!float 5e2
|
||||
|
||||
#### logger
|
||||
logger:
|
||||
print_freq: 50
|
||||
save_checkpoint_freq: !!float 5e2
|
|
@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs):
|
|||
def main():
|
||||
#### options
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked_xl.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_res.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||
help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
|
@ -147,7 +147,7 @@ def main():
|
|||
current_step = resume_state['iter']
|
||||
model.resume_training(resume_state) # handle optimizers and schedulers
|
||||
else:
|
||||
current_step = -1
|
||||
current_step = 0
|
||||
start_epoch = 0
|
||||
|
||||
#### training
|
||||
|
|
Loading…
Reference in New Issue
Block a user