Enable disjoint feature networks

This is done by pre-training a feature net that predicts the features
of HR images from LR images. Then use the original feature network
and this new one in tandem to work only on LR/Gen images.
This commit is contained in:
James Betker 2020-07-31 16:29:47 -06:00
parent 6e086d0c20
commit eb11a08d1c
5 changed files with 72 additions and 43 deletions

View File

@ -15,10 +15,9 @@ class LQGTDataset(data.Dataset):
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
If only GT images are provided, generate LQ images on-the-fly.
"""
def get_lq_path(self, i):
which_lq = random.randint(0, len(self.paths_LQ)-1)
return self.paths_LQ[which_lq][i]
return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
def __init__(self, opt):
super(LQGTDataset, self).__init__()
@ -53,11 +52,6 @@ class LQGTDataset(data.Dataset):
print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,))
assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT:
assert len(self.paths_LQ[0]) == len(
self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
len(self.paths_LQ[0]), len(self.paths_GT))
self.random_scale_list = [1]
def _init_lmdb(self):
@ -85,7 +79,7 @@ class LQGTDataset(data.Dataset):
GT_size = self.opt['target_size']
# get GT image
GT_path = self.paths_GT[index]
GT_path = self.paths_GT[index % len(self.paths_GT)]
resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution)

View File

@ -42,6 +42,7 @@ class SRGANModel(BaseModel):
else:
self.netC = None
self.mega_batch_factor = 1
self.disjoint_data = False
# define losses, optimizer and scheduler
if self.is_train:
@ -101,16 +102,28 @@ class SRGANModel(BaseModel):
self.cri_fea = None
if self.cri_fea: # load VGG perceptual loss
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
self.lr_netF = None
if 'lr_fea_path' in train_opt.keys():
self.lr_netF = networks.define_F(opt, use_bn=False, load_path=train_opt['lr_fea_path']).to(self.device)
self.disjoint_data = True
if opt['dist']:
pass # do not need to use DistributedDataParallel for netF
else:
self.netF = DataParallel(self.netF)
if self.lr_netF:
self.lr_netF = DataParallel(self.lr_netF)
# You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses.
self.fixed_disc_nets = []
if 'fixed_discriminators' in opt.keys():
for opt_fdisc in opt['fixed_discriminators'].keys():
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device))
netFD = networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device)
if opt['dist']:
pass # do not need to use DistributedDataParallel for netF
else:
netFD = DataParallel(netFD)
self.fixed_disc_nets.append(netFD)
# GD gan loss
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
@ -330,6 +343,9 @@ class SRGANModel(BaseModel):
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
l_g_total += l_g_fdpl * self.fdpl_weight
if self.cri_fea and not using_gan_img: # feature loss
if self.lr_netF is not None:
real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale'])
else:
real_fea = self.netF(pix).detach()
fake_fea = self.netF(fea_GenOut)
fea_w = self.l_fea_sched.get_weight_for_step(step)
@ -346,7 +362,7 @@ class SRGANModel(BaseModel):
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
# it should target this
l_g_fix_disc = 0
l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze()
for fixed_disc in self.fixed_disc_nets:
weight = fixed_disc.fdisc_weight
real_fea = fixed_disc(pix).detach()
@ -439,6 +455,7 @@ class SRGANModel(BaseModel):
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward()
if 'pixgan' in self.opt['train']['gan_type']:
if not self.disjoint_data:
# randomly determine portions of the image to swap to keep the discriminator honest.
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)

View File

@ -26,8 +26,10 @@ class VGGFeatureExtractor(nn.Module):
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
# Assume input range is [0, 1]
def forward(self, x, interpolate_factor=1):
if interpolate_factor > 1:
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)

View File

@ -168,7 +168,7 @@ def define_fixed_D(opt):
# Define network used for perceptual loss
def define_F(opt, use_bn=False, for_training=False):
def define_F(opt, use_bn=False, for_training=False, load_path=None):
gpu_ids = opt['gpu_ids']
device = torch.device('cuda' if gpu_ids else 'cpu')
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
@ -186,5 +186,21 @@ def define_F(opt, use_bn=False, for_training=False):
elif opt['train']['which_model_F'] == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
netF.eval() # No need to train
if load_path:
# Load the model parameters:
load_net = torch.load(load_path)
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
netF.load_state_dict(load_net_clean)
# Put into eval mode, freeze the parameters and set the 'weight' field.
netF.eval()
for k, v in netF.named_parameters():
v.requires_grad = False
netF.fdisc_weight = opt['weight']
return netF

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)