DL-Art-School/codes/models/networks.py
2020-04-22 00:37:41 -06:00

63 lines
2.5 KiB
Python

import torch
import models.archs.SRResNet_arch as SRResNet_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.RRDBNet_arch as RRDBNet_arch
import models.archs.EDVR_arch as EDVR_arch
import math
# Generator
def define_G(opt):
opt_net = opt['network_G']
which_model = opt_net['which_model_G']
scale = opt['scale']
# image restoration
if which_model == 'MSRResNet':
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
elif which_model == 'RRDBNet':
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
scale_per_step = math.sqrt(scale)
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], interpolation_scale_factor=scale_per_step)
# video restoration
elif which_model == 'EDVR':
netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
back_RBs=opt_net['back_RBs'], center=opt_net['center'],
predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
w_TSA=opt_net['w_TSA'])
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG
# Discriminator
def define_D(opt):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Define network used for perceptual loss
def define_F(opt, use_bn=False):
gpu_ids = opt['gpu_ids']
device = torch.device('cuda' if gpu_ids else 'cpu')
# PyTorch pretrained VGG19-54, before ReLU.
if use_bn:
feature_layer = 49
else:
feature_layer = 34
netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
use_input_norm=True, device=device)
netF.eval() # No need to train
return netF