Updates to be able to train flownet2 in ExtensibleTrainer
Only supports basic losses for now, though.
This commit is contained in:
parent
1dbcbfbac8
commit
9c3d059ef0
|
@ -35,7 +35,7 @@ class PairedFrameDataset(BaseUnsupervisedImageDataset):
|
|||
# Convert to torch tensor
|
||||
hq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hs), (0, 3, 1, 2)))).float()
|
||||
hq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(hrs), (0, 3, 1, 2)))).float()
|
||||
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).unsqueeze(dim=1)
|
||||
hq_mask = torch.from_numpy(np.ascontiguousarray(np.stack(hms))).squeeze().unsqueeze(dim=1)
|
||||
hq_ref = torch.cat([hq_ref, hq_mask], dim=1)
|
||||
lq = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(ls), (0, 3, 1, 2)))).float()
|
||||
lq_ref = torch.from_numpy(np.ascontiguousarray(np.transpose(np.stack(lrs), (0, 3, 1, 2)))).float()
|
||||
|
|
|
@ -101,6 +101,11 @@ class BaseModel():
|
|||
if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
|
||||
network = network.module
|
||||
load_net = torch.load(load_path)
|
||||
|
||||
# Support loading torch.save()s for whole models as well as just state_dicts.
|
||||
if 'state_dict' in load_net:
|
||||
load_net = load_net['state_dict']
|
||||
|
||||
load_net_clean = OrderedDict() # remove unnecessary 'module.'
|
||||
for k, v in load_net.items():
|
||||
if k.startswith('module.'):
|
||||
|
|
|
@ -87,10 +87,12 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
netG = chained.MultifacetedChainedEmbeddingGen(depth=opt_net['depth'], scale=scale)
|
||||
elif which_model == "flownet2":
|
||||
from models.flownet2.models import FlowNet2
|
||||
ld = torch.load(opt_net['load_path'])
|
||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0})
|
||||
ld = 'load_path' in opt_net.keys()
|
||||
args = munch.Munch({'fp16': False, 'rgb_max': 1.0, 'checkpoint': not ld})
|
||||
netG = FlowNet2(args)
|
||||
netG.load_state_dict(ld['state_dict'])
|
||||
if ld:
|
||||
sd = torch.load(opt_net['load_path'])
|
||||
netG.load_state_dict(sd['state_dict'])
|
||||
elif which_model == "backbone_encoder":
|
||||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||
elif which_model == "backbone_encoder_no_ref":
|
||||
|
|
|
@ -14,6 +14,9 @@ def create_injector(opt_inject, env):
|
|||
elif 'progressive_' in type:
|
||||
from models.steps.progressive_zoom import create_progressive_zoom_injector
|
||||
return create_progressive_zoom_injector(opt_inject, env)
|
||||
elif 'stereoscopic_' in type:
|
||||
from models.steps.stereoscopic import create_stereoscopic_injector
|
||||
return create_stereoscopic_injector(opt_inject, env)
|
||||
elif type == 'generator':
|
||||
return ImageGeneratorInjector(opt_inject, env)
|
||||
elif type == 'discriminator':
|
||||
|
@ -42,6 +45,8 @@ def create_injector(opt_inject, env):
|
|||
return ConstantInjector(opt_inject, env)
|
||||
elif type == 'fft':
|
||||
return ImageFftInjector(opt_inject, env)
|
||||
elif type == 'extract_indices':
|
||||
return IndicesExtractor(opt_inject, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -299,3 +304,17 @@ class ImageFftInjector(Injector):
|
|||
im = torch.irfft(fftim, signal_ndim=2, normalized=True)
|
||||
return {self.output: im}
|
||||
|
||||
|
||||
class IndicesExtractor(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(IndicesExtractor, self).__init__(opt, env)
|
||||
self.dim = opt['dim']
|
||||
assert self.dim == 1 # Honestly not sure how to support an abstract dim here, so just add yours when needed.
|
||||
|
||||
def forward(self, state):
|
||||
results = {}
|
||||
for i, o in enumerate(self.output):
|
||||
if self.dim == 1:
|
||||
results[o] = state[self.input][:, i]
|
||||
return results
|
||||
|
||||
|
|
22
codes/models/steps/stereoscopic.py
Normal file
22
codes/models/steps/stereoscopic.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from models.steps.injectors import Injector
|
||||
|
||||
|
||||
def create_stereoscopic_injector(opt, env):
|
||||
type = opt['type']
|
||||
if type == 'stereoscopic_resample':
|
||||
return ResampleInjector(opt, env)
|
||||
return None
|
||||
|
||||
|
||||
class ResampleInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super(ResampleInjector, self).__init__(opt, env)
|
||||
self.resample = Resample2d()
|
||||
self.flow = opt['flowfield']
|
||||
|
||||
def forward(self, state):
|
||||
with autocast(enabled=False):
|
||||
return {self.output: self.resample(state[self.input], state[self.flow])}
|
|
@ -278,7 +278,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_prog_imgset_multifaceted_chained.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_3dflow_vr_flownet.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
args = parser.parse_args()
|
||||
opt = option.parse(args.opt, is_train=True)
|
||||
|
|
|
@ -51,6 +51,14 @@ def checkpoint(fn, *args):
|
|||
else:
|
||||
return fn(*args)
|
||||
|
||||
# A fancy alternative to if <flag> checkpoint() else <call>
|
||||
def possible_checkpoint(enabled, fn, *args):
|
||||
opt_en = loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in loaded_options.keys() else True
|
||||
if enabled and opt_en:
|
||||
return torch.utils.checkpoint.checkpoint(fn, *args)
|
||||
else:
|
||||
return fn(*args)
|
||||
|
||||
def get_timestamp():
|
||||
return datetime.now().strftime('%y%m%d-%H%M%S')
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user