Enable disjoint feature networks

This is done by pre-training a feature net that predicts the features
of HR images from LR images. Then use the original feature network
and this new one in tandem to work only on LR/Gen images.
This commit is contained in:
James Betker 2020-07-31 16:29:47 -06:00
parent 6e086d0c20
commit eb11a08d1c
5 changed files with 72 additions and 43 deletions

View File

@ -15,10 +15,9 @@ class LQGTDataset(data.Dataset):
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
If only GT images are provided, generate LQ images on-the-fly. If only GT images are provided, generate LQ images on-the-fly.
""" """
def get_lq_path(self, i): def get_lq_path(self, i):
which_lq = random.randint(0, len(self.paths_LQ)-1) which_lq = random.randint(0, len(self.paths_LQ)-1)
return self.paths_LQ[which_lq][i] return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])]
def __init__(self, opt): def __init__(self, opt):
super(LQGTDataset, self).__init__() super(LQGTDataset, self).__init__()
@ -53,11 +52,6 @@ class LQGTDataset(data.Dataset):
print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,)) print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,))
assert self.paths_GT, 'Error: GT path is empty.' assert self.paths_GT, 'Error: GT path is empty.'
if self.paths_LQ and self.paths_GT:
assert len(self.paths_LQ[0]) == len(
self.paths_GT
), 'GT and LQ datasets have different number of images - {}, {}.'.format(
len(self.paths_LQ[0]), len(self.paths_GT))
self.random_scale_list = [1] self.random_scale_list = [1]
def _init_lmdb(self): def _init_lmdb(self):
@ -85,7 +79,7 @@ class LQGTDataset(data.Dataset):
GT_size = self.opt['target_size'] GT_size = self.opt['target_size']
# get GT image # get GT image
GT_path = self.paths_GT[index] GT_path = self.paths_GT[index % len(self.paths_GT)]
resolution = [int(s) for s in self.sizes_GT[index].split('_') resolution = [int(s) for s in self.sizes_GT[index].split('_')
] if self.data_type == 'lmdb' else None ] if self.data_type == 'lmdb' else None
img_GT = util.read_img(self.GT_env, GT_path, resolution) img_GT = util.read_img(self.GT_env, GT_path, resolution)

View File

@ -42,6 +42,7 @@ class SRGANModel(BaseModel):
else: else:
self.netC = None self.netC = None
self.mega_batch_factor = 1 self.mega_batch_factor = 1
self.disjoint_data = False
# define losses, optimizer and scheduler # define losses, optimizer and scheduler
if self.is_train: if self.is_train:
@ -101,16 +102,28 @@ class SRGANModel(BaseModel):
self.cri_fea = None self.cri_fea = None
if self.cri_fea: # load VGG perceptual loss if self.cri_fea: # load VGG perceptual loss
self.netF = networks.define_F(opt, use_bn=False).to(self.device) self.netF = networks.define_F(opt, use_bn=False).to(self.device)
self.lr_netF = None
if 'lr_fea_path' in train_opt.keys():
self.lr_netF = networks.define_F(opt, use_bn=False, load_path=train_opt['lr_fea_path']).to(self.device)
self.disjoint_data = True
if opt['dist']: if opt['dist']:
pass # do not need to use DistributedDataParallel for netF pass # do not need to use DistributedDataParallel for netF
else: else:
self.netF = DataParallel(self.netF) self.netF = DataParallel(self.netF)
if self.lr_netF:
self.lr_netF = DataParallel(self.lr_netF)
# You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses. # You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses.
self.fixed_disc_nets = [] self.fixed_disc_nets = []
if 'fixed_discriminators' in opt.keys(): if 'fixed_discriminators' in opt.keys():
for opt_fdisc in opt['fixed_discriminators'].keys(): for opt_fdisc in opt['fixed_discriminators'].keys():
self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device)) netFD = networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device)
if opt['dist']:
pass # do not need to use DistributedDataParallel for netF
else:
netFD = DataParallel(netFD)
self.fixed_disc_nets.append(netFD)
# GD gan loss # GD gan loss
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
@ -330,7 +343,10 @@ class SRGANModel(BaseModel):
l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) l_g_fdpl = self.cri_fdpl(fea_GenOut, pix)
l_g_total += l_g_fdpl * self.fdpl_weight l_g_total += l_g_fdpl * self.fdpl_weight
if self.cri_fea and not using_gan_img: # feature loss if self.cri_fea and not using_gan_img: # feature loss
real_fea = self.netF(pix).detach() if self.lr_netF is not None:
real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale'])
else:
real_fea = self.netF(pix).detach()
fake_fea = self.netF(fea_GenOut) fake_fea = self.netF(fea_GenOut)
fea_w = self.l_fea_sched.get_weight_for_step(step) fea_w = self.l_fea_sched.get_weight_for_step(step)
l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea) l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea)
@ -346,7 +362,7 @@ class SRGANModel(BaseModel):
# equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically,
# it should target this # it should target this
l_g_fix_disc = 0 l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze()
for fixed_disc in self.fixed_disc_nets: for fixed_disc in self.fixed_disc_nets:
weight = fixed_disc.fdisc_weight weight = fixed_disc.fdisc_weight
real_fea = fixed_disc(pix).detach() real_fea = fixed_disc(pix).detach()
@ -439,33 +455,34 @@ class SRGANModel(BaseModel):
with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled:
l_d_fake_scaled.backward() l_d_fake_scaled.backward()
if 'pixgan' in self.opt['train']['gan_type']: if 'pixgan' in self.opt['train']['gan_type']:
# randomly determine portions of the image to swap to keep the discriminator honest. if not self.disjoint_data:
pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() # randomly determine portions of the image to swap to keep the discriminator honest.
disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters()
b, _, w, h = var_ref.shape disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction)
real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) b, _, w, h = var_ref.shape
fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device)
SWAP_MAX_DIM = w // 4 fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device)
SWAP_MIN_DIM = 16 SWAP_MAX_DIM = w // 4
assert SWAP_MAX_DIM > 0 SWAP_MIN_DIM = 16
if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen assert SWAP_MAX_DIM > 0
# more often and the model was "cheating" by using the presence of if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen
# easily discriminated fake swaps to count the entire generated image # more often and the model was "cheating" by using the presence of
# as fake. # easily discriminated fake swaps to count the entire generated image
random_swap_count = random.randint(0, 4) # as fake.
for i in range(random_swap_count): random_swap_count = random.randint(0, 4)
# Make the swap across fake_H and var_ref for i in range(random_swap_count):
swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) # Make the swap across fake_H and var_ref
swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM)
if swap_x + swap_w > w: swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM)
swap_w = w - swap_x if swap_x + swap_w > w:
if swap_y + swap_h > h: swap_w = w - swap_x
swap_h = h - swap_y if swap_y + swap_h > h:
t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() swap_h = h - swap_y
fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone()
var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)]
real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0
fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0
# Interpolate down to the dimensionality that the discriminator uses. # Interpolate down to the dimensionality that the discriminator uses.
real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear") real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear")

View File

@ -26,8 +26,10 @@ class VGGFeatureExtractor(nn.Module):
for k, v in self.features.named_parameters(): for k, v in self.features.named_parameters():
v.requires_grad = False v.requires_grad = False
def forward(self, x): def forward(self, x, interpolate_factor=1):
# Assume input range is [0, 1] if interpolate_factor > 1:
x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic')
if self.use_input_norm: if self.use_input_norm:
x = (x - self.mean) / self.std x = (x - self.mean) / self.std
output = self.features(x) output = self.features(x)

View File

@ -168,7 +168,7 @@ def define_fixed_D(opt):
# Define network used for perceptual loss # Define network used for perceptual loss
def define_F(opt, use_bn=False, for_training=False): def define_F(opt, use_bn=False, for_training=False, load_path=None):
gpu_ids = opt['gpu_ids'] gpu_ids = opt['gpu_ids']
device = torch.device('cuda' if gpu_ids else 'cpu') device = torch.device('cuda' if gpu_ids else 'cpu')
if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg': if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg':
@ -186,5 +186,21 @@ def define_F(opt, use_bn=False, for_training=False):
elif opt['train']['which_model_F'] == 'wide_resnet': elif opt['train']['which_model_F'] == 'wide_resnet':
netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device) netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device)
netF.eval() # No need to train if load_path:
# Load the model parameters:
load_net = torch.load(load_path)
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):
load_net_clean[k[7:]] = v
else:
load_net_clean[k] = v
netF.load_state_dict(load_net_clean)
# Put into eval mode, freeze the parameters and set the 'weight' field.
netF.eval()
for k, v in netF.named_parameters():
v.requires_grad = False
netF.fdisc_weight = opt['weight']
return netF return netF

View File

@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs):
def main(): def main():
#### options #### options
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher') help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)