From f857eb00a8f32b655058db64898683f666c1bb7e Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 26 Oct 2020 11:09:55 -0600 Subject: [PATCH 1/2] Allow tecogan losses to compute at 32px --- codes/models/steps/tecogan_losses.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index db301c7a..b460a15d 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -97,9 +97,15 @@ class RecurrentImageGeneratorSequenceInjector(Injector): else: input = extract_inputs_index(inputs, i) with torch.no_grad() and autocast(enabled=False): - reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') - flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2).float() - flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') + # This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is + # made here that if you are operating at 4x scale, your inputs are 32px x 32px + if self.scale >= 4: + flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic') + else: + flow_input = input[self.input_lq_index] + reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic') + flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float() + flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic') recurrent_input = self.resample(recurrent_input.float(), flowfield) input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: @@ -122,9 +128,15 @@ class RecurrentImageGeneratorSequenceInjector(Injector): input = extract_inputs_index(inputs, i) with torch.no_grad(): with autocast(enabled=False): - reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') - flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2).float() - flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') + # This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is + # made here that if you are operating at 4x scale, your inputs are 32px x 32px + if self.scale >= 4: + flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic') + else: + flow_input = input[self.input_lq_index] + reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic') + flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float() + flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic') recurrent_input = self.resample(recurrent_input.float(), flowfield) input[self.recurrent_index] = recurrent_input if self.env['step'] % 50 == 0: From b2f803588b71c56bb3533eb25a96796e374ceeb0 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 26 Oct 2020 11:10:22 -0600 Subject: [PATCH 2/2] Fix multi_modal_train.py --- codes/multi_modal_train.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/codes/multi_modal_train.py b/codes/multi_modal_train.py index 3b4cf1cb..23d1c379 100644 --- a/codes/multi_modal_train.py +++ b/codes/multi_modal_train.py @@ -42,14 +42,11 @@ def main(master_opt, launcher): if __name__ == '__main__': parser = argparse.ArgumentParser() - #parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() Loader, Dumper = OrderedYaml() with open(args.opt, mode='r') as f: opt = yaml.load(f, Loader=Loader) - opt = { - 'trainer_options': ['../options/teco.yml', '../options/exd.yml'] - } - main(opt, args.launcher) \ No newline at end of file + main(opt, args.launcher)