Updates to be able to train flownet2 in ExtensibleTrainer

Only supports basic losses for now, though.
This commit is contained in:
James Betker 2020-10-24 11:56:39 -06:00
parent 1dbcbfbac8
commit 9c3d059ef0
7 changed files with 61 additions and 5 deletions

View File

@ -35,7 +35,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
# Convert to torch tensor
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).squeeze().unsqueeze(dim=1)
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()

View File

@ -101,6 +101,11 @@ class BaseModel():
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
network = network.module
load_net = torch.load(load_path)
# Support loading torch.save()s for whole models as well as just state_dicts.
if 'state_dict' in load_net:
load_net = load_net['state_dict']
load_net_clean = OrderedDict() # remove unnecessary 'module.'
for k, v in load_net.items():
if k.startswith('module.'):

View File

@ -87,10 +87,12 @@ def define_G(opt, net_key='network_G', scale=None):
netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)
elif which_model == "flownet2":
from models.flownet2.models import FlowNet2
ld = torch.load(opt_net['load_path'])
args = munch.Munch({'fp16': False, 'rgb_max': 1.0})
ld = 'load_path' in opt_net.keys()
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
netG = FlowNet2(args)
netG.load_state_dict(ld['state_dict'])
if ld:
sd = torch.load(opt_net['load_path'])
netG.load_state_dict(sd['state_dict'])
elif which_model == "backbone_encoder":
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
elif which_model == "backbone_encoder_no_ref":

View File

@ -14,6 +14,9 @@ def create_injector(opt_inject, env):
elif 'progressive_' in type:
from models.steps.progressive_zoom import create_progressive_zoom_injector
return create_progressive_zoom_injector(opt_inject, env)
elif 'stereoscopic_' in type:
from models.steps.stereoscopic import create_stereoscopic_injector
return create_stereoscopic_injector(opt_inject, env)
elif type == 'generator':
return ImageGeneratorInjector(opt_inject, env)
elif type == 'discriminator':
@ -42,6 +45,8 @@ def create_injector(opt_inject, env):
return ConstantInjector(opt_inject, env)
elif type == 'fft':
return ImageFftInjector(opt_inject, env)
elif type == 'extract_indices':
return IndicesExtractor(opt_inject, env)
else:
raise NotImplementedError
@ -299,3 +304,17 @@ class ImageFftInjector(Injector):
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
return {self.output: im}
class IndicesExtractor(Injector):
def __init__(self, opt, env):
super(IndicesExtractor, self).__init__(opt, env)
self.dim = opt['dim']
assert self.dim == 1 # Honestly not sure how to support an abstract dim here, so just add yours when needed.
def forward(self, state):
results = {}
for i, o in enumerate(self.output):
if self.dim == 1:
results[o] = state[self.input][:, i]
return results

View File

@ -0,0 +1,22 @@
import torch
from torch.cuda.amp import autocast
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector
def create_stereoscopic_injector(opt, env):
type = opt['type']
if type == 'stereoscopic_resample':
return ResampleInjector(opt, env)
return None
class ResampleInjector(Injector):
def __init__(self, opt, env):
super(ResampleInjector, self).__init__(opt, env)
self.resample = Resample2d()
self.flow = opt['flowfield']
def forward(self, state):
with autocast(enabled=False):
return {self.output: self.resample(state[self.input], state[self.flow])}

View File

@ -278,7 +278,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_3dflow_vr_flownet.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
args = parser.parse_args()
opt = option.parse(args.opt, is_train=True)

View File

@ -51,6 +51,14 @@ def checkpoint(fn, *args):
else:
return fn(*args)
# A fancy alternative to if <flag> checkpoint() else <call>
def possible_checkpoint(enabled, fn, *args):
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
if enabled and opt_en:
return torch.utils.checkpoint.checkpoint(fn, *args)
else:
return fn(*args)
def get_timestamp():
return datetime.now().strftime('%y%m%d-%H%M%S')