diff --git a/codes/models/archs/ResGen_arch.py b/codes/models/archs/ResGen_arch.py new file mode 100644 index 00000000..1d908a36 --- /dev/null +++ b/codes/models/archs/ResGen_arch.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn +import numpy as np +import torch.nn.functional as F + + +__all__ = ['FixupResNet', 'fixup_resnet18', 'fixup_resnet34', 'fixup_resnet50', 'fixup_resnet101', 'fixup_resnet152'] + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +def conv5x5(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, + padding=2, bias=False) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class FixupBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, conv_create=conv3x3): + super(FixupBasicBlock, self).__init__() + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.bias1a = nn.Parameter(torch.zeros(1)) + self.conv1 = conv_create(inplanes, planes, stride) + self.bias1b = nn.Parameter(torch.zeros(1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + self.bias2a = nn.Parameter(torch.zeros(1)) + self.conv2 = conv_create(planes, planes) + self.scale = nn.Parameter(torch.ones(1)) + self.bias2b = nn.Parameter(torch.zeros(1)) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x + self.bias1a) + out = self.lrelu(out + self.bias1b) + + out = self.conv2(out + self.bias2a) + out = out * self.scale + self.bias2b + + if self.downsample is not None: + identity = self.downsample(x + self.bias1a) + + out += identity + out = self.lrelu(out) + + return out + + +class FixupResNet(nn.Module): + + def __init__(self, block, layers, num_filters=64): + super(FixupResNet, self).__init__() + self.num_layers = sum(layers) + layers[-1] # The last layer is applied twice to achieve 4x upsampling. + self.inplanes = num_filters + # Part 1 - Process raw input image. Most denoising should appear here and this should be the most complicated + # part of the block. + self.conv1 = nn.Conv2d(3, num_filters, kernel_size=5, stride=1, padding=2, + bias=False) + self.bias1 = nn.Parameter(torch.zeros(1)) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.layer1 = self._make_layer(block, num_filters, layers[0], stride=1) + self.skip1 = nn.Conv2d(num_filters, 3, kernel_size=5, stride=1, padding=2, bias=False) + self.skip1_bias = nn.Parameter(torch.zeros(1)) + + # Part 2 - This is the upsampler core. It consists of a normal multiplicative conv followed by several residual + # convs which are intended to repair artifacts caused by 2x interpolation. + # This core layer should by itself accomplish 2x super-resolution. We use it in repeat to do the + # requested SR. + nf2 = int(num_filters/4) + # This part isn't repeated. It de-filters the output from the previous step to fit the filter size used in the + # upsampler-conv. + self.upsampler_conv = nn.Conv2d(num_filters, nf2, kernel_size=3, stride=1, padding=1, bias=False) + self.uc_bias = nn.Parameter(torch.zeros(1)) + self.inplanes = nf2 + + # This is the repeated part. + self.layer2 = self._make_layer(block, int(nf2), layers[1], stride=1, conv_type=conv5x5) + self.skip2 = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=False) + self.skip2_bias = nn.Parameter(torch.zeros(1)) + + self.final_defilter = nn.Conv2d(nf2, 3, kernel_size=5, stride=1, padding=2, bias=True) + self.bias2 = nn.Parameter(torch.zeros(1)) + + for m in self.modules(): + if isinstance(m, FixupBasicBlock): + nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5)) + nn.init.constant_(m.conv2.weight, 0) + if m.downsample is not None: + nn.init.normal_(m.downsample.weight, mean=0, std=np.sqrt(2 / (m.downsample.weight.shape[0] * np.prod(m.downsample.weight.shape[2:])))) + ''' + elif isinstance(m, nn.Linear): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0)''' + + def _make_layer(self, block, planes, blocks, stride=1, conv_type=conv3x3): + defilter = None + if self.inplanes != planes * block.expansion: + defilter = conv1x1(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, defilter)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, conv_create=conv_type)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.lrelu(x + self.bias1) + x = self.layer1(x) + skip_lo = self.skip1(x) + self.skip1_bias + + x = self.lrelu(self.upsampler_conv(x) + self.uc_bias) + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.layer2(x) + skip_med = self.skip2(x) + self.skip2_bias + x = F.interpolate(x, scale_factor=2, mode='nearest') + x = self.layer2(x) + x = self.final_defilter(x) + self.bias2 + return x, skip_med, skip_lo + +def fixup_resnet34(**kwargs): + """Constructs a Fixup-ResNet-34 model. + """ + model = FixupResNet(FixupBasicBlock, [2, 28], **kwargs) + return model + + +__all__ = ['FixupResNet', 'fixup_resnet34'] \ No newline at end of file diff --git a/codes/models/networks.py b/codes/models/networks.py index 0e66bb61..c23a69c0 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -10,6 +10,7 @@ import models.archs.RRDBNetXL_arch as RRDBNetXL_arch import models.archs.HighToLowResNet as HighToLowResNet import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch import models.archs.arch_util as arch_utils +import models.archs.ResGen_arch as ResGen_arch import math # Generator @@ -32,6 +33,9 @@ def define_G(opt): netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'], interpolation_scale_factor=scale_per_step) + elif which_model == 'ResGen': + netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf']) + # image corruption elif which_model == 'HighToLowResNet': netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], diff --git a/codes/options/train/train_ESRGAN.yml b/codes/options/train/train_ESRGAN.yml index 39865580..8a825b2b 100644 --- a/codes/options/train/train_ESRGAN.yml +++ b/codes/options/train/train_ESRGAN.yml @@ -17,7 +17,7 @@ datasets: use_shuffle: true n_workers: 16 # per GPU - batch_size: 32 + batch_size: 16 target_size: 128 use_flip: true use_rot: true @@ -30,15 +30,11 @@ datasets: #### network structures network_G: - which_model_G: RRDBNet - in_nc: 3 - out_nc: 3 - nf: 64 - nb: 23 + which_model_G: ResGen + nf: 256 network_D: - which_model_D: discriminator_resnet - in_nc: 3 - nf: 64 + which_model_D: discriminator_resnet_passthrough + nf: 42 #### path path: @@ -62,6 +58,7 @@ train: warmup_iter: -1 # no warm up lr_steps: [50000, 100000, 200000, 300000] lr_gamma: 0.5 + mega_batch_factor: 1 pixel_criterion: l1 pixel_weight: !!float 1e-2 diff --git a/codes/options/train/train_ESRGAN_blacked_xl.yml b/codes/options/train/train_ESRGAN_blacked_xl.yml index b4526337..39467615 100644 --- a/codes/options/train/train_ESRGAN_blacked_xl.yml +++ b/codes/options/train/train_ESRGAN_blacked_xl.yml @@ -16,8 +16,8 @@ datasets: dataroot_LQ: K:\4k6k\4k_closeup\lr_corrupted doCrop: false use_shuffle: true - n_workers: 12 # per GPU - batch_size: 24 + n_workers: 10 # per GPU + batch_size: 16 target_size: 256 color: RGB val: @@ -28,16 +28,10 @@ datasets: #### network structures network_G: - which_model_G: RRDBNetXL - in_nc: 3 - out_nc: 3 - nf: 64 - nblo: 18 - nbmed: 8 - nbhi: 6 + which_model_G: ResGen + nf: 256 network_D: which_model_D: discriminator_resnet_passthrough - in_nc: 3 nf: 42 #### path @@ -49,11 +43,11 @@ path: #### training settings: learning rate scheme, loss train: - lr_G: !!float 2e-4 + lr_G: !!float 1e-4 weight_decay_G: 0 beta1_G: 0.9 beta2_G: 0.99 - lr_D: !!float 4e-4 + lr_D: !!float 2e-4 weight_decay_D: 0 beta1_D: 0.9 beta2_D: 0.99 @@ -63,7 +57,7 @@ train: warmup_iter: -1 # no warm up lr_steps: [20000, 40000, 50000, 60000] lr_gamma: 0.5 - mega_batch_factor: 3 + mega_batch_factor: 2 pixel_criterion: l1 pixel_weight: !!float 1e-2 diff --git a/codes/options/train/train_ESRGAN_res.yml b/codes/options/train/train_ESRGAN_res.yml new file mode 100644 index 00000000..80f4249b --- /dev/null +++ b/codes/options/train/train_ESRGAN_res.yml @@ -0,0 +1,83 @@ +#### general settings +name: esrgan_res +use_tb_logger: true +model: srgan +distortion: sr +scale: 4 +gpu_ids: [0] +amp_opt_level: O1 + +#### datasets +datasets: + train: + name: DIV2K + mode: LQGT + dataroot_GT: E:/4k6k/datasets/div2k/DIV2K800_sub + dataroot_LQ: E:/4k6k/datasets/div2k/DIV2K800_sub_bicLRx4 + + use_shuffle: true + n_workers: 10 # per GPU + batch_size: 24 + target_size: 128 + use_flip: true + use_rot: true + color: RGB + val: + name: div2kval + mode: LQGT + dataroot_GT: E:/4k6k/datasets/div2k/div2k_valid_hr + dataroot_LQ: E:/4k6k/datasets/div2k/div2k_valid_lr_bicubic + +#### network structures +network_G: + which_model_G: ResGen + nf: 256 +network_D: + which_model_D: discriminator_resnet_passthrough + nf: 42 + +#### path +path: + #pretrain_model_G: ../experiments/blacked_fix_and_upconv_xl_part1/models/3000_G.pth + #pretrain_model_D: ~ + strict_load: true + resume_state: ~ + +#### training settings: learning rate scheme, loss +train: + lr_G: !!float 1e-4 + weight_decay_G: 0 + beta1_G: 0.9 + beta2_G: 0.99 + lr_D: !!float 1e-4 + weight_decay_D: 0 + beta1_D: 0.9 + beta2_D: 0.99 + lr_scheme: MultiStepLR + + niter: 400000 + warmup_iter: -1 # no warm up + lr_steps: [20000, 40000, 50000, 60000] + lr_gamma: 0.5 + mega_batch_factor: 2 + + pixel_criterion: l1 + pixel_weight: !!float 1e-2 + feature_criterion: l1 + feature_weight: 1 + feature_weight_decay: 1 + feature_weight_decay_steps: 500 + feature_weight_minimum: 1 + gan_type: gan # gan | ragan + gan_weight: !!float 1e-2 + + D_update_ratio: 1 + D_init_iters: -1 + + manual_seed: 10 + val_freq: !!float 5e2 + +#### logger +logger: + print_freq: 50 + save_checkpoint_freq: !!float 5e2 diff --git a/codes/train.py b/codes/train.py index 18691a4d..ba2c2e45 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_blacked_xl.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='options/train/train_ESRGAN_res.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -147,7 +147,7 @@ def main(): current_step = resume_state['iter'] model.resume_training(resume_state) # handle optimizers and schedulers else: - current_step = -1 + current_step = 0 start_epoch = 0 #### training