207 lines
9.3 KiB
Python
207 lines
9.3 KiB
Python
import functools
|
|
import importlib
|
|
import logging
|
|
import pkgutil
|
|
import sys
|
|
from collections import OrderedDict
|
|
from inspect import isfunction, getmembers
|
|
|
|
import torch
|
|
import torchvision
|
|
|
|
import models.discriminator_vgg_arch as SRGAN_arch
|
|
import models.feature_arch as feature_arch
|
|
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
|
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
|
|
|
logger = logging.getLogger('base')
|
|
|
|
|
|
class RegisteredModelNameError(Exception):
|
|
def __init__(self, name_error):
|
|
super().__init__(f'Registered DLAS modules must start with `register_`. Incorrect registration: {name_error}')
|
|
|
|
|
|
# Decorator that allows API clients to show DLAS how to build a nn.Module from an opt dict.
|
|
# Functions with this decorator should have a specific naming format:
|
|
# `register_<name>` where <name> is the name that will be used in configuration files to reference this model.
|
|
# Functions with this decorator are expected to take a single argument:
|
|
# - opt: A dict with the configuration options for building the module.
|
|
# They should return:
|
|
# - A torch.nn.Module object for the model being defined.
|
|
def register_model(func):
|
|
if func.__name__.startswith("register_"):
|
|
func._dlas_model_name = func.__name__[9:]
|
|
assert func._dlas_model_name
|
|
else:
|
|
raise RegisteredModelNameError(func.__name__)
|
|
func._dlas_registered_model = True
|
|
return func
|
|
|
|
|
|
def find_registered_model_fns(base_path='models'):
|
|
found_fns = {}
|
|
module_iter = pkgutil.walk_packages([base_path])
|
|
for mod in module_iter:
|
|
if mod.ispkg:
|
|
EXCLUSION_LIST = ['flownet2']
|
|
if mod.name not in EXCLUSION_LIST:
|
|
found_fns.update(find_registered_model_fns(f'{base_path}/{mod.name}'))
|
|
else:
|
|
mod_name = f'{base_path}/{mod.name}'.replace('/', '.')
|
|
importlib.import_module(mod_name)
|
|
for mod_fn in getmembers(sys.modules[mod_name], isfunction):
|
|
if hasattr(mod_fn[1], "_dlas_registered_model"):
|
|
found_fns[mod_fn[1]._dlas_model_name] = mod_fn[1]
|
|
return found_fns
|
|
|
|
|
|
class CreateModelError(Exception):
|
|
def __init__(self, name, available):
|
|
super().__init__(f'Could not find the specified model name: {name}. Tip: If your model is in a'
|
|
f' subdirectory, that directory must contain an __init__.py to be scanned. Available models:'
|
|
f'{available}')
|
|
|
|
|
|
def create_model(opt, opt_net, scale=None):
|
|
which_model = opt_net['which_model']
|
|
# For backwards compatibility.
|
|
if not which_model:
|
|
which_model = opt_net['which_model_G']
|
|
if not which_model:
|
|
which_model = opt_net['which_model_D']
|
|
registered_fns = find_registered_model_fns()
|
|
if which_model not in registered_fns.keys():
|
|
raise CreateModelError(which_model, list(registered_fns.keys()))
|
|
return registered_fns[which_model](opt_net, opt)
|
|
|
|
|
|
class GradDiscWrapper(torch.nn.Module):
|
|
def __init__(self, m):
|
|
super(GradDiscWrapper, self).__init__()
|
|
logger.info("Wrapping a discriminator..")
|
|
self.m = m
|
|
|
|
def forward(self, x):
|
|
return self.m(x)
|
|
|
|
def define_D_net(opt_net, img_sz=None, wrap=False):
|
|
which_model = opt_net['which_model_D']
|
|
|
|
if 'image_size' in opt_net.keys():
|
|
img_sz = opt_net['image_size']
|
|
|
|
if which_model == 'discriminator_vgg_128':
|
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
|
|
elif which_model == 'discriminator_vgg_128_gn':
|
|
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
|
|
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
|
|
input_img_factor=img_sz / 128, extra_conv=extra_conv)
|
|
if wrap:
|
|
netD = GradDiscWrapper(netD)
|
|
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
|
|
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
|
|
elif which_model == 'stylegan_vgg':
|
|
netD = StyleGanDiscriminator(128)
|
|
elif which_model == 'discriminator_resnet':
|
|
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
|
elif which_model == 'discriminator_resnet_50':
|
|
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
|
elif which_model == 'resnext':
|
|
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
|
|
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
|
|
#netD.load_state_dict(state_dict, strict=False)
|
|
netD.fc = torch.nn.Linear(512 * 4, 1)
|
|
elif which_model == 'discriminator_pix':
|
|
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
|
elif which_model == "discriminator_unet":
|
|
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
|
elif which_model == "discriminator_unet_fea":
|
|
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
|
|
elif which_model == "discriminator_switched":
|
|
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
|
|
final_temperature_step=opt_net['final_temperature_step'])
|
|
elif which_model == "cross_compare_vgg128":
|
|
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
|
|
elif which_model == "discriminator_refvgg":
|
|
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
|
elif which_model == "psnr_approximator":
|
|
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
|
elif which_model == "stylegan2_discriminator":
|
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
|
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
|
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
|
elif which_model == "rrdb_disc":
|
|
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
|
|
else:
|
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
|
return netD
|
|
|
|
# Discriminator
|
|
def define_D(opt, wrap=False):
|
|
img_sz = opt['datasets']['train']['target_size']
|
|
opt_net = opt['network_D']
|
|
return define_D_net(opt_net, img_sz, wrap=wrap)
|
|
|
|
def define_fixed_D(opt):
|
|
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
|
|
net = define_D_net(opt)
|
|
|
|
# Load the model parameters:
|
|
load_net = torch.load(opt['pretrained_path'])
|
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
|
for k, v in load_net.items():
|
|
if k.startswith('module.'):
|
|
load_net_clean[k[7:]] = v
|
|
else:
|
|
load_net_clean[k] = v
|
|
net.load_state_dict(load_net_clean)
|
|
|
|
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
|
net.eval()
|
|
for k, v in net.named_parameters():
|
|
v.requires_grad = False
|
|
net.fdisc_weight = opt['weight']
|
|
|
|
return net
|
|
|
|
|
|
# Define network used for perceptual loss
|
|
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
|
|
if which_model == 'vgg':
|
|
# PyTorch pretrained VGG19-54, before ReLU.
|
|
if feature_layers is None:
|
|
if use_bn:
|
|
feature_layers = [49]
|
|
else:
|
|
feature_layers = [34]
|
|
if for_training:
|
|
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
|
use_input_norm=True)
|
|
else:
|
|
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
|
use_input_norm=True)
|
|
elif which_model == 'wide_resnet':
|
|
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if load_path:
|
|
# Load the model parameters:
|
|
load_net = torch.load(load_path)
|
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
|
for k, v in load_net.items():
|
|
if k.startswith('module.'):
|
|
load_net_clean[k[7:]] = v
|
|
else:
|
|
load_net_clean[k] = v
|
|
netF.load_state_dict(load_net_clean)
|
|
|
|
if not for_training:
|
|
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
|
netF.eval()
|
|
for k, v in netF.named_parameters():
|
|
v.requires_grad = False
|
|
|
|
return netF
|