From 981d64413b0efb848436a6e3e50e996e99bff483 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 19 Oct 2020 11:01:56 -0600 Subject: [PATCH] Support validation over a custom injector Also re-enable PSNR --- codes/models/ExtensibleTrainer.py | 33 ++++++++++++++++++++----------- codes/train.py | 32 +++++++++++------------------- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index b12eb7fe..8f4e2e9e 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -10,6 +10,7 @@ from torch.nn.parallel.distributed import DistributedDataParallel import models.lr_scheduler as lr_scheduler import models.networks as networks from models.base_model import BaseModel +from models.steps.injectors import create_injector from models.steps.steps import ConfigurableStep from models.experiments.experiments import get_experiment_for_name import torchvision.utils as utils @@ -155,7 +156,7 @@ class ExtensibleTrainer(BaseModel): o.zero_grad() torch.cuda.empty_cache() - self.lq = torch.chunk(data['LQ'].to(self.device), chunks=self.mega_batch_factor, dim=0) + self.lq = [t.to(self.device) for t in torch.chunk(data['LQ'], chunks=self.mega_batch_factor, dim=0)] if need_GT: self.hq = [t.to(self.device) for t in torch.chunk(data['GT'], chunks=self.mega_batch_factor, dim=0)] input_ref = data['ref'] if 'ref' in data.keys() else data['GT'] @@ -260,19 +261,29 @@ class ExtensibleTrainer(BaseModel): net.eval() with torch.no_grad(): - # Iterate through the steps, performing them one at a time. - state = self.dstate - for step_num, s in enumerate(self.steps): - ns = s.do_forward_backward(state, 0, step_num, train=False) - for k, v in ns.items(): - state[k] = [v] + # This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that. + # Or, we run the entire chain of steps in "train" mode and use eval.output_state. + if 'injector' in self.opt['eval'].keys(): + # Need to move from mega_batch mode to batch mode (remove chunks) + state = {} + for k, v in self.dstate.items(): + state[k] = v[0] + inj = create_injector(self.opt['eval']['injector'], self.env) + state.update(inj(state)) + else: + # Iterate through the steps, performing them one at a time. + state = self.dstate + for step_num, s in enumerate(self.steps): + ns = s.do_forward_backward(state, 0, step_num, train=False) + for k, v in ns.items(): + state[k] = [v] self.eval_state = {} for k, v in state.items(): - self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v] - - # For backwards compatibility.. - self.fake_H = self.eval_state[self.opt['eval']['output_state']][0].float().cpu() + if isinstance(v, list): + self.eval_state[k] = [s.detach().cpu() if isinstance(s, torch.Tensor) else s for s in v] + else: + self.eval_state[k] = [v.detach().cpu() if isinstance(v, torch.Tensor) else v] for net in self.netsG.values(): net.train() diff --git a/codes/train.py b/codes/train.py index 0c7de24b..efc6e378 100644 --- a/codes/train.py +++ b/codes/train.py @@ -30,7 +30,7 @@ def init_dist(backend='nccl', **kwargs): def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_chained.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_spsr7.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() @@ -185,10 +185,6 @@ def main(): print("Data fetch: %f" % (time() - _t)) _t = time() - #tb_logger.add_graph(model.netsG['generator'].module, [train_data['LQ'].to('cuda'), - # train_data['lq_fullsize_ref'].float().to('cuda'), - # train_data['lq_center'].to('cuda')]) - current_step += 1 if current_step > total_iters: break @@ -241,9 +237,6 @@ def main(): #### validation if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0: if opt['model'] in ['sr', 'srgan', 'corruptgan', 'spsrgan', 'extensibletrainer'] and rank <= 0: # image restoration validation - model.force_restore_swapout() - val_batch_sz = 1 if 'batch_size' not in opt['datasets']['val'].keys() else opt['datasets']['val']['batch_size'] - # does not support multi-GPU validation avg_psnr = 0. avg_fea_loss = 0. idx = 0 @@ -263,23 +256,22 @@ def main(): if visuals is None: continue + if colab_mode: + colab_imgs_to_copy.append(save_img_path) + + # calculate PSNR sr_img = util.tensor2img(visuals['rlt'][b]) # uint8 - #gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + gt_img = util.tensor2img(visuals['GT'][b]) # uint8 + sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) + avg_psnr += util.calculate_psnr(sr_img, gt_img) + + # calculate fea loss + avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) # Save SR images for reference img_base_name = '{:s}_{:d}.png'.format(img_name, current_step) save_img_path = os.path.join(img_dir, img_base_name) util.save_img(sr_img, save_img_path) - if colab_mode: - colab_imgs_to_copy.append(save_img_path) - - # calculate PSNR (Naw - don't do that. PSNR sucks) - #sr_img, gt_img = util.crop_border([sr_img, gt_img], opt['scale']) - #avg_psnr += util.calculate_psnr(sr_img, gt_img) - #pbar.update('Test {}'.format(img_name)) - - # calculate fea loss - avg_fea_loss += model.compute_fea_loss(visuals['rlt'][b], visuals['GT'][b]) if colab_mode: util.copy_files_to_server(opt['ssh_server'], opt['ssh_username'], opt['ssh_password'], @@ -293,7 +285,7 @@ def main(): logger.info('# Validation # PSNR: {:.4e} Fea: {:.4e}'.format(avg_psnr, avg_fea_loss)) # tensorboard logger if opt['use_tb_logger'] and 'debug' not in opt['name'] and rank <= 0: - #tb_logger.add_scalar('val_psnr', avg_psnr, current_step) + tb_logger.add_scalar('val_psnr', avg_psnr, current_step) tb_logger.add_scalar('val_fea', avg_fea_loss, current_step) if rank <= 0: