Enable disjoint feature networks

This is done by pre-training a feature net that predicts the features
of HR images from LR images. Then use the original feature network
and this new one in tandem to work only on LR/Gen images.
This commit is contained in:
James Betker 2020-07-31 16:29:47 -06:00
parent 6e086d0c20
commit eb11a08d1c
5 changed files with 72 additions and 43 deletions

View File

@ -15,10 +15,9 @@ class LQGTDataset(data.Dataset):
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
If only GT images are provided, generate LQ images on-the-fly. If only GT images are provided, generate LQ images on-the-fly.
""" """
def get_lq_path(self, i): def get_lq_path(self, i):
which_lq = random.randint(0, len(self.paths_LQ)-1) which_lq = random.randint(0, len(self.paths_LQ)-1)
return self.paths_LQ[which_lq][i] return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
def __init__(self, opt): def __init__(self, opt):
super(LQGTDataset, self).__init__() super(LQGTDataset, self).__init__()
@ -53,11 +52,6 @@ class LQGTDataset(data.Dataset):
print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,)) print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,))
assert self.paths_GT, 'Error: GT path is empty.' assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT:
assert len(self.paths_LQ[0]) == len(
self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
len(self.paths_LQ[0]), len(self.paths_GT))
self.random_scale_list = [1] self.random_scale_list = [1]
def _init_lmdb(self): def _init_lmdb(self):
@ -85,7 +79,7 @@ class LQGTDataset(data.Dataset):
GT_size = self.opt['target_size'] GT_size = self.opt['target_size']
# get GT image # get GT image
GT_path = self.paths_GT[index] GT_path = self.paths_GT[index % len(self.paths_GT)]
resolution = [int(s) for s in self.sizes_GT[index].split('_') resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None ] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution) img_GT = util.read_img(self.GT_env, GT_path, resolution)

View File

@ -42,6 +42,7 @@ class SRGANModel(BaseModel):
else: else:
self.netC = None self.netC = None
self.mega_batch_factor = 1 self.mega_batch_factor = 1
self.disjoint_data = False
# define losses, optimizer and scheduler # define losses, optimizer and scheduler
if self.is_train: if self.is_train:
@ -101,16 +102,28 @@ class SRGANModel(BaseModel):
self.cri_fea = None self.cri_fea = None
if self.cri_fea: # load VGG perceptual loss if self.cri_fea: # load VGG perceptual loss
self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.netF = networks.define_F(opt, use_bn=False).to(self.device)
self.lr_netF = None
if 'lr_fea_path' in train_opt.keys():
self.lr_netF = networks.define_F(opt, use_bn=False, load_path=train_opt['lr_fea_path']).to(self.device)
self.disjoint_data = True
if opt['dist']: if opt['dist']:
pass # do not need to use DistributedDataParallel for netF pass # do not need to use DistributedDataParallel for netF
else: else:
self.netF = DataParallel(self.netF) self.netF = DataParallel(self.netF)
if self.lr_netF:
self.lr_netF = DataParallel(self.lr_netF)
# You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses. # You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses.
self.fixed_disc_nets = [] self.fixed_disc_nets = []
if 'fixed_discriminators' in opt.keys(): if 'fixed_discriminators' in opt.keys():
for opt_fdisc in opt['fixed_discriminators'].keys(): for opt_fdisc in opt['fixed_discriminators'].keys():
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device)) netFD = networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device)
if opt['dist']:
pass # do not need to use DistributedDataParallel for netF
else:
netFD = DataParallel(netFD)
self.fixed_disc_nets.append(netFD)
# GD gan loss # GD gan loss
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
@ -330,6 +343,9 @@ class SRGANModel(BaseModel):
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
l_g_total += l_g_fdpl * self.fdpl_weight l_g_total += l_g_fdpl * self.fdpl_weight
if self.cri_fea and not using_gan_img: # feature loss if self.cri_fea and not using_gan_img: # feature loss
if self.lr_netF is not None:
real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale'])
else:
real_fea = self.netF(pix).detach() real_fea = self.netF(pix).detach()
fake_fea = self.netF(fea_GenOut) fake_fea = self.netF(fea_GenOut)
fea_w = self.l_fea_sched.get_weight_for_step(step) fea_w = self.l_fea_sched.get_weight_for_step(step)
@ -346,7 +362,7 @@ class SRGANModel(BaseModel):
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
# it should target this # it should target this
l_g_fix_disc = 0 l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze()
for fixed_disc in self.fixed_disc_nets: for fixed_disc in self.fixed_disc_nets:
weight = fixed_disc.fdisc_weight weight = fixed_disc.fdisc_weight
real_fea = fixed_disc(pix).detach() real_fea = fixed_disc(pix).detach()
@ -439,6 +455,7 @@ class SRGANModel(BaseModel):
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
if 'pixgan' in self.opt['train']['gan_type']: if 'pixgan' in self.opt['train']['gan_type']:
if not self.disjoint_data:
# randomly determine portions of the image to swap to keep the discriminator honest. # randomly determine portions of the image to swap to keep the discriminator honest.
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)

View File

@ -26,8 +26,10 @@ class VGGFeatureExtractor(nn.Module):
for k, v in self.features.named_parameters(): for k, v in self.features.named_parameters():
v.requires_grad = False v.requires_grad = False
def forward(self, x): def forward(self, x, interpolate_factor=1):
# Assume input range is [0, 1] if interpolate_factor > 1:
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
if self.use_input_norm: if self.use_input_norm:
x = (x - self.mean) / self.std x = (x - self.mean) / self.std
output = self.features(x) output = self.features(x)

View File

@ -168,7 +168,7 @@ def define_fixed_D(opt):
# Define network used for perceptual loss # Define network used for perceptual loss
def define_F(opt, use_bn=False, for_training=False): def define_F(opt, use_bn=False, for_training=False, load_path=None):
gpu_ids = opt['gpu_ids'] gpu_ids = opt['gpu_ids']
device = torch.device('cuda' if gpu_ids else 'cpu') device = torch.device('cuda' if gpu_ids else 'cpu')
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg': if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
@ -186,5 +186,21 @@ def define_F(opt, use_bn=False, for_training=False):
elif opt['train']['which_model_F'] == 'wide_resnet': elif opt['train']['which_model_F'] == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device) netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
netF.eval() # No need to train if load_path:
# Load the model parameters:
load_net = torch.load(load_path)
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
netF.load_state_dict(load_net_clean)
# Put into eval mode, freeze the parameters and set the 'weight' field.
netF.eval()
for k, v in netF.named_parameters():
v.requires_grad = False
netF.fdisc_weight = opt['weight']
return netF return netF

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)