2020-10-17 14:40:28 +00:00
|
|
|
import functools
|
|
|
|
import logging
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
2020-10-07 15:03:30 +00:00
|
|
|
import munch
|
2019-08-23 13:42:47 +00:00
|
|
|
import torch
|
2020-10-17 14:40:28 +00:00
|
|
|
import torchvision
|
2020-09-07 23:01:48 +00:00
|
|
|
from munch import munchify
|
2020-11-15 18:32:35 +00:00
|
|
|
import models.archs.stylegan.stylegan2 as stylegan2
|
|
|
|
import models.archs.stylegan.stylegan2_unet_disc as stylegan2_unet
|
2020-10-17 14:40:28 +00:00
|
|
|
|
2020-11-10 23:06:54 +00:00
|
|
|
import models.archs.fixup_resnet.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
2019-08-23 13:42:47 +00:00
|
|
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
2020-08-02 18:55:08 +00:00
|
|
|
import models.archs.SPSR_arch as spsr
|
2020-10-17 14:40:28 +00:00
|
|
|
import models.archs.SRResNet_arch as SRResNet_arch
|
|
|
|
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
|
|
|
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
|
|
|
import models.archs.feature_arch as feature_arch
|
2020-10-12 16:20:55 +00:00
|
|
|
import models.archs.panet.panet as panet
|
2020-10-17 14:40:28 +00:00
|
|
|
import models.archs.rcan as rcan
|
2020-11-01 02:55:23 +00:00
|
|
|
from models.archs import srg2_classic
|
2020-11-10 23:06:54 +00:00
|
|
|
from models.archs.biggan.biggan_discriminator import BigGanDiscriminator
|
|
|
|
from models.archs.stylegan.Discriminator_StyleGAN import StyleGanDiscriminator
|
2020-11-05 01:07:48 +00:00
|
|
|
from models.archs.pyramid_arch import BasicResamplingFlowNet
|
2020-11-10 23:06:54 +00:00
|
|
|
from models.archs.rrdb_with_adain_latent import AdaRRDBNet, LinearLatentEstimator
|
2020-11-08 03:38:56 +00:00
|
|
|
from models.archs.rrdb_with_latent import LatentEstimator, RRDBNetWithLatent, LatentEstimator2
|
2020-10-28 02:59:55 +00:00
|
|
|
from models.archs.teco_resgen import TecoGen
|
2019-08-23 13:42:47 +00:00
|
|
|
|
2020-08-26 00:14:45 +00:00
|
|
|
logger = logging.getLogger('base')
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
# Generator
|
2020-08-22 14:24:34 +00:00
|
|
|
def define_G(opt, net_key='network_G', scale=None):
|
|
|
|
if net_key is not None:
|
|
|
|
opt_net = opt[net_key]
|
|
|
|
else:
|
|
|
|
opt_net = opt
|
|
|
|
if scale is None:
|
|
|
|
scale = opt['scale']
|
2019-08-23 13:42:47 +00:00
|
|
|
which_model = opt_net['which_model_G']
|
|
|
|
|
|
|
|
# image restoration
|
|
|
|
if which_model == 'MSRResNet':
|
|
|
|
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
|
|
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
|
|
|
elif which_model == 'RRDBNet':
|
2020-11-17 03:45:09 +00:00
|
|
|
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive'
|
2020-10-27 16:25:31 +00:00
|
|
|
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
2020-11-17 03:45:09 +00:00
|
|
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode)
|
2020-10-29 15:39:45 +00:00
|
|
|
elif which_model == 'RRDBNetBypass':
|
2020-11-18 01:31:40 +00:00
|
|
|
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
2020-10-29 15:39:45 +00:00
|
|
|
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
|
|
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
2020-11-17 03:45:09 +00:00
|
|
|
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
|
|
|
|
additive_mode=additive_mode)
|
2020-09-27 22:00:41 +00:00
|
|
|
elif which_model == 'rcan':
|
|
|
|
#args: n_resgroups, n_resblocks, res_scale, reduction, scale, n_feats
|
|
|
|
opt_net['rgb_range'] = 255
|
|
|
|
opt_net['n_colors'] = 3
|
|
|
|
args_obj = munchify(opt_net)
|
|
|
|
netG = rcan.RCAN(args_obj)
|
2020-10-12 16:20:55 +00:00
|
|
|
elif which_model == 'panet':
|
|
|
|
#args: n_resblocks, res_scale, scale, n_feats
|
|
|
|
opt_net['rgb_range'] = 255
|
|
|
|
opt_net['n_colors'] = 3
|
|
|
|
args_obj = munchify(opt_net)
|
|
|
|
netG = panet.PANET(args_obj)
|
2020-06-25 01:49:37 +00:00
|
|
|
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
2020-07-09 23:34:51 +00:00
|
|
|
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
2020-06-25 01:49:37 +00:00
|
|
|
switch_reductions=opt_net['switch_reductions'],
|
|
|
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
|
|
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
2020-07-18 13:24:02 +00:00
|
|
|
transformation_filters=opt_net['transformation_filters'], attention_norm=opt_net['attention_norm'],
|
2020-06-25 01:49:37 +00:00
|
|
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
2020-06-29 03:21:57 +00:00
|
|
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
2020-11-10 23:16:41 +00:00
|
|
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'],
|
|
|
|
for_video=opt_net['for_video'])
|
2020-11-01 02:55:23 +00:00
|
|
|
elif which_model == "srg2classic":
|
|
|
|
netG = srg2_classic.ConfigurableSwitchedResidualGenerator2(switch_depth=opt_net['switch_depth'], switch_filters=opt_net['switch_filters'],
|
|
|
|
switch_reductions=opt_net['switch_reductions'],
|
|
|
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
|
|
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
|
|
|
transformation_filters=opt_net['transformation_filters'],
|
|
|
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
|
|
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
|
|
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
2020-10-27 17:00:38 +00:00
|
|
|
elif which_model == 'spsr':
|
|
|
|
netG = spsr.SPSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
|
|
|
nb=opt_net['nb'], upscale=opt_net['scale'])
|
2020-08-03 16:25:37 +00:00
|
|
|
elif which_model == 'spsr_net_improved':
|
|
|
|
netG = spsr.SPSRNetSimplified(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
|
|
|
nb=opt_net['nb'], upscale=opt_net['scale'])
|
2020-10-15 16:13:06 +00:00
|
|
|
elif which_model == "spsr_switched":
|
|
|
|
netG = spsr.SwitchedSpsr(in_nc=3, nf=opt_net['nf'], upscale=opt_net['scale'], init_temperature=opt_net['temperature'])
|
2020-09-29 22:59:26 +00:00
|
|
|
elif which_model == "spsr7":
|
2020-10-11 04:39:55 +00:00
|
|
|
recurrent = opt_net['recurrent'] if 'recurrent' in opt_net.keys() else False
|
2020-09-29 22:59:26 +00:00
|
|
|
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
|
|
|
netG = spsr.Spsr7(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
2020-09-29 04:08:31 +00:00
|
|
|
multiplexer_reductions=opt_net['multiplexer_reductions'] if 'multiplexer_reductions' in opt_net.keys() else 3,
|
2020-10-11 04:39:55 +00:00
|
|
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10, recurrent=recurrent)
|
2020-10-07 02:35:39 +00:00
|
|
|
elif which_model == "flownet2":
|
2020-10-07 15:03:30 +00:00
|
|
|
from models.flownet2.models import FlowNet2
|
2020-10-24 17:56:39 +00:00
|
|
|
ld = 'load_path' in opt_net.keys()
|
|
|
|
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
2020-10-07 02:35:39 +00:00
|
|
|
netG = FlowNet2(args)
|
2020-10-24 17:56:39 +00:00
|
|
|
if ld:
|
|
|
|
sd = torch.load(opt_net['load_path'])
|
|
|
|
netG.load_state_dict(sd['state_dict'])
|
2020-09-12 04:55:37 +00:00
|
|
|
elif which_model == "backbone_encoder":
|
|
|
|
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
2020-09-15 22:55:38 +00:00
|
|
|
elif which_model == "backbone_encoder_no_ref":
|
|
|
|
netG = SwitchedGen_arch.BackboneEncoderNoRef(pretrained_backbone=opt_net['pretrained_spinenet'])
|
2020-09-26 04:45:57 +00:00
|
|
|
elif which_model == "backbone_encoder_no_head":
|
|
|
|
netG = SwitchedGen_arch.BackboneSpinenetNoHead()
|
2020-09-21 18:36:49 +00:00
|
|
|
elif which_model == "backbone_resnet":
|
|
|
|
netG = SwitchedGen_arch.BackboneResnet()
|
2020-10-28 02:59:55 +00:00
|
|
|
elif which_model == "tecogen":
|
|
|
|
netG = TecoGen(opt_net['nf'], opt_net['scale'])
|
2020-11-05 01:07:48 +00:00
|
|
|
elif which_model == "basic_resampling_flow_predictor":
|
|
|
|
netG = BasicResamplingFlowNet(opt_net['nf'], resample_scale=opt_net['resample_scale'])
|
2020-11-05 17:04:17 +00:00
|
|
|
elif which_model == "rrdb_with_latent":
|
|
|
|
netG = RRDBNetWithLatent(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
|
|
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
2020-11-08 03:38:56 +00:00
|
|
|
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
|
|
|
|
scale=opt_net['scale'],
|
|
|
|
bottom_latent_only=opt_net['bottom_latent_only'])
|
2020-11-10 23:06:54 +00:00
|
|
|
elif which_model == "adarrdb":
|
|
|
|
netG = AdaRRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
|
|
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'],
|
|
|
|
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'],
|
|
|
|
scale=opt_net['scale'])
|
2020-11-05 17:04:17 +00:00
|
|
|
elif which_model == "latent_estimator":
|
2020-11-08 03:38:56 +00:00
|
|
|
if opt_net['version'] == 2:
|
|
|
|
netG = LatentEstimator2(in_nc=3, nf=opt_net['nf'])
|
|
|
|
else:
|
|
|
|
overwrite = [1,2] if opt_net['only_base_level'] else []
|
|
|
|
netG = LatentEstimator(in_nc=3, nf=opt_net['nf'], overwrite_levels=overwrite)
|
2020-11-10 23:06:54 +00:00
|
|
|
elif which_model == "linear_latent_estimator":
|
|
|
|
netG = LinearLatentEstimator(in_nc=3, nf=opt_net['nf'])
|
2020-11-12 22:42:05 +00:00
|
|
|
elif which_model == 'stylegan2':
|
2020-11-14 03:11:50 +00:00
|
|
|
is_structured = opt_net['structured'] if 'structured' in opt_net.keys() else False
|
2020-11-14 16:29:53 +00:00
|
|
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
2020-11-15 18:32:35 +00:00
|
|
|
netG = stylegan2.StyleGan2GeneratorWithLatent(image_size=opt_net['image_size'], latent_dim=opt_net['latent_dim'],
|
2020-11-14 16:29:53 +00:00
|
|
|
style_depth=opt_net['style_depth'], structure_input=is_structured,
|
|
|
|
attn_layers=attn)
|
2019-08-23 13:42:47 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
|
|
|
return netG
|
|
|
|
|
|
|
|
|
2020-08-25 23:58:20 +00:00
|
|
|
class GradDiscWrapper(torch.nn.Module):
|
|
|
|
def __init__(self, m):
|
|
|
|
super(GradDiscWrapper, self).__init__()
|
2020-08-26 00:14:45 +00:00
|
|
|
logger.info("Wrapping a discriminator..")
|
2020-08-25 23:58:20 +00:00
|
|
|
self.m = m
|
|
|
|
|
2020-08-26 00:14:45 +00:00
|
|
|
def forward(self, x):
|
|
|
|
return self.m(x)
|
2020-08-25 23:58:20 +00:00
|
|
|
|
|
|
|
def define_D_net(opt_net, img_sz=None, wrap=False):
|
2019-08-23 13:42:47 +00:00
|
|
|
which_model = opt_net['which_model_D']
|
|
|
|
|
2020-09-20 03:47:10 +00:00
|
|
|
if 'image_size' in opt_net.keys():
|
|
|
|
img_sz = opt_net['image_size']
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
if which_model == 'discriminator_vgg_128':
|
2020-08-31 15:50:30 +00:00
|
|
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, extra_conv=opt_net['extra_conv'])
|
2020-08-02 18:55:08 +00:00
|
|
|
elif which_model == 'discriminator_vgg_128_gn':
|
2020-08-31 15:50:30 +00:00
|
|
|
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
2020-08-26 00:14:45 +00:00
|
|
|
if wrap:
|
|
|
|
netD = GradDiscWrapper(netD)
|
2020-10-18 15:57:47 +00:00
|
|
|
elif which_model == 'discriminator_vgg_128_gn_checkpointed':
|
2020-10-18 18:10:24 +00:00
|
|
|
netD = SRGAN_arch.Discriminator_VGG_128_GN(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128, do_checkpointing=True)
|
2020-11-10 23:06:54 +00:00
|
|
|
elif which_model == 'stylegan_vgg':
|
|
|
|
netD = StyleGanDiscriminator(128)
|
2020-04-29 05:00:29 +00:00
|
|
|
elif which_model == 'discriminator_resnet':
|
2020-05-02 01:56:14 +00:00
|
|
|
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
2020-10-01 17:28:18 +00:00
|
|
|
elif which_model == 'discriminator_resnet_50':
|
|
|
|
netD = DiscriminatorResnet_arch.fixup_resnet50(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
2020-10-01 17:48:14 +00:00
|
|
|
elif which_model == 'resnext':
|
2020-10-01 21:49:28 +00:00
|
|
|
netD = torchvision.models.resnext50_32x4d(norm_layer=functools.partial(torch.nn.GroupNorm, 8))
|
2020-11-10 23:06:54 +00:00
|
|
|
#state_dict = torch.hub.load_state_dict_from_url('https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', progress=True)
|
|
|
|
#netD.load_state_dict(state_dict, strict=False)
|
2020-10-01 17:48:14 +00:00
|
|
|
netD.fc = torch.nn.Linear(512 * 4, 1)
|
2020-11-10 23:06:54 +00:00
|
|
|
elif which_model == 'biggan_resnet':
|
|
|
|
netD = BigGanDiscriminator(D_activation=torch.nn.LeakyReLU(negative_slope=.2))
|
2020-07-06 03:49:09 +00:00
|
|
|
elif which_model == 'discriminator_pix':
|
|
|
|
netD = SRGAN_arch.Discriminator_VGG_PixLoss(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
2020-07-10 22:16:03 +00:00
|
|
|
elif which_model == "discriminator_unet":
|
|
|
|
netD = SRGAN_arch.Discriminator_UNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
|
2020-07-16 16:10:09 +00:00
|
|
|
elif which_model == "discriminator_unet_fea":
|
2020-07-20 01:05:08 +00:00
|
|
|
netD = SRGAN_arch.Discriminator_UNet_FeaOut(in_nc=opt_net['in_nc'], nf=opt_net['nf'], feature_mode=opt_net['feature_mode'])
|
2020-07-23 02:52:59 +00:00
|
|
|
elif which_model == "discriminator_switched":
|
|
|
|
netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'],
|
|
|
|
final_temperature_step=opt_net['final_temperature_step'])
|
2020-08-06 14:56:21 +00:00
|
|
|
elif which_model == "cross_compare_vgg128":
|
2020-08-31 15:41:48 +00:00
|
|
|
netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale'])
|
2020-09-11 03:35:29 +00:00
|
|
|
elif which_model == "discriminator_refvgg":
|
|
|
|
netD = SRGAN_arch.RefDiscriminatorVgg128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
2020-10-31 17:08:55 +00:00
|
|
|
elif which_model == "psnr_approximator":
|
|
|
|
netD = SRGAN_arch.PsnrApproximator(nf=opt_net['nf'], input_img_factor=img_sz / 128)
|
2020-11-11 20:40:24 +00:00
|
|
|
elif which_model == "pyramid_disc":
|
|
|
|
netD = SRGAN_arch.PyramidDiscriminator(in_nc=3, nf=opt_net['nf'])
|
2020-11-12 22:42:05 +00:00
|
|
|
elif which_model == "stylegan2_discriminator":
|
2020-11-14 16:29:53 +00:00
|
|
|
attn = opt_net['attn_layers'] if 'attn_layers' in opt_net.keys() else []
|
2020-11-15 18:32:35 +00:00
|
|
|
disc = stylegan2.StyleGan2Discriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'], attn_layers=attn)
|
|
|
|
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
2020-11-15 05:04:48 +00:00
|
|
|
elif which_model == "stylegan2_unet":
|
2020-11-15 18:32:35 +00:00
|
|
|
disc = stylegan2_unet.StyleGan2UnetDiscriminator(image_size=opt_net['image_size'], input_filters=opt_net['in_nc'])
|
|
|
|
netD = stylegan2.StyleGan2Augmentor(disc, opt_net['image_size'], types=opt_net['augmentation_types'], prob=opt_net['augmentation_probability'])
|
2019-08-23 13:42:47 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
|
|
|
return netD
|
|
|
|
|
2020-07-31 20:59:54 +00:00
|
|
|
# Discriminator
|
2020-08-25 23:58:20 +00:00
|
|
|
def define_D(opt, wrap=False):
|
2020-07-31 20:59:54 +00:00
|
|
|
img_sz = opt['datasets']['train']['target_size']
|
|
|
|
opt_net = opt['network_D']
|
2020-08-25 23:58:20 +00:00
|
|
|
return define_D_net(opt_net, img_sz, wrap=wrap)
|
2020-07-31 20:59:54 +00:00
|
|
|
|
|
|
|
def define_fixed_D(opt):
|
|
|
|
# Note that this will not work with "old" VGG-style discriminators with dense blocks until the img_size parameter is added.
|
|
|
|
net = define_D_net(opt)
|
|
|
|
|
|
|
|
# Load the model parameters:
|
|
|
|
load_net = torch.load(opt['pretrained_path'])
|
|
|
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
|
|
|
for k, v in load_net.items():
|
|
|
|
if k.startswith('module.'):
|
|
|
|
load_net_clean[k[7:]] = v
|
|
|
|
else:
|
|
|
|
load_net_clean[k] = v
|
|
|
|
net.load_state_dict(load_net_clean)
|
|
|
|
|
|
|
|
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
|
|
|
net.eval()
|
|
|
|
for k, v in net.named_parameters():
|
|
|
|
v.requires_grad = False
|
|
|
|
net.fdisc_weight = opt['weight']
|
|
|
|
|
2020-07-31 21:07:10 +00:00
|
|
|
return net
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
|
|
|
|
# Define network used for perceptual loss
|
2020-09-23 17:56:36 +00:00
|
|
|
def define_F(which_model='vgg', use_bn=False, for_training=False, load_path=None, feature_layers=None):
|
2020-08-22 19:08:33 +00:00
|
|
|
if which_model == 'vgg':
|
2020-05-29 02:26:30 +00:00
|
|
|
# PyTorch pretrained VGG19-54, before ReLU.
|
2020-09-23 17:56:36 +00:00
|
|
|
if feature_layers is None:
|
|
|
|
if use_bn:
|
|
|
|
feature_layers = [49]
|
|
|
|
else:
|
|
|
|
feature_layers = [34]
|
2020-07-31 17:20:39 +00:00
|
|
|
if for_training:
|
2020-09-23 17:56:36 +00:00
|
|
|
netF = feature_arch.TrainableVGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
2020-08-22 19:08:33 +00:00
|
|
|
use_input_norm=True)
|
2020-07-31 17:20:39 +00:00
|
|
|
else:
|
2020-09-23 17:56:36 +00:00
|
|
|
netF = feature_arch.VGGFeatureExtractor(feature_layers=feature_layers, use_bn=use_bn,
|
2020-08-22 19:08:33 +00:00
|
|
|
use_input_norm=True)
|
|
|
|
elif which_model == 'wide_resnet':
|
|
|
|
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError
|
2020-05-29 02:26:30 +00:00
|
|
|
|
2020-07-31 22:29:47 +00:00
|
|
|
if load_path:
|
|
|
|
# Load the model parameters:
|
|
|
|
load_net = torch.load(load_path)
|
|
|
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
|
|
|
for k, v in load_net.items():
|
|
|
|
if k.startswith('module.'):
|
|
|
|
load_net_clean[k[7:]] = v
|
|
|
|
else:
|
|
|
|
load_net_clean[k] = v
|
|
|
|
netF.load_state_dict(load_net_clean)
|
|
|
|
|
2020-08-23 23:22:34 +00:00
|
|
|
if not for_training:
|
2020-07-31 22:29:47 +00:00
|
|
|
# Put into eval mode, freeze the parameters and set the 'weight' field.
|
|
|
|
netF.eval()
|
|
|
|
for k, v in netF.named_parameters():
|
|
|
|
v.requires_grad = False
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
return netF
|