forked from mrq/DL-Art-School
207 lines
9.3 KiB
207 lines
9.3 KiB
import functools
import importlib
import logging
import pkgutil
import sys
from collections import OrderedDict
from inspect import isfunction, getmembers
import torch
import torchvision
import models.discriminator_vgg_arch as SRGAN_arch
import models.feature_arch as feature_arch
import models.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
from models.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
logger = logging.getLogger('base')
class RegisteredModelNameError(Exception):
def __init__(self, name_error):
super().__init__(f'Registered DLAS modules must start with `register_`. Incorrect registration: {name_error}')
# Decorator that allows API clients to show DLAS how to build a nn.Module from an opt dict.
# Functions with this decorator should have a specific naming format:
# `register_<name>` where <name> is the name that will be used in configuration files to reference this model.
# Functions with this decorator are expected to take a single argument:
# - opt: A dict with the configuration options for building the module.
# They should return:
# - A torch.nn.Module object for the model being defined.
def register_model(func):
if func.__name__.startswith("register_"):
func._dlas_model_name = func.__name__[9:]
assert func._dlas_model_name
raise RegisteredModelNameError(func.__name__)
func._dlas_registered_model = True
return func
def find_registered_model_fns(base_path='models'):
found_fns = {}
module_iter = pkgutil.walk_packages([base_path])
for mod in module_iter:
if mod.ispkg:
EXCLUSION_LIST = ['flownet2']
mod_name = f'{base_path}/{}'.replace('/', '.')
for mod_fn in getmembers(sys.modules[mod_name], isfunction):
if hasattr(mod_fn[1], "_dlas_registered_model"):
found_fns[mod_fn[1]._dlas_model_name] = mod_fn[1]
return found_fns
class CreateModelError(Exception):
def __init__(self, name, available):
super().__init__(f'Could not find the specified model name: {name}. Tip: If your model is in a'
f' subdirectory, that directory must contain an to be scanned. Available models:'
def create_model(opt, opt_net, scale=None):
which_model = opt_net['which_model']
# For backwards compatibility.
if not which_model:
which_model = opt_net['which_model_G']
if not which_model:
which_model = opt_net['which_model_D']
registered_fns = find_registered_model_fns()
if which_model not in registered_fns.keys():
raise CreateModelError(which_model, list(registered_fns.keys()))
return registered_fns[which_model](opt_net, opt)
class GradDiscWrapper(torch.nn.Module):
def __init__(self, m):
super(GradDiscWrapper, self).__init__()
|"Wrapping a discriminator..")
self.m = m
def forward(self, x):
return self.m(x)
def define_D_net(opt_net, img_sz=None, wrap=False):
which_model = opt_net['which_model_D']
if 'image_size' in opt_net.keys():
img_sz = opt_net['image_size']
if which_model == 'discriminator_vgg_128':
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
elif which_model == 'discriminator_vgg_128_gn':
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
input_img_factor=img_sz / 128, extra_conv=extra_conv)
if wrap:
netD = GradDiscWrapper(netD)
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
elif which_model == 'stylegan_vgg':
netD = StyleGanDiscriminator(128)
elif which_model == 'discriminator_resnet':
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_50':
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'resnext':
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
#state_dict = torch.hub.load_state_dict_from_url('', progress=True)
#netD.load_state_dict(state_dict, strict=False)
netD.fc = torch.nn.Linear(512 * 4, 1)
elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet":
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet_fea":
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
elif which_model == "discriminator_switched":
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
elif which_model == "cross_compare_vgg128":
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
elif which_model == "discriminator_refvgg":
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "psnr_approximator":
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
elif which_model == "stylegan2_discriminator":
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
elif which_model == "rrdb_disc":
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Discriminator
def define_D(opt, wrap=False):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
return define_D_net(opt_net, img_sz, wrap=wrap)
def define_fixed_D(opt):
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
net = define_D_net(opt)
# Load the model parameters:
load_net = torch.load(opt['pretrained_path'])
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
load_net_clean[k] = v
# Put into eval mode, freeze the parameters and set the 'weight' field.
for k, v in net.named_parameters():
v.requires_grad = False
net.fdisc_weight = opt['weight']
return net
# Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
if which_model == 'vgg':
# PyTorch pretrained VGG19-54, before ReLU.
if feature_layers is None:
if use_bn:
feature_layers = [49]
feature_layers = [34]
if for_training:
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
elif which_model == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
raise NotImplementedError
if load_path:
# Load the model parameters:
load_net = torch.load(load_path)
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
load_net_clean[k] = v
if not for_training:
# Put into eval mode, freeze the parameters and set the 'weight' field.
for k, v in netF.named_parameters():
v.requires_grad = False
return netF