2019-08-23 13:42:47 +00:00
|
|
|
import torch
|
|
|
|
import models.archs.SRResNet_arch as SRResNet_arch
|
|
|
|
import models.archs.discriminator_vgg_arch as SRGAN_arch
|
2020-04-29 05:00:29 +00:00
|
|
|
import models.archs.DiscriminatorResnet_arch as DiscriminatorResnet_arch
|
2020-05-04 20:01:43 +00:00
|
|
|
import models.archs.DiscriminatorResnet_arch_passthrough as DiscriminatorResnet_arch_passthrough
|
2020-05-01 01:17:30 +00:00
|
|
|
import models.archs.FlatProcessorNetNew_arch as FlatProcessorNetNew_arch
|
2019-08-23 13:42:47 +00:00
|
|
|
import models.archs.RRDBNet_arch as RRDBNet_arch
|
2020-04-24 06:00:46 +00:00
|
|
|
import models.archs.HighToLowResNet as HighToLowResNet
|
2020-06-29 03:21:57 +00:00
|
|
|
import models.archs.NestedSwitchGenerator as ng
|
2020-05-29 02:26:30 +00:00
|
|
|
import models.archs.feature_arch as feature_arch
|
2020-06-16 17:23:50 +00:00
|
|
|
import models.archs.SwitchedResidualGenerator_arch as SwitchedGen_arch
|
2020-06-07 00:29:25 +00:00
|
|
|
import functools
|
2019-08-23 13:42:47 +00:00
|
|
|
|
|
|
|
# Generator
|
2020-05-13 21:26:55 +00:00
|
|
|
def define_G(opt, net_key='network_G'):
|
|
|
|
opt_net = opt[net_key]
|
2019-08-23 13:42:47 +00:00
|
|
|
which_model = opt_net['which_model_G']
|
2020-04-22 06:37:41 +00:00
|
|
|
scale = opt['scale']
|
2019-08-23 13:42:47 +00:00
|
|
|
|
|
|
|
# image restoration
|
|
|
|
if which_model == 'MSRResNet':
|
|
|
|
netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
|
|
nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
|
|
|
|
elif which_model == 'RRDBNet':
|
2020-04-22 06:37:41 +00:00
|
|
|
# RRDB does scaling in two steps, so take the sqrt of the scale we actually want to achieve and feed it to RRDB.
|
2020-06-02 16:47:15 +00:00
|
|
|
initial_stride = 1 if 'initial_stride' not in opt_net else opt_net['initial_stride']
|
|
|
|
assert initial_stride == 1 or initial_stride == 2
|
|
|
|
# Need to adjust the scale the generator sees by the stride since the stride causes a down-sample.
|
|
|
|
gen_scale = scale * initial_stride
|
2019-08-23 13:42:47 +00:00
|
|
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
2020-06-02 17:15:55 +00:00
|
|
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, initial_stride=initial_stride)
|
2020-05-24 03:09:21 +00:00
|
|
|
elif which_model == 'AssistedRRDBNet':
|
|
|
|
netG = RRDBNet_arch.AssistedRRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
|
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale)
|
2020-06-07 00:29:25 +00:00
|
|
|
elif which_model == 'AttentiveRRDBNet':
|
|
|
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
|
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=scale,
|
2020-06-11 03:45:24 +00:00
|
|
|
rrdb_block_f=functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
2020-06-08 17:10:38 +00:00
|
|
|
init_temperature=opt_net['temperature'],
|
|
|
|
final_temperature_step=opt_net['temperature_final_step']))
|
2020-06-13 17:37:27 +00:00
|
|
|
elif which_model == 'LowDimRRDBNet':
|
2020-06-14 17:02:16 +00:00
|
|
|
gen_scale = scale * opt_net['initial_stride']
|
2020-06-13 17:37:27 +00:00
|
|
|
rrdb = functools.partial(RRDBNet_arch.LowDimRRDB, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'])
|
|
|
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
2020-06-14 17:02:16 +00:00
|
|
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
|
2020-06-13 17:37:27 +00:00
|
|
|
elif which_model == "LowDimRRDBWithMultiHeadSwitching":
|
2020-06-14 17:02:16 +00:00
|
|
|
gen_scale = scale * opt_net['initial_stride']
|
2020-06-13 17:37:27 +00:00
|
|
|
switcher = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=opt_net['num_convs'], num_heads=opt_net['num_heads'],
|
|
|
|
init_temperature=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'])
|
|
|
|
rrdb = functools.partial(RRDBNet_arch.LowDimRRDBWrapper, nf=opt_net['nf'], gc=opt_net['gc'], dimensional_adjustment=opt_net['dim'],
|
|
|
|
partial_rrdb=switcher)
|
|
|
|
netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
2020-06-14 17:02:16 +00:00
|
|
|
nf=opt_net['nf'], nb=opt_net['nb'], scale=gen_scale, rrdb_block_f=rrdb, initial_stride=opt_net['initial_stride'])
|
2020-06-09 19:28:55 +00:00
|
|
|
elif which_model == 'PixRRDBNet':
|
|
|
|
block_f = None
|
|
|
|
if opt_net['attention']:
|
2020-06-11 03:45:24 +00:00
|
|
|
block_f = functools.partial(RRDBNet_arch.SwitchedRRDB, nf=opt_net['nf'], gc=opt_net['gc'],
|
|
|
|
init_temperature=opt_net['temperature'],
|
2020-06-14 18:46:54 +00:00
|
|
|
final_temperature_step=opt_net['temperature_final_step'])
|
|
|
|
if opt_net['mhattention']:
|
|
|
|
block_f = functools.partial(RRDBNet_arch.SwitchedMultiHeadRRDB, num_convs=8, num_heads=2, nf=opt_net['nf'], gc=opt_net['gc'],
|
|
|
|
init_temperature=opt_net['temperature'],
|
2020-06-11 03:45:24 +00:00
|
|
|
final_temperature_step=opt_net['temperature_final_step'])
|
2020-06-09 19:28:55 +00:00
|
|
|
netG = RRDBNet_arch.PixShuffleRRDB(nf=opt_net['nf'], nb=opt_net['nb'], gc=opt_net['gc'], scale=scale, rrdb_block_f=block_f)
|
2020-06-16 19:24:07 +00:00
|
|
|
elif which_model == "ConfigurableSwitchedResidualGenerator":
|
2020-06-22 16:40:16 +00:00
|
|
|
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
|
|
|
|
switch_reductions=opt_net['switch_reductions'],
|
2020-06-16 20:19:12 +00:00
|
|
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
2020-06-16 19:24:07 +00:00
|
|
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
2020-06-17 23:18:28 +00:00
|
|
|
trans_filters_mid=opt_net['trans_filters_mid'],
|
2020-06-18 17:29:31 +00:00
|
|
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
2020-06-25 01:49:37 +00:00
|
|
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
|
|
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
|
|
|
elif which_model == "ConfigurableSwitchedResidualGenerator2":
|
|
|
|
netG = SwitchedGen_arch.ConfigurableSwitchedResidualGenerator2(switch_filters=opt_net['switch_filters'], switch_growths=opt_net['switch_growths'],
|
|
|
|
switch_reductions=opt_net['switch_reductions'],
|
|
|
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
|
|
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
2020-06-26 00:36:06 +00:00
|
|
|
transformation_filters=opt_net['transformation_filters'],
|
2020-06-25 01:49:37 +00:00
|
|
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
2020-06-29 03:21:57 +00:00
|
|
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
|
|
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
|
|
|
elif which_model == "NestedSwitchGenerator":
|
|
|
|
netG = ng.NestedSwitchedGenerator(switch_filters=opt_net['switch_filters'],
|
|
|
|
switch_reductions=opt_net['switch_reductions'],
|
|
|
|
switch_processing_layers=opt_net['switch_processing_layers'], trans_counts=opt_net['trans_counts'],
|
|
|
|
trans_kernel_sizes=opt_net['trans_kernel_sizes'], trans_layers=opt_net['trans_layers'],
|
|
|
|
transformation_filters=opt_net['transformation_filters'],
|
|
|
|
initial_temp=opt_net['temperature'], final_temperature_step=opt_net['temperature_final_step'],
|
2020-06-19 15:18:30 +00:00
|
|
|
heightened_temp_min=opt_net['heightened_temp_min'], heightened_final_step=opt_net['heightened_final_step'],
|
2020-06-23 16:16:02 +00:00
|
|
|
upsample_factor=scale, add_scalable_noise_to_transforms=opt_net['add_noise'])
|
2020-05-05 17:59:46 +00:00
|
|
|
|
2020-04-24 06:00:46 +00:00
|
|
|
# image corruption
|
|
|
|
elif which_model == 'HighToLowResNet':
|
|
|
|
netG = HighToLowResNet.HighToLowResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
|
|
|
nf=opt_net['nf'], nb=opt_net['nb'], downscale=opt_net['scale'])
|
2020-04-28 17:48:05 +00:00
|
|
|
elif which_model == 'FlatProcessorNet':
|
2020-05-01 01:17:30 +00:00
|
|
|
'''netG = FlatProcessorNet_arch.FlatProcessorNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
|
2020-04-28 17:48:05 +00:00
|
|
|
nf=opt_net['nf'], downscale=opt_net['scale'], reduce_anneal_blocks=opt_net['ra_blocks'],
|
2020-05-01 01:17:30 +00:00
|
|
|
assembler_blocks=opt_net['assembler_blocks'])'''
|
2020-06-16 03:32:03 +00:00
|
|
|
netG = FlatProcessorNetNew_arch.fixup_resnet34(num_filters=opt_net['nf'])\
|
2020-04-22 06:37:41 +00:00
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
|
|
|
|
|
|
|
return netG
|
|
|
|
|
|
|
|
|
|
|
|
# Discriminator
|
|
|
|
def define_D(opt):
|
2020-04-22 06:37:41 +00:00
|
|
|
img_sz = opt['datasets']['train']['target_size']
|
2019-08-23 13:42:47 +00:00
|
|
|
opt_net = opt['network_D']
|
|
|
|
which_model = opt_net['which_model_D']
|
|
|
|
|
|
|
|
if which_model == 'discriminator_vgg_128':
|
2020-06-23 15:40:33 +00:00
|
|
|
netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], input_img_factor=img_sz // 128, extra_conv=opt_net['extra_conv'])
|
2020-04-29 05:00:29 +00:00
|
|
|
elif which_model == 'discriminator_resnet':
|
2020-05-02 01:56:14 +00:00
|
|
|
netD = DiscriminatorResnet_arch.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz)
|
2020-05-04 20:01:43 +00:00
|
|
|
elif which_model == 'discriminator_resnet_passthrough':
|
2020-05-15 19:50:49 +00:00
|
|
|
netD = DiscriminatorResnet_arch_passthrough.fixup_resnet34(num_filters=opt_net['nf'], num_classes=1, input_img_size=img_sz,
|
2020-05-19 15:41:16 +00:00
|
|
|
number_skips=opt_net['number_skips'], use_bn=True,
|
|
|
|
disable_passthrough=opt_net['disable_passthrough'])
|
2019-08-23 13:42:47 +00:00
|
|
|
else:
|
|
|
|
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
|
|
|
return netD
|
|
|
|
|
|
|
|
|
|
|
|
# Define network used for perceptual loss
|
|
|
|
def define_F(opt, use_bn=False):
|
|
|
|
gpu_ids = opt['gpu_ids']
|
|
|
|
device = torch.device('cuda' if gpu_ids else 'cpu')
|
2020-05-29 02:26:30 +00:00
|
|
|
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
|
|
|
|
# PyTorch pretrained VGG19-54, before ReLU.
|
|
|
|
if use_bn:
|
|
|
|
feature_layer = 49
|
|
|
|
else:
|
|
|
|
feature_layer = 34
|
|
|
|
netF = feature_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
|
|
|
|
use_input_norm=True, device=device)
|
|
|
|
elif opt['train']['which_model_F'] == 'wide_resnet':
|
|
|
|
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
|
|
|
|
|
2019-08-23 13:42:47 +00:00
|
|
|
netF.eval() # No need to train
|
|
|
|
return netF
|