DL-Art-School/codes/trainer/networks.py
2020-12-24 23:02:10 -07:00

207 lines
9.3 KiB
Python

import functools
import importlib
import logging
import pkgutil
import sys
from collections import OrderedDict
from inspect import isfunction, getmembers
import torch
import torchvision
import models.discriminator_vgg_arch as SRGAN_arch
import models.feature_arch as feature_arch
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
logger = logging.getLogger('base')
class RegisteredModelNameError(Exception):
def __init__(self, name_error):
super().__init__(f'Registered DLAS modules must start with `register_`. Incorrect registration: {name_error}')
# Decorator that allows API clients to show DLAS how to build a nn.Module from an opt dict.
# Functions with this decorator should have a specific naming format:
# `register_<name>` where <name> is the name that will be used in configuration files to reference this model.
# Functions with this decorator are expected to take a single argument:
# - opt: A dict with the configuration options for building the module.
# They should return:
# - A torch.nn.Module object for the model being defined.
def register_model(func):
if func.__name__.startswith("register_"):
func._dlas_model_name = func.__name__[9:]
assert func._dlas_model_name
else:
raise RegisteredModelNameError(func.__name__)
func._dlas_registered_model = True
return func
def find_registered_model_fns(base_path='models'):
found_fns = {}
module_iter = pkgutil.walk_packages([base_path])
for mod in module_iter:
if mod.ispkg:
EXCLUSION_LIST = ['flownet2']
if mod.name not in EXCLUSION_LIST:
found_fns.update(find_registered_model_fns(f'{base_path}/{mod.name}'))
else:
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
importlib.import_module(mod_name)
for mod_fn in getmembers(sys.modules[mod_name], isfunction):
if hasattr(mod_fn[1], "_dlas_registered_model"):
found_fns[mod_fn[1]._dlas_model_name] = mod_fn[1]
return found_fns
class CreateModelError(Exception):
def __init__(self, name, available):
super().__init__(f'Could not find the specified model name: {name}. Tip: If your model is in a'
f' subdirectory, that directory must contain an __init__.py to be scanned. Available models:'
f'{available}')
def create_model(opt, opt_net, scale=None):
which_model = opt_net['which_model']
# For backwards compatibility.
if not which_model:
which_model = opt_net['which_model_G']
if not which_model:
which_model = opt_net['which_model_D']
registered_fns = find_registered_model_fns()
if which_model not in registered_fns.keys():
raise CreateModelError(which_model, list(registered_fns.keys()))
return registered_fns[which_model](opt_net, opt)
class GradDiscWrapper(torch.nn.Module):
def __init__(self, m):
super(GradDiscWrapper, self).__init__()
logger.info("Wrapping a discriminator..")
self.m = m
def forward(self, x):
return self.m(x)
def define_D_net(opt_net, img_sz=None, wrap=False):
which_model = opt_net['which_model_D']
if 'image_size' in opt_net.keys():
img_sz = opt_net['image_size']
if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
elif which_model == 'discriminator_vgg_128_gn':
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
input_img_factor=img_sz / 128, extra_conv=extra_conv)
if wrap:
netD = GradDiscWrapper(netD)
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
elif which_model == 'stylegan_vgg':
netD = StyleGanDiscriminator(128)
elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_50':
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'resnext':
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
#netD.load_state_dict(state_dict, strict=False)
netD.fc = torch.nn.Linear(512 * 4, 1)
elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet":
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet_fea":
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
elif which_model == "discriminator_switched":
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
final_temperature_step=opt_net['final_temperature_step'])
elif which_model == "cross_compare_vgg128":
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
elif which_model == "discriminator_refvgg":
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "psnr_approximator":
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "stylegan2_discriminator":
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
elif which_model == "rrdb_disc":
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Discriminator
def define_D(opt, wrap=False):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
return define_D_net(opt_net, img_sz, wrap=wrap)
def define_fixed_D(opt):
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
net = define_D_net(opt)
# Load the model parameters:
load_net = torch.load(opt['pretrained_path'])
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
net.load_state_dict(load_net_clean)
# Put into eval mode, freeze the parameters and set the 'weight' field.
net.eval()
for k, v in net.named_parameters():
v.requires_grad = False
net.fdisc_weight = opt['weight']
return net
# Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
if which_model == 'vgg':
# PyTorch pretrained VGG19-54, before ReLU.
if feature_layers is None:
if use_bn:
feature_layers = [49]
else:
feature_layers = [34]
if for_training:
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
use_input_norm=True)
else:
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
use_input_norm=True)
elif which_model == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
else:
raise NotImplementedError
if load_path:
# Load the model parameters:
load_net = torch.load(load_path)
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
netF.load_state_dict(load_net_clean)
if not for_training:
# Put into eval mode, freeze the parameters and set the 'weight' field.
netF.eval()
for k, v in netF.named_parameters():
v.requires_grad = False
return netF