From 9c3d059ef0ec1ac396b5e0645920b4f23ac22e9a Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 24 Oct 2020 11:56:39 -0600 Subject: [PATCH] Updates to be able to train flownet2 in ExtensibleTrainer Only supports basic losses for now, though. --- codes/data/paired_frame_dataset.py | 2 +- codes/models/base_model.py | 5 +++++ codes/models/networks.py | 8 +++++--- codes/models/steps/injectors.py | 19 +++++++++++++++++++ codes/models/steps/stereoscopic.py | 22 ++++++++++++++++++++++ codes/train.py | 2 +- codes/utils/util.py | 8 ++++++++ 7 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 codes/models/steps/stereoscopic.py diff --git a/codes/data/paired_frame_dataset.py b/codes/data/paired_frame_dataset.py index 2c1be9ac..16c09ba2 100644 --- a/codes/data/paired_frame_dataset.py +++ b/codes/data/paired_frame_dataset.py @@ -35,7 +35,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset): # Convert to torch tensor hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float() hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float() - hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1) + hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).squeeze().unsqueeze(dim=1) hq_ref = torch.cat([hq_ref, hq_mask], dim=1) lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float() lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float() diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 04b6d9e3..ea08aecc 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -101,6 +101,11 @@ class BaseModel(): if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): network = network.module load_net = torch.load(load_path) + + # Support loading torch.save()s for whole models as well as just state_dicts. + if 'state_dict' in load_net: + load_net = load_net['state_dict'] + load_net_clean = OrderedDict() # remove unnecessary 'module.' for k, v in load_net.items(): if k.startswith('module.'): diff --git a/codes/models/networks.py b/codes/models/networks.py index e10064a3..8423fb7f 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -87,10 +87,12 @@ def define_G(opt, net_key='network_G', scale=None): netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale) elif which_model == "flownet2": from models.flownet2.models import FlowNet2 - ld = torch.load(opt_net['load_path']) - args = munch.Munch({'fp16': False, 'rgb_max': 1.0}) + ld = 'load_path' in opt_net.keys() + args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld}) netG = FlowNet2(args) - netG.load_state_dict(ld['state_dict']) + if ld: + sd = torch.load(opt_net['load_path']) + netG.load_state_dict(sd['state_dict']) elif which_model == "backbone_encoder": netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) elif which_model == "backbone_encoder_no_ref": diff --git a/codes/models/steps/injectors.py b/codes/models/steps/injectors.py index 7e6e2fd9..f718de9f 100644 --- a/codes/models/steps/injectors.py +++ b/codes/models/steps/injectors.py @@ -14,6 +14,9 @@ def create_injector(opt_inject, env): elif 'progressive_' in type: from models.steps.progressive_zoom import create_progressive_zoom_injector return create_progressive_zoom_injector(opt_inject, env) + elif 'stereoscopic_' in type: + from models.steps.stereoscopic import create_stereoscopic_injector + return create_stereoscopic_injector(opt_inject, env) elif type == 'generator': return ImageGeneratorInjector(opt_inject, env) elif type == 'discriminator': @@ -42,6 +45,8 @@ def create_injector(opt_inject, env): return ConstantInjector(opt_inject, env) elif type == 'fft': return ImageFftInjector(opt_inject, env) + elif type == 'extract_indices': + return IndicesExtractor(opt_inject, env) else: raise NotImplementedError @@ -299,3 +304,17 @@ class ImageFftInjector(Injector): im = torch.irfft(fftim, signal_ndim=2, normalized=True) return {self.output: im} + +class IndicesExtractor(Injector): + def __init__(self, opt, env): + super(IndicesExtractor, self).__init__(opt, env) + self.dim = opt['dim'] + assert self.dim == 1 # Honestly not sure how to support an abstract dim here, so just add yours when needed. + + def forward(self, state): + results = {} + for i, o in enumerate(self.output): + if self.dim == 1: + results[o] = state[self.input][:, i] + return results + diff --git a/codes/models/steps/stereoscopic.py b/codes/models/steps/stereoscopic.py new file mode 100644 index 00000000..f4f4f32d --- /dev/null +++ b/codes/models/steps/stereoscopic.py @@ -0,0 +1,22 @@ +import torch +from torch.cuda.amp import autocast +from models.flownet2.networks.resample2d_package.resample2d import Resample2d +from models.steps.injectors import Injector + + +def create_stereoscopic_injector(opt, env): + type = opt['type'] + if type == 'stereoscopic_resample': + return ResampleInjector(opt, env) + return None + + +class ResampleInjector(Injector): + def __init__(self, opt, env): + super(ResampleInjector, self).__init__(opt, env) + self.resample = Resample2d() + self.flow = opt['flowfield'] + + def forward(self, state): + with autocast(enabled=False): + return {self.output: self.resample(state[self.input], state[self.flow])} \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index c7232118..536e1575 100644 --- a/codes/train.py +++ b/codes/train.py @@ -278,7 +278,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_3dflow_vr_flownet.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True) diff --git a/codes/utils/util.py b/codes/utils/util.py index c526fecb..ea271b2e 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -51,6 +51,14 @@ def checkpoint(fn, *args): else: return fn(*args) +# A fancy alternative to if checkpoint() else +def possible_checkpoint(enabled, fn, *args): + opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True + if enabled and opt_en: + return torch.utils.checkpoint.checkpoint(fn, *args) + else: + return fn(*args) + def get_timestamp(): return datetime.now().strftime('%y%m%d-%H%M%S')