import functools import logging import torch import torch.nn as nn from torch.nn import init import models.SPSR_modules.architecture as arch logger = logging.getLogger('base') #################### # initialize #################### def weights_init_normal(m, std=0.02): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.normal_(m.weight.data, 0.0, std) if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.normal_(m.weight.data, 0.0, std) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, std) # BN also uses norm init.constant_(m.bias.data, 0.0) def weights_init_kaiming(m, scale=1): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') m.weight.data *= scale if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: if m.affine != False: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0) def weights_init_orthogonal(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: init.orthogonal_(m.weight.data, gain=1) if m.bias is not None: m.bias.data.zero_() elif classname.find('Linear') != -1: init.orthogonal_(m.weight.data, gain=1) if m.bias is not None: m.bias.data.zero_() elif classname.find('BatchNorm2d') != -1: init.constant_(m.weight.data, 1.0) init.constant_(m.bias.data, 0.0) def init_weights(net, init_type='kaiming', scale=1, std=0.02): # scale for 'kaiming', std for 'normal'. if init_type == 'normal': weights_init_normal_ = functools.partial(weights_init_normal, std=std) net.apply(weights_init_normal_) elif init_type == 'kaiming': weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale) net.apply(weights_init_kaiming_) elif init_type == 'orthogonal': net.apply(weights_init_orthogonal) else: raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type)) #################### # define network #################### # Generator def define_G(opt, device=None): opt_net = opt['network_G'] which_model = opt_net['which_model_G'] if which_model == 'spsr_net': netG = arch.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv') else: raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) if opt['is_train']: init_weights(netG, init_type='kaiming', scale=0.1) return netG # Discriminator def define_D(opt): opt_net = opt['network_D'] which_model = opt_net['which_model_D'] if which_model == 'discriminator_vgg_128': netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_96': netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_192': netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_128_SN': netD = arch.Discriminator_VGG_128_SN() else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) init_weights(netD, init_type='kaiming', scale=1) return netD def define_D_grad(opt): opt_net = opt['network_D'] which_model = opt_net['which_model_D'] if which_model == 'discriminator_vgg_128': netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_96': netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_192': netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \ norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type']) elif which_model == 'discriminator_vgg_128_SN': netD = arch.Discriminator_VGG_128_SN() else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) init_weights(netD, init_type='kaiming', scale=1) return netD def define_F(opt, use_bn=False): device = torch.device('cuda') # pytorch pretrained VGG19-54, before ReLU. if use_bn: feature_layer = 49 else: feature_layer = 34 netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \ use_input_norm=True, device=device) netF.eval() return netF