Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-10-26 11:12:37 -06:00
commit 54accfa693
2 changed files with 20 additions and 11 deletions

View File

@ -97,9 +97,15 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
else: else:
input = extract_inputs_index(inputs, i) input = extract_inputs_index(inputs, i)
with torch.no_grad() and autocast(enabled=False): with torch.no_grad() and autocast(enabled=False):
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') # This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2).float() # made here that if you are operating at 4x scale, your inputs are 32px x 32px
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') if self.scale >= 4:
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
else:
flow_input = input[self.input_lq_index]
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield) recurrent_input = self.resample(recurrent_input.float(), flowfield)
input[self.recurrent_index] = recurrent_input input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
@ -122,9 +128,15 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
input = extract_inputs_index(inputs, i) input = extract_inputs_index(inputs, i)
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=False): with autocast(enabled=False):
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') # This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2).float() # made here that if you are operating at 4x scale, your inputs are 32px x 32px
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic') if self.scale >= 4:
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
else:
flow_input = input[self.input_lq_index]
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield) recurrent_input = self.resample(recurrent_input.float(), flowfield)
input[self.recurrent_index] = recurrent_input input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:

View File

@ -42,14 +42,11 @@ def main(master_opt, launcher):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
#parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args() args = parser.parse_args()
Loader, Dumper = OrderedYaml() Loader, Dumper = OrderedYaml()
with open(args.opt, mode='r') as f: with open(args.opt, mode='r') as f:
opt = yaml.load(f, Loader=Loader) opt = yaml.load(f, Loader=Loader)
opt = { main(opt, args.launcher)
'trainer_options': ['../options/teco.yml', '../options/exd.yml']
}
main(opt, args.launcher)