From 74738489b92afdf34bbb5f5a84194febf8e9a61a Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 30 Oct 2020 09:59:54 -0600 Subject: [PATCH] Fixes and additional support for progressive zoom --- codes/data/multiscale_dataset.py | 2 +- codes/models/ExtensibleTrainer.py | 13 +++++++------ codes/models/archs/RRDBNet_arch.py | 2 ++ codes/models/steps/progressive_zoom.py | 20 +++++++++++++++----- codes/scripts/validate_data.py | 7 ++++++- codes/train.py | 2 +- codes/utils/options.py | 7 ------- 7 files changed, 32 insertions(+), 21 deletions(-) diff --git a/codes/data/multiscale_dataset.py b/codes/data/multiscale_dataset.py index cc5aa78a..1a176367 100644 --- a/codes/data/multiscale_dataset.py +++ b/codes/data/multiscale_dataset.py @@ -22,7 +22,7 @@ class MultiScaleDataset(data.Dataset): self.num_scales = self.opt['num_scales'] self.hq_size_cap = self.tile_size * 2 ** self.num_scales self.scale = self.opt['scale'] - self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1]) + self.paths_hq, self.sizes_hq = util.get_image_paths(self.data_type, opt['paths'], [1 for _ in opt['paths']]) self.corruptor = ImageCorruptor(opt) # Selects the smallest dimension from the image and crops it randomly so the other dimension matches. The cropping diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index d8485273..c28cfde7 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -267,13 +267,14 @@ class ExtensibleTrainer(BaseModel): with torch.no_grad(): # This can happen one of two ways: Either a 'validation injector' is provided, in which case we run that. # Or, we run the entire chain of steps in "train" mode and use eval.output_state. - if 'injector' in self.opt['eval'].keys(): - # Need to move from mega_batch mode to batch mode (remove chunks) + if 'injectors' in self.opt['eval'].keys(): state = {} - for k, v in self.dstate.items(): - state[k] = v[0] - inj = create_injector(self.opt['eval']['injector'], self.env) - state.update(inj(state)) + for inj in self.opt['eval']['injectors'].values(): + # Need to move from mega_batch mode to batch mode (remove chunks) + for k, v in self.dstate.items(): + state[k] = v[0] + inj = create_injector(inj, self.env) + state.update(inj(state)) else: # Iterate through the steps, performing them one at a time. state = self.dstate diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index 42a04770..e08a3353 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -188,6 +188,8 @@ class RRDBNet(nn.Module): if ref is None: ref = torch.zeros_like(x_lg) x_lg = torch.cat([x_lg, ref], dim=1) + else: + x_lg = x feat = self.conv_first(x_lg) body_feat = self.conv_body(checkpoint_sequential(self.body, self.num_blocks // self.blocks_per_checkpoint, feat)) feat = feat + body_feat diff --git a/codes/models/steps/progressive_zoom.py b/codes/models/steps/progressive_zoom.py index f4d047f6..2b27cf7b 100644 --- a/codes/models/steps/progressive_zoom.py +++ b/codes/models/steps/progressive_zoom.py @@ -27,8 +27,12 @@ class ProgressiveGeneratorInjector(Injector): self.hq_output_key = opt['hq_output'] # The key where HQ images corresponding with generated images are stored. self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 self.output_hq_index = opt['output_hq_index'] - self.recurrent_output_index = opt['recurrent_output_index'] - self.recurrent_index = opt['recurrent_index'] + if 'recurrent_output_index' in opt.keys(): + self.recurrent_output_index = opt['recurrent_output_index'] + self.recurrent_index = opt['recurrent_index'] + self.recurrence = True + else: + self.recurrence = False self.depth = opt['depth'] self.number_branches = opt['num_branches'] # Number of input branches to randomly choose for generation. This defines the output shape. self.multiscale_leaves = build_multiscale_patch_index_map(self.depth) @@ -52,7 +56,8 @@ class ProgressiveGeneratorInjector(Injector): def feed_forward(self, gen, inputs, results, lq_input, recurrent_input): ff_input = inputs.copy() ff_input[self.input_lq_index] = lq_input - ff_input[self.recurrent_index] = recurrent_input + if self.recurrence: + ff_input[self.recurrent_index] = recurrent_input with autocast(enabled=self.env['opt']['fp16']): gen_out = gen(*ff_input) @@ -61,7 +66,10 @@ class ProgressiveGeneratorInjector(Injector): gen_out = [gen_out] for i, out_key in enumerate(self.output): results[out_key].append(gen_out[i]) - return gen_out[self.output_hq_index], gen_out[self.recurrent_output_index] + recurrent = None + if self.recurrence: + recurrent = gen_out[self.recurrent_output_index] + return gen_out[self.output_hq_index], recurrent def forward(self, state): gen = self.env['generators'][self.gen_key] @@ -73,6 +81,7 @@ class ProgressiveGeneratorInjector(Injector): inputs = [inputs] if not isinstance(self.output, list): output = [self.output] + self.output = output results = {} # A list of outputs produced by feeding each progressive lq input into the generator. results_hq = [] for out_key in output: @@ -91,7 +100,8 @@ class ProgressiveGeneratorInjector(Injector): for link in chain: # Remember, `link` is a MultiscaleTreeNode. top = int(link.top * h) left = int(link.left * w) - recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest") + if recurrent is not None: + recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest") if self.feed_gen_output_into_input: top *= 2 left *= 2 diff --git a/codes/scripts/validate_data.py b/codes/scripts/validate_data.py index c2fb28cd..e88d4f54 100644 --- a/codes/scripts/validate_data.py +++ b/codes/scripts/validate_data.py @@ -13,7 +13,7 @@ from skimage import io def main(): #### options parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_imgset_spsr_switched2_xlbatch.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../../options/train_prog_mi1_rrdb_6bypass.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) @@ -50,6 +50,7 @@ def main(): len(train_set), train_size)) assert train_loader is not None + ''' tq_ldr = tqdm(train_set.get_paths()) for path in tq_ldr: try: @@ -58,6 +59,10 @@ def main(): except Exception as e: print("Error with %s" % (path,)) print(e) + ''' + tq_ldr = tqdm(train_set) + for ds in tq_ldr: + pass if __name__ == '__main__': diff --git a/codes/train.py b/codes/train.py index eccc14ab..120356fe 100644 --- a/codes/train.py +++ b/codes/train.py @@ -265,7 +265,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_bypass.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_mi1_rrdb_6bypass.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/utils/options.py b/codes/utils/options.py index bead66c4..64240ad8 100644 --- a/codes/utils/options.py +++ b/codes/utils/options.py @@ -23,8 +23,6 @@ def parse(opt_path, is_train=True): for phase, dataset in opt['datasets'].items(): phase = phase.split('_')[0] dataset['phase'] = phase - if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample': - dataset['scale'] = scale is_lmdb = False ''' LMDB is not supported at this point with the mods I've been making. if dataset.get('dataroot_GT', None) is not None: @@ -67,11 +65,6 @@ def parse(opt_path, is_train=True): opt['path']['results_root'] = results_root opt['path']['log'] = results_root - # network - if opt['distortion'] == 'sr' or opt['distortion'] == 'downsample': - if 'network_G' in opt.keys(): - opt['network_G']['scale'] = scale - return opt