James Betker 8ab595e427 Add FlatProcessorNet
After doing some thinking and reading on the subject, it occurred to me that
I was treating the generator like a discriminator by focusing the network
complexity at the feature levels. It makes far more sense to process each conv
level equally for the generator, hence the FlatProcessorNet in this commit. This
network borrows some of the residual pass-through logic from RRDB which makes
the gradient path exceptionally short for pretty much all model parameters and
can be trained in O1 optimization mode without overflows again.
2020-04-28 11:49:21 -06:00

73 lines
3.2 KiB

import torch
import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.EDVR_arch as EDVR_arch
import models.archs.HighToLowResNet as HighToLowResNet
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
import math
# Generator
def define_G(opt):
opt_net = opt['network_G']
which_model = opt_net['which_model_G']
scale = opt['scale']
# image restoration
if which_model == 'MSRResNet':
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'RRDBNet':
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
scale_per_step = math.sqrt(scale)
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
# image corruption
elif which_model == 'HighToLowResNet':
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
elif which_model == 'FlatProcessorNet':
netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
# video restoration
elif which_model == 'EDVR':
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG
# Discriminator
def define_D(opt):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Define network used for perceptual loss
def define_F(opt, use_bn=False):
gpu_ids = opt['gpu_ids']
device = torch.device('cuda' if gpu_ids else 'cpu')
# PyTorch pretrained VGG19-54, before ReLU.
if use_bn:
feature_layer = 49
feature_layer = 34
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
use_input_norm=True, device=device)
netF.eval() # No need to train
return netF