forked from mrq/DL-Art-School
162 lines
5.6 KiB
Python
162 lines
5.6 KiB
Python
|
import functools
|
||
|
import logging
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.nn import init
|
||
|
|
||
|
import models.SPSR_modules.architecture as arch
|
||
|
logger = logging.getLogger('base')
|
||
|
####################
|
||
|
# initialize
|
||
|
####################
|
||
|
|
||
|
|
||
|
def weights_init_normal(m, std=0.02):
|
||
|
classname = m.__class__.__name__
|
||
|
if classname.find('Conv') != -1:
|
||
|
init.normal_(m.weight.data, 0.0, std)
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
elif classname.find('Linear') != -1:
|
||
|
init.normal_(m.weight.data, 0.0, std)
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
elif classname.find('BatchNorm2d') != -1:
|
||
|
init.normal_(m.weight.data, 1.0, std) # BN also uses norm
|
||
|
init.constant_(m.bias.data, 0.0)
|
||
|
|
||
|
|
||
|
def weights_init_kaiming(m, scale=1):
|
||
|
classname = m.__class__.__name__
|
||
|
if classname.find('Conv') != -1:
|
||
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||
|
m.weight.data *= scale
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
elif classname.find('Linear') != -1:
|
||
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||
|
m.weight.data *= scale
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
elif classname.find('BatchNorm2d') != -1:
|
||
|
if m.affine != False:
|
||
|
|
||
|
init.constant_(m.weight.data, 1.0)
|
||
|
init.constant_(m.bias.data, 0.0)
|
||
|
|
||
|
|
||
|
def weights_init_orthogonal(m):
|
||
|
classname = m.__class__.__name__
|
||
|
if classname.find('Conv') != -1:
|
||
|
init.orthogonal_(m.weight.data, gain=1)
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
elif classname.find('Linear') != -1:
|
||
|
init.orthogonal_(m.weight.data, gain=1)
|
||
|
if m.bias is not None:
|
||
|
m.bias.data.zero_()
|
||
|
elif classname.find('BatchNorm2d') != -1:
|
||
|
init.constant_(m.weight.data, 1.0)
|
||
|
init.constant_(m.bias.data, 0.0)
|
||
|
|
||
|
|
||
|
def init_weights(net, init_type='kaiming', scale=1, std=0.02):
|
||
|
# scale for 'kaiming', std for 'normal'.
|
||
|
if init_type == 'normal':
|
||
|
weights_init_normal_ = functools.partial(weights_init_normal, std=std)
|
||
|
net.apply(weights_init_normal_)
|
||
|
elif init_type == 'kaiming':
|
||
|
weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
|
||
|
net.apply(weights_init_kaiming_)
|
||
|
elif init_type == 'orthogonal':
|
||
|
net.apply(weights_init_orthogonal)
|
||
|
else:
|
||
|
raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
|
||
|
|
||
|
|
||
|
####################
|
||
|
# define network
|
||
|
####################
|
||
|
|
||
|
|
||
|
# Generator
|
||
|
def define_G(opt, device=None):
|
||
|
opt_net = opt['network_G']
|
||
|
which_model = opt_net['which_model_G']
|
||
|
|
||
|
if which_model == 'spsr_net':
|
||
|
netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
||
|
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
||
|
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
|
||
|
else:
|
||
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
||
|
|
||
|
if opt['is_train']:
|
||
|
init_weights(netG, init_type='kaiming', scale=0.1)
|
||
|
|
||
|
return netG
|
||
|
|
||
|
|
||
|
|
||
|
# Discriminator
|
||
|
def define_D(opt):
|
||
|
opt_net = opt['network_D']
|
||
|
which_model = opt_net['which_model_D']
|
||
|
|
||
|
if which_model == 'discriminator_vgg_128':
|
||
|
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||
|
|
||
|
elif which_model == 'discriminator_vgg_96':
|
||
|
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||
|
elif which_model == 'discriminator_vgg_192':
|
||
|
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||
|
elif which_model == 'discriminator_vgg_128_SN':
|
||
|
netD = arch.Discriminator_VGG_128_SN()
|
||
|
else:
|
||
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||
|
|
||
|
init_weights(netD, init_type='kaiming', scale=1)
|
||
|
|
||
|
return netD
|
||
|
|
||
|
def define_D_grad(opt):
|
||
|
opt_net = opt['network_D']
|
||
|
which_model = opt_net['which_model_D']
|
||
|
|
||
|
if which_model == 'discriminator_vgg_128':
|
||
|
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||
|
|
||
|
elif which_model == 'discriminator_vgg_96':
|
||
|
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||
|
elif which_model == 'discriminator_vgg_192':
|
||
|
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
||
|
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
||
|
elif which_model == 'discriminator_vgg_128_SN':
|
||
|
netD = arch.Discriminator_VGG_128_SN()
|
||
|
else:
|
||
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
||
|
|
||
|
|
||
|
init_weights(netD, init_type='kaiming', scale=1)
|
||
|
|
||
|
return netD
|
||
|
|
||
|
|
||
|
def define_F(opt, use_bn=False):
|
||
|
device = torch.device('cuda')
|
||
|
# pytorch pretrained VGG19-54, before ReLU.
|
||
|
if use_bn:
|
||
|
feature_layer = 49
|
||
|
else:
|
||
|
feature_layer = 34
|
||
|
netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
|
||
|
use_input_norm=True, device=device)
|
||
|
|
||
|
netF.eval()
|
||
|
return netF
|