From eb11a08d1c272058c04b9b661bc4e1ff6cce99d8 Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 31 Jul 2020 16:29:47 -0600 Subject: [PATCH] Enable disjoint feature networks This is done by pre-training a feature net that predicts the features of HR images from LR images. Then use the original feature network and this new one in tandem to work only on LR/Gen images. --- codes/data/LQGT_dataset.py | 10 +--- codes/models/SRGAN_model.py | 77 ++++++++++++++++++------------ codes/models/archs/feature_arch.py | 6 ++- codes/models/networks.py | 20 +++++++- codes/train.py | 2 +- 5 files changed, 72 insertions(+), 43 deletions(-) diff --git a/codes/data/LQGT_dataset.py b/codes/data/LQGT_dataset.py index 8d0ec157..ca872e2e 100644 --- a/codes/data/LQGT_dataset.py +++ b/codes/data/LQGT_dataset.py @@ -15,10 +15,9 @@ class LQGTDataset(data.Dataset): Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. If only GT images are provided, generate LQ images on-the-fly. """ - def get_lq_path(self, i): which_lq = random.randint(0, len(self.paths_LQ)-1) - return self.paths_LQ[which_lq][i] + return self.paths_LQ[which_lq][i % len(self.paths_LQ[which_lq])] def __init__(self, opt): super(LQGTDataset, self).__init__() @@ -53,11 +52,6 @@ class LQGTDataset(data.Dataset): print('loaded %i images for use in training GAN only.' % (self.sizes_GAN,)) assert self.paths_GT, 'Error: GT path is empty.' - if self.paths_LQ and self.paths_GT: - assert len(self.paths_LQ[0]) == len( - self.paths_GT - ), 'GT and LQ datasets have different number of images - {}, {}.'.format( - len(self.paths_LQ[0]), len(self.paths_GT)) self.random_scale_list = [1] def _init_lmdb(self): @@ -85,7 +79,7 @@ class LQGTDataset(data.Dataset): GT_size = self.opt['target_size'] # get GT image - GT_path = self.paths_GT[index] + GT_path = self.paths_GT[index % len(self.paths_GT)] resolution = [int(s) for s in self.sizes_GT[index].split('_') ] if self.data_type == 'lmdb' else None img_GT = util.read_img(self.GT_env, GT_path, resolution) diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index 786181e4..114c3546 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -42,6 +42,7 @@ class SRGANModel(BaseModel): else: self.netC = None self.mega_batch_factor = 1 + self.disjoint_data = False # define losses, optimizer and scheduler if self.is_train: @@ -101,16 +102,28 @@ class SRGANModel(BaseModel): self.cri_fea = None if self.cri_fea: # load VGG perceptual loss self.netF = networks.define_F(opt, use_bn=False).to(self.device) + self.lr_netF = None + if 'lr_fea_path' in train_opt.keys(): + self.lr_netF = networks.define_F(opt, use_bn=False, load_path=train_opt['lr_fea_path']).to(self.device) + self.disjoint_data = True + if opt['dist']: pass # do not need to use DistributedDataParallel for netF else: self.netF = DataParallel(self.netF) + if self.lr_netF: + self.lr_netF = DataParallel(self.lr_netF) # You can feed in a list of frozen pre-trained discriminators. These are treated the same as feature losses. self.fixed_disc_nets = [] if 'fixed_discriminators' in opt.keys(): for opt_fdisc in opt['fixed_discriminators'].keys(): - self.fixed_disc_nets.append(networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device)) + netFD = networks.define_fixed_D(opt['fixed_discriminators'][opt_fdisc]).to(self.device) + if opt['dist']: + pass # do not need to use DistributedDataParallel for netF + else: + netFD = DataParallel(netFD) + self.fixed_disc_nets.append(netFD) # GD gan loss self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) @@ -330,7 +343,10 @@ class SRGANModel(BaseModel): l_g_fdpl = self.cri_fdpl(fea_GenOut, pix) l_g_total += l_g_fdpl * self.fdpl_weight if self.cri_fea and not using_gan_img: # feature loss - real_fea = self.netF(pix).detach() + if self.lr_netF is not None: + real_fea = self.lr_netF(var_L, interpolate_factor=self.opt['scale']) + else: + real_fea = self.netF(pix).detach() fake_fea = self.netF(fea_GenOut) fea_w = self.l_fea_sched.get_weight_for_step(step) l_g_fea = fea_w * self.cri_fea(fake_fea, real_fea) @@ -346,7 +362,7 @@ class SRGANModel(BaseModel): # equal to this value. If I ever come up with an algorithm that tunes fea/gan weights automatically, # it should target this - l_g_fix_disc = 0 + l_g_fix_disc = torch.zeros(1, requires_grad=False).squeeze() for fixed_disc in self.fixed_disc_nets: weight = fixed_disc.fdisc_weight real_fea = fixed_disc(pix).detach() @@ -439,33 +455,34 @@ class SRGANModel(BaseModel): with amp.scale_loss(l_d_fake, self.optimizer_D, loss_id=1) as l_d_fake_scaled: l_d_fake_scaled.backward() if 'pixgan' in self.opt['train']['gan_type']: - # randomly determine portions of the image to swap to keep the discriminator honest. - pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() - disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) - b, _, w, h = var_ref.shape - real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) - fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) - SWAP_MAX_DIM = w // 4 - SWAP_MIN_DIM = 16 - assert SWAP_MAX_DIM > 0 - if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen - # more often and the model was "cheating" by using the presence of - # easily discriminated fake swaps to count the entire generated image - # as fake. - random_swap_count = random.randint(0, 4) - for i in range(random_swap_count): - # Make the swap across fake_H and var_ref - swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) - swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) - if swap_x + swap_w > w: - swap_w = w - swap_x - if swap_y + swap_h > h: - swap_h = h - swap_y - t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() - fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] - var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t - real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 - fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 + if not self.disjoint_data: + # randomly determine portions of the image to swap to keep the discriminator honest. + pixdisc_channels, pixdisc_output_reduction = self.netD.module.pixgan_parameters() + disc_output_shape = (var_ref.shape[0], pixdisc_channels, var_ref.shape[2] // pixdisc_output_reduction, var_ref.shape[3] // pixdisc_output_reduction) + b, _, w, h = var_ref.shape + real = torch.ones((b, pixdisc_channels, w, h), device=var_ref.device) + fake = torch.zeros((b, pixdisc_channels, w, h), device=var_ref.device) + SWAP_MAX_DIM = w // 4 + SWAP_MIN_DIM = 16 + assert SWAP_MAX_DIM > 0 + if random.random() > .5: # Make this only happen half the time. Earlier experiments had it happen + # more often and the model was "cheating" by using the presence of + # easily discriminated fake swaps to count the entire generated image + # as fake. + random_swap_count = random.randint(0, 4) + for i in range(random_swap_count): + # Make the swap across fake_H and var_ref + swap_x, swap_y = random.randint(0, w - SWAP_MIN_DIM), random.randint(0, h - SWAP_MIN_DIM) + swap_w, swap_h = random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM), random.randint(SWAP_MIN_DIM, SWAP_MAX_DIM) + if swap_x + swap_w > w: + swap_w = w - swap_x + if swap_y + swap_h > h: + swap_h = h - swap_y + t = fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)].clone() + fake_H[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] + var_ref[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = t + real[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 0.0 + fake[:, :, swap_x:(swap_x+swap_w), swap_y:(swap_y+swap_h)] = 1.0 # Interpolate down to the dimensionality that the discriminator uses. real = F.interpolate(real, size=disc_output_shape[2:], mode="bilinear") diff --git a/codes/models/archs/feature_arch.py b/codes/models/archs/feature_arch.py index fbe87eb1..6d182231 100644 --- a/codes/models/archs/feature_arch.py +++ b/codes/models/archs/feature_arch.py @@ -26,8 +26,10 @@ class VGGFeatureExtractor(nn.Module): for k, v in self.features.named_parameters(): v.requires_grad = False - def forward(self, x): - # Assume input range is [0, 1] + def forward(self, x, interpolate_factor=1): + if interpolate_factor > 1: + x = F.interpolate(x, scale_factor=interpolate_factor, mode='bicubic') + if self.use_input_norm: x = (x - self.mean) / self.std output = self.features(x) diff --git a/codes/models/networks.py b/codes/models/networks.py index f8df741b..82d50910 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -168,7 +168,7 @@ def define_fixed_D(opt): # Define network used for perceptual loss -def define_F(opt, use_bn=False, for_training=False): +def define_F(opt, use_bn=False, for_training=False, load_path=None): gpu_ids = opt['gpu_ids'] device = torch.device('cuda' if gpu_ids else 'cpu') if 'which_model_F' not in opt['train'].keys() or opt['train']['which_model_F'] == 'vgg': @@ -186,5 +186,21 @@ def define_F(opt, use_bn=False, for_training=False): elif opt['train']['which_model_F'] == 'wide_resnet': netF = feature_arch.WideResnetFeatureExtractor(use_input_norm=True, device=device) - netF.eval() # No need to train + if load_path: + # Load the model parameters: + load_net = torch.load(load_path) + load_net_clean = OrderedDict() # remove unnecessary 'module.' + for k, v in load_net.items(): + if k.startswith('module.'): + load_net_clean[k[7:]] = v + else: + load_net_clean[k] = v + netF.load_state_dict(load_net_clean) + + # Put into eval mode, freeze the parameters and set the 'weight' field. + netF.eval() + for k, v in netF.named_parameters(): + v.requires_grad = False + netF.fdisc_weight = opt['weight'] + return netF diff --git a/codes/train.py b/codes/train.py index 1f2cfb6b..e8e5ec53 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_feature_net.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_pixgan_srg4_lr_feat.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)