From 4b4d08bdecf1eebe2a8102c8cc6deb725373e458 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 31 Aug 2020 09:41:48 -0600 Subject: [PATCH] Enable testing in ExtensibleTrainer, fix it in SRGAN_model Also compute fea loss for this. --- codes/models/ExtensibleTrainer.py | 121 ++++++++++++++++-------------- codes/models/SRGAN_model.py | 3 + codes/models/networks.py | 2 +- codes/models/steps/steps.py | 7 +- codes/test.py | 17 +++-- codes/train.py | 2 +- codes/train2.py | 2 +- 7 files changed, 86 insertions(+), 68 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 919a945a..aa6dd288 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -30,6 +30,11 @@ class ExtensibleTrainer(BaseModel): 'opt': opt, 'step': 0} + self.mega_batch_factor = 1 + if self.is_train: + self.mega_batch_factor = train_opt['mega_batch_factor'] + self.env['mega_batch_factor'] = self.mega_batch_factor + self.netsG = {} self.netsD = {} self.netF = networks.define_F().to(self.device) # Used to compute feature loss. @@ -49,71 +54,68 @@ class ExtensibleTrainer(BaseModel): step = ConfigurableStep(step, self.env) self.steps.append(step) + # The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped + # yet. + self.env['generators'] = self.netsG + self.env['discriminators'] = self.netsD + + # Define the optimizers from the steps + for s in self.steps: + s.define_optimizers() + self.optimizers.extend(s.get_optimizers()) + if self.is_train: - self.mega_batch_factor = train_opt['mega_batch_factor'] - if self.mega_batch_factor is None: - self.mega_batch_factor = 1 - self.env['mega_batch_factor'] = self.mega_batch_factor - - # The steps rely on the networks being placed in the env, so put them there. Even though they arent wrapped - # yet. - self.env['generators'] = self.netsG - self.env['discriminators'] = self.netsD - - # Define the optimizers from the steps - for s in self.steps: - s.define_optimizers() - self.optimizers.extend(s.get_optimizers()) - # Find the optimizers that are using the default scheduler, then build them. def_opt = [] for s in self.steps: def_opt.extend(s.get_optimizers_with_default_scheduler()) self.schedulers = lr_scheduler.get_scheduler_for_name(train_opt['default_lr_scheme'], def_opt, train_opt) + else: + self.schedulers = [] - # Initialize amp. - total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] - amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps, - self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) + # Initialize amp. + total_nets = [g for g in self.netsG.values()] + [d for d in self.netsD.values()] + amp_nets, amp_opts = amp.initialize(total_nets + [self.netF] + self.steps, + self.optimizers, opt_level=opt['amp_opt_level'], num_losses=len(opt['steps'])) - # Unwrap steps & netF - self.netF = amp_nets[len(total_nets)] - assert(len(self.steps) == len(amp_nets[len(total_nets)+1:])) - self.steps = amp_nets[len(total_nets)+1:] - amp_nets = amp_nets[:len(total_nets)] + # Unwrap steps & netF + self.netF = amp_nets[len(total_nets)] + assert(len(self.steps) == len(amp_nets[len(total_nets)+1:])) + self.steps = amp_nets[len(total_nets)+1:] + amp_nets = amp_nets[:len(total_nets)] - # DataParallel - dnets = [] - for anet in amp_nets: - if opt['dist']: - dnet = DistributedDataParallel(anet, - device_ids=[torch.cuda.current_device()], - find_unused_parameters=True) - else: - dnet = DataParallel(anet) - if self.is_train: - dnet.train() - else: - dnet.eval() - dnets.append(dnet) - if not opt['dist']: - self.netF = DataParallel(self.netF) + # DataParallel + dnets = [] + for anet in amp_nets: + if opt['dist']: + dnet = DistributedDataParallel(anet, + device_ids=[torch.cuda.current_device()], + find_unused_parameters=True) + else: + dnet = DataParallel(anet) + if self.is_train: + dnet.train() + else: + dnet.eval() + dnets.append(dnet) + if not opt['dist']: + self.netF = DataParallel(self.netF) - # Backpush the wrapped networks into the network dicts.. - self.networks = {} - found = 0 - for dnet in dnets: - for net_dict in [self.netsD, self.netsG]: - for k, v in net_dict.items(): - if v == dnet.module: - net_dict[k] = dnet - self.networks[k] = dnet - found += 1 - assert found == len(self.netsG) + len(self.netsD) + # Backpush the wrapped networks into the network dicts.. + self.networks = {} + found = 0 + for dnet in dnets: + for net_dict in [self.netsD, self.netsG]: + for k, v in net_dict.items(): + if v == dnet.module: + net_dict[k] = dnet + self.networks[k] = dnet + found += 1 + assert found == len(self.netsG) + len(self.netsD) - # Replace the env networks with the wrapped networks - self.env['generators'] = self.netsG - self.env['discriminators'] = self.netsD + # Replace the env networks with the wrapped networks + self.env['generators'] = self.netsG + self.env['discriminators'] = self.netsD self.print_network() # print network self.load() # load G and D if needed @@ -121,7 +123,12 @@ class ExtensibleTrainer(BaseModel): # Setting this to false triggers SRGAN to call the models update_model() function on the first iteration. self.updated = True - def feed_data(self, data): + def feed_data(self, data, need_GT=False): + self.eval_state = {} + for o in self.optimizers: + o.zero_grad() + torch.cuda.empty_cache() + self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0) self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] input_ref = data['ref'] if 'ref' in data else data['GT'] @@ -206,7 +213,9 @@ class ExtensibleTrainer(BaseModel): for k, v in ns.items(): state[k] = [v] - self.eval_state = state + self.eval_state = {} + for k, v in state.items(): + self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v] for net in self.netsG.values(): net.train() diff --git a/codes/models/SRGAN_model.py b/codes/models/SRGAN_model.py index d1f4b89a..4fa72fbc 100644 --- a/codes/models/SRGAN_model.py +++ b/codes/models/SRGAN_model.py @@ -352,6 +352,9 @@ class SRGANModel(BaseModel): self.gan_lq_img_use_prob = train_opt['gan_lowres_use_probability'] if train_opt['gan_lowres_use_probability'] else 0 self.img_debug_steps = opt['logger']['img_debug_steps'] if 'img_debug_steps' in opt['logger'].keys() else 50 + else: + self.netF = networks.define_F(use_bn=False).to(self.device) + self.cri_fea = nn.L1Loss().to(self.device) #self.print_network() # print network self.load() # load G and D if needed diff --git a/codes/models/networks.py b/codes/models/networks.py index 6ebb7935..31dbce05 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -181,7 +181,7 @@ def define_D_net(opt_net, img_sz=None, wrap=False): netD = SRGAN_arch.Discriminator_switched(in_nc=opt_net['in_nc'], nf=opt_net['nf'], initial_temp=opt_net['initial_temp'], final_temperature_step=opt_net['final_temperature_step']) elif which_model == "cross_compare_vgg128": - netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'], nf=opt_net['nf'], scale=opt_net['scale']) + netD = SRGAN_arch.CrossCompareDiscriminator(in_nc=opt_net['in_nc'], ref_channels=opt_net['ref_channels'] if 'ref_channels' in opt_net.keys() else 3, nf=opt_net['nf'], scale=opt_net['scale']) else: raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) return netD diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index 51bf8d8c..15331e5a 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -30,9 +30,10 @@ class ConfigurableStep(Module): losses = [] self.weights = {} - for loss_name, loss in self.step_opt['losses'].items(): - losses.append((loss_name, create_generator_loss(loss, env))) - self.weights[loss_name] = loss['weight'] + if 'losses' in self.step_opt.keys(): + for loss_name, loss in self.step_opt['losses'].items(): + losses.append((loss_name, create_generator_loss(loss, env))) + self.weights[loss_name] = loss['weight'] self.losses = OrderedDict(losses) # Subclasses should override this to define individual optimizers. They should all go into self.optimizers. diff --git a/codes/test.py b/codes/test.py index 7f3fdd11..7d418ceb 100644 --- a/codes/test.py +++ b/codes/test.py @@ -61,10 +61,8 @@ def forward_pass(model, output_dir, alteration_suffix=''): model.feed_data(data, need_GT=need_GT) model.test() - if isinstance(model.fake_GenOut[0], tuple): - visuals = model.fake_GenOut[0][0].detach().float().cpu() - else: - visuals = model.fake_GenOut[0].detach().float().cpu() + visuals = model.get_current_visuals()['rlt'].cpu() + fea_loss = 0 for i in range(visuals.shape[0]): img_path = data['GT_path'][i] if need_GT else data['LQ_path'][i] img_name = osp.splitext(osp.basename(img_path))[0] @@ -78,7 +76,10 @@ def forward_pass(model, output_dir, alteration_suffix=''): else: save_img_path = osp.join(output_dir, img_name + '.png') + fea_loss += model.compute_fea_loss(visuals[i], data['GT'][i]) + util.save_img(sr_img, save_img_path) + return fea_loss if __name__ == "__main__": @@ -87,7 +88,7 @@ if __name__ == "__main__": want_just_images = True srg_analyze = False parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='../options/analyze_srg.yml') + parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/srgan_compute_feature.yml') opt = option.parse(parser.parse_args().opt, is_train=False) opt = option.dict_to_nonedict(opt) @@ -108,6 +109,7 @@ if __name__ == "__main__": test_loaders.append(test_loader) model = create_model(opt) + fea_loss = 0 for test_loader in test_loaders: test_set_name = test_loader.dataset.opt['name'] logger.info('\nTesting [{:s}]...'.format(test_set_name)) @@ -143,4 +145,7 @@ if __name__ == "__main__": model_copy.load_state_dict(orig_model.state_dict()) model.netG = model_copy else: - forward_pass(model, dataset_dir, opt['name']) + fea_loss += forward_pass(model, dataset_dir, opt['name']) + + # log + logger.info('# Validation # Fea: {:.4e}'.format(fea_loss / len(test_loader))) diff --git a/codes/train.py b/codes/train.py index 1ea05fa8..82ae665f 100644 --- a/codes/train.py +++ b/codes/train.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_xlbatch_ragan.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/srgan_compute_feature.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) diff --git a/codes/train2.py b/codes/train2.py index 112cc83b..832d68bb 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -32,7 +32,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_spsr_switched2_fullimgref.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_mi1_spsr_switched2_fullimgref.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0)