forked from mrq/DL-Art-School
8ab595e427
After doing some thinking and reading on the subject, it occurred to me that I was treating the generator like a discriminator by focusing the network complexity at the feature levels. It makes far more sense to process each conv level equally for the generator, hence the FlatProcessorNet in this commit. This network borrows some of the residual pass-through logic from RRDB which makes the gradient path exceptionally short for pretty much all model parameters and can be trained in O1 optimization mode without overflows again.
73 lines
3.2 KiB
Python
73 lines
3.2 KiB
Python
import torch
|
|
import models.archs.SRResNet_arch as SRResNet_arch
|
|
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
|
import models.archs.EDVR_arch as EDVR_arch
|
|
import models.archs.HighToLowResNet as HighToLowResNet
|
|
import models.archs.FlatProcessorNet_arch as FlatProcessorNet_arch
|
|
import math
|
|
|
|
# Generator
|
|
def define_G(opt):
|
|
opt_net = opt['network_G']
|
|
which_model = opt_net['which_model_G']
|
|
scale = opt['scale']
|
|
|
|
# image restoration
|
|
if which_model == 'MSRResNet':
|
|
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
|
elif which_model == 'RRDBNet':
|
|
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
|
|
scale_per_step = math.sqrt(scale)
|
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
|
|
# image corruption
|
|
elif which_model == 'HighToLowResNet':
|
|
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
|
|
elif which_model == 'FlatProcessorNet':
|
|
netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
|
|
assembler_blocks=opt_net['assembler_blocks'])
|
|
# video restoration
|
|
elif which_model == 'EDVR':
|
|
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
|
|
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
|
|
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
|
|
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
|
|
w_TSA=opt_net['w_TSA'])
|
|
|
|
else:
|
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
|
|
|
return netG
|
|
|
|
|
|
# Discriminator
|
|
def define_D(opt):
|
|
img_sz = opt['datasets']['train']['target_size']
|
|
opt_net = opt['network_D']
|
|
which_model = opt_net['which_model_D']
|
|
|
|
if which_model == 'discriminator_vgg_128':
|
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
|
else:
|
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
|
return netD
|
|
|
|
|
|
# Define network used for perceptual loss
|
|
def define_F(opt, use_bn=False):
|
|
gpu_ids = opt['gpu_ids']
|
|
device = torch.device('cuda' if gpu_ids else 'cpu')
|
|
# PyTorch pretrained VGG19-54, before ReLU.
|
|
if use_bn:
|
|
feature_layer = 49
|
|
else:
|
|
feature_layer = 34
|
|
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
|
use_input_norm=True, device=device)
|
|
netF.eval() # No need to train
|
|
return netF
|