Updates to be able to train flownet2 in ExtensibleTrainer
Only supports basic losses for now, though.
This commit is contained in:
parent
1dbcbfbac8
commit
9c3d059ef0
|
@ -35,7 +35,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
||||||
# Convert to torch tensor
|
# Convert to torch tensor
|
||||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
|
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
|
||||||
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
|
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
|
||||||
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
|
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).squeeze().unsqueeze(dim=1)
|
||||||
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
||||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||||
|
|
|
@ -101,6 +101,11 @@ class BaseModel():
|
||||||
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||||
network = network.module
|
network = network.module
|
||||||
load_net = torch.load(load_path)
|
load_net = torch.load(load_path)
|
||||||
|
|
||||||
|
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||||
|
if 'state_dict' in load_net:
|
||||||
|
load_net = load_net['state_dict']
|
||||||
|
|
||||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||||
for k, v in load_net.items():
|
for k, v in load_net.items():
|
||||||
if k.startswith('module.'):
|
if k.startswith('module.'):
|
||||||
|
|
|
@ -87,10 +87,12 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)
|
netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)
|
||||||
elif which_model == "flownet2":
|
elif which_model == "flownet2":
|
||||||
from models.flownet2.models import FlowNet2
|
from models.flownet2.models import FlowNet2
|
||||||
ld = torch.load(opt_net['load_path'])
|
ld = 'load_path' in opt_net.keys()
|
||||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0})
|
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
||||||
netG = FlowNet2(args)
|
netG = FlowNet2(args)
|
||||||
netG.load_state_dict(ld['state_dict'])
|
if ld:
|
||||||
|
sd = torch.load(opt_net['load_path'])
|
||||||
|
netG.load_state_dict(sd['state_dict'])
|
||||||
elif which_model == "backbone_encoder":
|
elif which_model == "backbone_encoder":
|
||||||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||||
elif which_model == "backbone_encoder_no_ref":
|
elif which_model == "backbone_encoder_no_ref":
|
||||||
|
|
|
@ -14,6 +14,9 @@ def create_injector(opt_inject, env):
|
||||||
elif 'progressive_' in type:
|
elif 'progressive_' in type:
|
||||||
from models.steps.progressive_zoom import create_progressive_zoom_injector
|
from models.steps.progressive_zoom import create_progressive_zoom_injector
|
||||||
return create_progressive_zoom_injector(opt_inject, env)
|
return create_progressive_zoom_injector(opt_inject, env)
|
||||||
|
elif 'stereoscopic_' in type:
|
||||||
|
from models.steps.stereoscopic import create_stereoscopic_injector
|
||||||
|
return create_stereoscopic_injector(opt_inject, env)
|
||||||
elif type == 'generator':
|
elif type == 'generator':
|
||||||
return ImageGeneratorInjector(opt_inject, env)
|
return ImageGeneratorInjector(opt_inject, env)
|
||||||
elif type == 'discriminator':
|
elif type == 'discriminator':
|
||||||
|
@ -42,6 +45,8 @@ def create_injector(opt_inject, env):
|
||||||
return ConstantInjector(opt_inject, env)
|
return ConstantInjector(opt_inject, env)
|
||||||
elif type == 'fft':
|
elif type == 'fft':
|
||||||
return ImageFftInjector(opt_inject, env)
|
return ImageFftInjector(opt_inject, env)
|
||||||
|
elif type == 'extract_indices':
|
||||||
|
return IndicesExtractor(opt_inject, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -299,3 +304,17 @@ class ImageFftInjector(Injector):
|
||||||
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
|
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
|
||||||
return {self.output: im}
|
return {self.output: im}
|
||||||
|
|
||||||
|
|
||||||
|
class IndicesExtractor(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(IndicesExtractor, self).__init__(opt, env)
|
||||||
|
self.dim = opt['dim']
|
||||||
|
assert self.dim == 1 # Honestly not sure how to support an abstract dim here, so just add yours when needed.
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
results = {}
|
||||||
|
for i, o in enumerate(self.output):
|
||||||
|
if self.dim == 1:
|
||||||
|
results[o] = state[self.input][:, i]
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
22
codes/models/steps/stereoscopic.py
Normal file
22
codes/models/steps/stereoscopic.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||||
|
from models.steps.injectors import Injector
|
||||||
|
|
||||||
|
|
||||||
|
def create_stereoscopic_injector(opt, env):
|
||||||
|
type = opt['type']
|
||||||
|
if type == 'stereoscopic_resample':
|
||||||
|
return ResampleInjector(opt, env)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleInjector(Injector):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(ResampleInjector, self).__init__(opt, env)
|
||||||
|
self.resample = Resample2d()
|
||||||
|
self.flow = opt['flowfield']
|
||||||
|
|
||||||
|
def forward(self, state):
|
||||||
|
with autocast(enabled=False):
|
||||||
|
return {self.output: self.resample(state[self.input], state[self.flow])}
|
|
@ -278,7 +278,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_3dflow_vr_flownet.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
opt = option.parse(args.opt, is_train=True)
|
opt = option.parse(args.opt, is_train=True)
|
||||||
|
|
|
@ -51,6 +51,14 @@ def checkpoint(fn, *args):
|
||||||
else:
|
else:
|
||||||
return fn(*args)
|
return fn(*args)
|
||||||
|
|
||||||
|
# A fancy alternative to if <flag> checkpoint() else <call>
|
||||||
|
def possible_checkpoint(enabled, fn, *args):
|
||||||
|
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||||
|
if enabled and opt_en:
|
||||||
|
return torch.utils.checkpoint.checkpoint(fn, *args)
|
||||||
|
else:
|
||||||
|
return fn(*args)
|
||||||
|
|
||||||
def get_timestamp():
|
def get_timestamp():
|
||||||
return datetime.now().strftime('%y%m%d-%H%M%S')
|
return datetime.now().strftime('%y%m%d-%H%M%S')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user