This is a simpler resnet-based generator which performs mutations
on an input interspersed with interpolate-upsampling. It is a two
part generator:
1) A component that "fixes" LQ images with a long string of resnet
blocks. This component is intended to remove compression artifacts
and other noise from a LQ image.
2) A component that can double the image size. The idea is that this
component be trained so that it can work at most reasonable
resolutions, such that it can be repeatedly applied to itself to
perform multiple upsamples.
The motivation here is to simplify what is being done inside of RRDB.
I don't believe the complexity inside of that network is justified.
92 lines
4.5 KiB
Python
92 lines
4.5 KiB
Python
import torch
|
|
import models.archs.SRResNet_arch as SRResNet_arch
|
|
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
|
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
|
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
|
|
import models.archs.FlatProcessorNetNew_arch as FlatProcessorNetNew_arch
|
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
|
import models.archs.RRDBNetXL_arch as RRDBNetXL_arch
|
|
#import models.archs.EDVR_arch as EDVR_arch
|
|
import models.archs.HighToLowResNet as HighToLowResNet
|
|
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
|
|
import models.archs.arch_util as arch_utils
|
|
import models.archs.ResGen_arch as ResGen_arch
|
|
import math
|
|
|
|
# Generator
|
|
def define_G(opt):
|
|
opt_net = opt['network_G']
|
|
which_model = opt_net['which_model_G']
|
|
scale = opt['scale']
|
|
|
|
# image restoration
|
|
if which_model == 'MSRResNet':
|
|
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
|
elif which_model == 'RRDBNet':
|
|
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
|
|
scale_per_step = math.sqrt(scale)
|
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
|
|
elif which_model == 'RRDBNetXL':
|
|
scale_per_step = math.sqrt(scale)
|
|
netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
|
|
interpolation_scale_factor=scale_per_step)
|
|
elif which_model == 'ResGen':
|
|
netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf'])
|
|
|
|
# image corruption
|
|
elif which_model == 'HighToLowResNet':
|
|
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
|
|
elif which_model == 'FlatProcessorNet':
|
|
'''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
|
|
assembler_blocks=opt_net['assembler_blocks'])'''
|
|
netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])
|
|
# video restoration
|
|
elif which_model == 'EDVR':
|
|
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
|
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
|
|
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
|
|
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
|
|
w_TSA=opt_net['w_TSA'])
|
|
|
|
else:
|
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
|
|
|
return netG
|
|
|
|
|
|
# Discriminator
|
|
def define_D(opt):
|
|
img_sz = opt['datasets']['train']['target_size']
|
|
opt_net = opt['network_D']
|
|
which_model = opt_net['which_model_D']
|
|
|
|
if which_model == 'discriminator_vgg_128':
|
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
|
elif which_model == 'discriminator_resnet':
|
|
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
|
elif which_model == 'discriminator_resnet_passthrough':
|
|
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
|
else:
|
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
|
return netD
|
|
|
|
|
|
# Define network used for perceptual loss
|
|
def define_F(opt, use_bn=False):
|
|
gpu_ids = opt['gpu_ids']
|
|
device = torch.device('cuda' if gpu_ids else 'cpu')
|
|
# PyTorch pretrained VGG19-54, before ReLU.
|
|
if use_bn:
|
|
feature_layer = 49
|
|
else:
|
|
feature_layer = 34
|
|
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
|
use_input_norm=True, device=device)
|
|
netF.eval() # No need to train
|
|
return netF
|