DL-Art-School/codes/trainer/networks.py

261 lines
15 KiB
Python
Raw Normal View History

2020-10-17 14:40:28 +00:00
import functools
import logging
from collections import OrderedDict
import munch
2019-08-23 13:42:47 +00:00
import torch
2020-10-17 14:40:28 +00:00
import torchvision
2020-09-07 23:01:48 +00:00
from munch import munchify
2020-12-17 21:18:46 +00:00
import models.archs.stylegan.stylegan2_lucidrains as stylegan2
2020-10-17 14:40:28 +00:00
2020-11-10 23:06:54 +00:00
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
2019-08-23 13:42:47 +00:00
import models.archs.RRDBNet_arch as RRDBNet_arch
2020-10-17 14:40:28 +00:00
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
import models.archs.discriminator_vgg_arch as SRGAN_arch
import models.archs.feature_arch as feature_arch
from models.archs import srg2_classic
2020-11-10 23:06:54 +00:00
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
from models.archs.tecogan.teco_resgen import TecoGen
from utils.util import opt_get
2019-08-23 13:42:47 +00:00
2020-08-26 00:14:45 +00:00
logger = logging.getLogger('base')
2019-08-23 13:42:47 +00:00
# Generator
2020-11-20 06:47:24 +00:00
def define_G(opt, opt_net, scale=None):
2020-08-22 14:24:34 +00:00
if scale is None:
scale = opt['scale']
2019-08-23 13:42:47 +00:00
which_model = opt_net['which_model_G']
if 'RRDBNet' in which_model:
2020-11-27 03:30:55 +00:00
if which_model == 'RRDBNetBypass':
2020-11-27 19:03:08 +00:00
block = RRDBNet_arch.RRDBWithBypass
elif which_model == 'RRDBNetLambda':
2020-11-27 03:30:55 +00:00
from models.archs.lambda_rrdb import LambdaRRDB
block = LambdaRRDB
else:
block = RRDBNet_arch.RRDB
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
2020-11-28 21:35:46 +00:00
gc = opt_net['gc'] if 'gc' in opt_net.keys() else 32
initial_stride = opt_net['initial_stride'] if 'initial_stride' in opt_net.keys() else 1
2020-10-29 15:39:45 +00:00
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
2020-11-27 03:30:55 +00:00
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode,
2020-11-28 21:35:46 +00:00
output_mode=output_mode, body_block=block, scale=opt_net['scale'], growth_channels=gc,
initial_stride=initial_stride)
2020-06-25 01:49:37 +00:00
elif which_model == "ConfigurableSwitchedResidualGenerator2":
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
2020-06-25 01:49:37 +00:00
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
2020-07-18 13:24:02 +00:00
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
2020-06-25 01:49:37 +00:00
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
2020-11-10 23:16:41 +00:00
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
for_video=opt_net['for_video'])
elif which_model == "srg2classic":
netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
switch_reductions=opt_net['switch_reductions'],
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
transformation_filters=opt_net['transformation_filters'],
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
elif which_model == "flownet2":
from models.archs.flownet2 import FlowNet2
ld = 'load_path' in opt_net.keys()
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
netG = FlowNet2(args)
if ld:
sd = torch.load(opt_net['load_path'])
netG.load_state_dict(sd['state_dict'])
elif which_model == "backbone_encoder":
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
2020-09-15 22:55:38 +00:00
elif which_model == "backbone_encoder_no_ref":
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
2020-09-26 04:45:57 +00:00
elif which_model == "backbone_encoder_no_head":
netG = SwitchedGen_arch.BackboneSpinenetNoHead()
2020-09-21 18:36:49 +00:00
elif which_model == "backbone_resnet":
netG = SwitchedGen_arch.BackboneResnet()
2020-10-28 02:59:55 +00:00
elif which_model == "tecogen":
netG = TecoGen(opt_net['nf'], opt_net['scale'])
2020-11-12 22:42:05 +00:00
elif which_model == 'stylegan2':
2020-11-14 03:11:50 +00:00
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
2020-11-15 18:32:35 +00:00
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
style_depth=opt_net['style_depth'], structure_input=is_structured,
attn_layers=attn)
2020-11-20 06:47:24 +00:00
elif which_model == 'srflow_orig':
from models.archs.srflow_orig import SRFlowNet_arch
2020-11-28 04:37:10 +00:00
netG = SRFlowNet_arch.SRFlowNet(in_nc=3, out_nc=3, nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
2020-11-20 06:47:24 +00:00
K=opt_net['K'], opt=opt)
elif which_model == 'rrdb_latent_wrapper':
from models.archs.srflow_orig.RRDBNet_arch import RRDBLatentWrapper
netG = RRDBLatentWrapper(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
nf=opt_net['nf'], nb=opt_net['nb'], with_bypass=opt_net['with_bypass'],
blocks=opt_net['blocks_for_latent'], scale=opt_net['scale'], pretrain_rrdb_path=opt_net['pretrain_path'])
elif which_model == 'rrdb_centipede':
output_mode = opt_net['output_mode'] if 'output_mode' in opt_net.keys() else 'hq_only'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], scale=opt_net['scale'],
headless=True, output_mode=output_mode)
2020-11-24 20:20:20 +00:00
elif which_model == 'rrdb_srflow':
from models.archs.srflow_orig.RRDBNet_arch import RRDBNet
netG = RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
2020-11-29 19:21:31 +00:00
nf=opt_net['nf'], nb=opt_net['nb'], scale=opt_net['scale'],
initial_conv_stride=opt_net['initial_stride'])
2020-12-03 22:32:21 +00:00
elif which_model == 'igpt2':
from models.archs.transformers.igpt.gpt2 import iGPT2
netG = iGPT2(opt_net['embed_dim'], opt_net['num_heads'], opt_net['num_layers'], opt_net['num_pixels'] ** 2, opt_net['num_vocab'], centroids_file=opt_net['centroids_file'])
elif which_model == 'byol':
from models.archs.byol.byol_model_wrapper import BYOL
subnet = define_G(opt, opt_net['subnet'])
netG = BYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
structural_mlp=opt_get(opt_net, ['use_structural_mlp'], False))
2020-12-10 22:07:35 +00:00
elif which_model == 'structural_byol':
from models.archs.byol.byol_structural import StructuralBYOL
2020-12-10 22:07:35 +00:00
subnet = define_G(opt, opt_net['subnet'])
netG = StructuralBYOL(subnet, opt_net['image_size'], opt_net['hidden_layer'],
pretrained_state_dict=opt_get(opt_net, ["pretrained_path"]),
freeze_until=opt_get(opt_net, ['freeze_until'], 0))
elif which_model == 'spinenet':
from models.archs.spinenet_arch import SpineNet
netG = SpineNet(str(opt_net['arch']), in_channels=3, use_input_norm=opt_net['use_input_norm'])
2020-12-16 00:16:19 +00:00
elif which_model == 'spinenet_with_logits':
from models.archs.spinenet_arch import SpinenetWithLogits
netG = SpinenetWithLogits(str(opt_net['arch']), opt_net['output_to_attach'], opt_net['num_labels'],
in_channels=3, use_input_norm=opt_net['use_input_norm'])
2019-08-23 13:42:47 +00:00
else:
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
return netG
2020-08-25 23:58:20 +00:00
class GradDiscWrapper(torch.nn.Module):
def __init__(self, m):
super(GradDiscWrapper, self).__init__()
2020-08-26 00:14:45 +00:00
logger.info("Wrapping a discriminator..")
2020-08-25 23:58:20 +00:00
self.m = m
2020-08-26 00:14:45 +00:00
def forward(self, x):
return self.m(x)
2020-08-25 23:58:20 +00:00
def define_D_net(opt_net, img_sz=None, wrap=False):
2019-08-23 13:42:47 +00:00
which_model = opt_net['which_model_D']
2020-09-20 03:47:10 +00:00
if 'image_size' in opt_net.keys():
img_sz = opt_net['image_size']
2019-08-23 13:42:47 +00:00
if which_model == 'discriminator_vgg_128':
2020-08-31 15:50:30 +00:00
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
elif which_model == 'discriminator_vgg_128_gn':
extra_conv = opt_net['extra_conv'] if 'extra_conv' in opt_net.keys() else False
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'],
input_img_factor=img_sz / 128, extra_conv=extra_conv)
2020-08-26 00:14:45 +00:00
if wrap:
netD = GradDiscWrapper(netD)
2020-10-18 15:57:47 +00:00
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
2020-11-10 23:06:54 +00:00
elif which_model == 'stylegan_vgg':
netD = StyleGanDiscriminator(128)
elif which_model == 'discriminator_resnet':
2020-05-02 01:56:14 +00:00
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
elif which_model == 'discriminator_resnet_50':
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
2020-10-01 17:48:14 +00:00
elif which_model == 'resnext':
2020-10-01 21:49:28 +00:00
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
2020-11-10 23:06:54 +00:00
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
#netD.load_state_dict(state_dict, strict=False)
2020-10-01 17:48:14 +00:00
netD.fc = torch.nn.Linear(512 * 4, 1)
2020-07-06 03:49:09 +00:00
elif which_model == 'discriminator_pix':
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet":
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
elif which_model == "discriminator_unet_fea":
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
elif which_model == "discriminator_switched":
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
final_temperature_step=opt_net['final_temperature_step'])
2020-08-06 14:56:21 +00:00
elif which_model == "cross_compare_vgg128":
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
elif which_model == "discriminator_refvgg":
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
2020-10-31 17:08:55 +00:00
elif which_model == "psnr_approximator":
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
2020-11-12 22:42:05 +00:00
elif which_model == "stylegan2_discriminator":
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
2020-11-15 18:32:35 +00:00
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
2020-11-27 19:03:08 +00:00
elif which_model == "rrdb_disc":
netD = RRDBNet_arch.RRDBDiscriminator(opt_net['in_nc'], opt_net['nf'], opt_net['nb'], blocks_per_checkpoint=3)
2019-08-23 13:42:47 +00:00
else:
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
return netD
# Discriminator
2020-08-25 23:58:20 +00:00
def define_D(opt, wrap=False):
img_sz = opt['datasets']['train']['target_size']
opt_net = opt['network_D']
2020-08-25 23:58:20 +00:00
return define_D_net(opt_net, img_sz, wrap=wrap)
def define_fixed_D(opt):
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
net = define_D_net(opt)
# Load the model parameters:
load_net = torch.load(opt['pretrained_path'])
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
net.load_state_dict(load_net_clean)
# Put into eval mode, freeze the parameters and set the 'weight' field.
net.eval()
for k, v in net.named_parameters():
v.requires_grad = False
net.fdisc_weight = opt['weight']
2020-07-31 21:07:10 +00:00
return net
2019-08-23 13:42:47 +00:00
# Define network used for perceptual loss
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
2020-08-22 19:08:33 +00:00
if which_model == 'vgg':
# PyTorch pretrained VGG19-54, before ReLU.
if feature_layers is None:
if use_bn:
feature_layers = [49]
else:
feature_layers = [34]
if for_training:
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
2020-08-22 19:08:33 +00:00
use_input_norm=True)
else:
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
2020-08-22 19:08:33 +00:00
use_input_norm=True)
elif which_model == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
else:
raise NotImplementedError
if load_path:
# Load the model parameters:
load_net = torch.load(load_path)
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
netF.load_state_dict(load_net_clean)
if not for_training:
# Put into eval mode, freeze the parameters and set the 'weight' field.
netF.eval()
for k, v in netF.named_parameters():
v.requires_grad = False
2019-08-23 13:42:47 +00:00
return netF