forked from mrq/DL-Art-School
3cd85f8073
This is a simpler resnet-based generator which performs mutations on an input interspersed with interpolate-upsampling. It is a two part generator: 1) A component that "fixes" LQ images with a long string of resnet blocks. This component is intended to remove compression artifacts and other noise from a LQ image. 2) A component that can double the image size. The idea is that this component be trained so that it can work at most reasonable resolutions, such that it can be repeatedly applied to itself to perform multiple upsamples. The motivation here is to simplify what is being done inside of RRDB. I don't believe the complexity inside of that network is justified.
92 lines
4.5 KiB
Python
92 lines
4.5 KiB
Python
import torch
|
|
import models.archs.SRResNet_arch as SRResNet_arch
|
|
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
|
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
|
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
|
|
import models.archs.FlatProcessorNetNew_arch as FlatProcessorNetNew_arch
|
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
|
import models.archs.RRDBNetXL_arch as RRDBNetXL_arch
|
|
#import models.archs.EDVR_arch as EDVR_arch
|
|
import models.archs.HighToLowResNet as HighToLowResNet
|
|
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
|
|
import models.archs.arch_util as arch_utils
|
|
import models.archs.ResGen_arch as ResGen_arch
|
|
import math
|
|
|
|
# Generator
|
|
def define_G(opt):
|
|
opt_net = opt['network_G']
|
|
which_model = opt_net['which_model_G']
|
|
scale = opt['scale']
|
|
|
|
# image restoration
|
|
if which_model == 'MSRResNet':
|
|
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
|
elif which_model == 'RRDBNet':
|
|
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
|
|
scale_per_step = math.sqrt(scale)
|
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
|
|
elif which_model == 'RRDBNetXL':
|
|
scale_per_step = math.sqrt(scale)
|
|
netG = RRDBNetXL_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb_lo=opt_net['nblo'], nb_med=opt_net['nbmed'], nb_hi=opt_net['nbhi'],
|
|
interpolation_scale_factor=scale_per_step)
|
|
elif which_model == 'ResGen':
|
|
netG = ResGen_arch.fixup_resnet34(num_filters=opt_net['nf'])
|
|
|
|
# image corruption
|
|
elif which_model == 'HighToLowResNet':
|
|
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
|
|
elif which_model == 'FlatProcessorNet':
|
|
'''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
|
|
assembler_blocks=opt_net['assembler_blocks'])'''
|
|
netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])
|
|
# video restoration
|
|
elif which_model == 'EDVR':
|
|
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
|
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
|
|
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
|
|
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
|
|
w_TSA=opt_net['w_TSA'])
|
|
|
|
else:
|
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
|
|
|
return netG
|
|
|
|
|
|
# Discriminator
|
|
def define_D(opt):
|
|
img_sz = opt['datasets']['train']['target_size']
|
|
opt_net = opt['network_D']
|
|
which_model = opt_net['which_model_D']
|
|
|
|
if which_model == 'discriminator_vgg_128':
|
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
|
elif which_model == 'discriminator_resnet':
|
|
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
|
elif which_model == 'discriminator_resnet_passthrough':
|
|
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
|
else:
|
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
|
return netD
|
|
|
|
|
|
# Define network used for perceptual loss
|
|
def define_F(opt, use_bn=False):
|
|
gpu_ids = opt['gpu_ids']
|
|
device = torch.device('cuda' if gpu_ids else 'cpu')
|
|
# PyTorch pretrained VGG19-54, before ReLU.
|
|
if use_bn:
|
|
feature_layer = 49
|
|
else:
|
|
feature_layer = 34
|
|
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
|
use_input_norm=True, device=device)
|
|
netF.eval() # No need to train
|
|
return netF
|