DL-Art-School/codes/models/SPSR_networks.py

162 lines
5.6 KiB
Python
Raw Normal View History

import functools
import logging
import torch
import torch.nn as nn
from torch.nn import init
import models.SPSR_modules.architecture as arch
logger = logging.getLogger('base')
####################
# initialize
####################
def weights_init_normal(m, std=0.02):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.normal_(m.weight.data, 0.0, std)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.normal_(m.weight.data, 0.0, std)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, std) # BN also uses norm
init.constant_(m.bias.data, 0.0)
def weights_init_kaiming(m, scale=1):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
if m.affine != False:
init.constant_(m.weight.data, 1.0)
init.constant_(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
init.orthogonal_(m.weight.data, gain=1)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('Linear') != -1:
init.orthogonal_(m.weight.data, gain=1)
if m.bias is not None:
m.bias.data.zero_()
elif classname.find('BatchNorm2d') != -1:
init.constant_(m.weight.data, 1.0)
init.constant_(m.bias.data, 0.0)
def init_weights(net, init_type='kaiming', scale=1, std=0.02):
# scale for 'kaiming', std for 'normal'.
if init_type == 'normal':
weights_init_normal_ = functools.partial(weights_init_normal, std=std)
net.apply(weights_init_normal_)
elif init_type == 'kaiming':
weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
net.apply(weights_init_kaiming_)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
####################
# define network
####################
# Generator
def define_G(opt, device=None):
opt_net = opt['network_G']
which_model = opt_net['which_model_G']
if which_model == 'spsr_net':
netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
if opt['is_train']:
init_weights(netG, init_type='kaiming', scale=0.1)
return netG
# Discriminator
def define_D(opt):
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_96':
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_192':
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_128_SN':
netD = arch.Discriminator_VGG_128_SN()
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
init_weights(netD, init_type='kaiming', scale=1)
return netD
def define_D_grad(opt):
opt_net = opt['network_D']
which_model = opt_net['which_model_D']
if which_model == 'discriminator_vgg_128':
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_96':
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_192':
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
elif which_model == 'discriminator_vgg_128_SN':
netD = arch.Discriminator_VGG_128_SN()
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
init_weights(netD, init_type='kaiming', scale=1)
return netD
def define_F(opt, use_bn=False):
device = torch.device('cuda')
# pytorch pretrained VGG19-54, before ReLU.
if use_bn:
feature_layer = 49
else:
feature_layer = 34
netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
use_input_norm=True, device=device)
netF.eval()
return netF