diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..bf1b51ab --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "flownet2"] + path = flownet2 + url = https://github.com/NVIDIA/flownet2-pytorch.git +[submodule "codes/models/flownet2"] + path = codes/models/flownet2 + url = https://github.com/NVIDIA/flownet2-pytorch.git diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 663d60f9..02e669b1 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -29,7 +29,8 @@ class ExtensibleTrainer(BaseModel): self.env = {'device': self.device, 'rank': self.rank, 'opt': opt, - 'step': 0} + 'step': 0, + 'base_path': os.path.join(opt['path']['models'])} self.mega_batch_factor = 1 if self.is_train: diff --git a/codes/models/flownet2 b/codes/models/flownet2 new file mode 160000 index 00000000..2e9e010c --- /dev/null +++ b/codes/models/flownet2 @@ -0,0 +1 @@ +Subproject commit 2e9e010c98931bc7cef3eb063b195f1e0ab470ba diff --git a/codes/models/networks.py b/codes/models/networks.py index f346711a..530e57da 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -14,6 +14,8 @@ import models.archs.rcan as rcan from collections import OrderedDict import torchvision import functools +from models.flownet2.models import FlowNet2 + logger = logging.getLogger('base') @@ -94,6 +96,10 @@ def define_G(opt, net_key='network_G', scale=None): xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) + elif which_model == "flownet2": + args_dict = {} + args = munchify(args_dict) + netG = FlowNet2(args) elif which_model == "backbone_encoder": netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) elif which_model == "backbone_encoder_no_ref": diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 40d66f9b..a184f51f 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -3,6 +3,9 @@ from models.layers.resample2d_package.resample2d import Resample2d from models.steps.recurrent import RecurrentController from models.steps.injectors import Injector import torch +import os +import os.path as osp +import torchvision def create_teco_loss(opt, env): type = opt['type'] @@ -114,7 +117,6 @@ class TecoGanDiscriminatorLoss(ConfigurableLoss): class TecoGanGeneratorLoss(ConfigurableLoss): def __init__(self, opt, env): super(TecoGanGeneratorLoss, self).__init__(opt, env) - self.opt = opt self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) # TecoGAN parameters self.image_flow_generator = opt['image_flow_generator'] @@ -130,6 +132,10 @@ class TecoGanGeneratorLoss(ConfigurableLoss): fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) d_fake = net(fake_sext) + if self.env['step'] % 100 == 0: + self.produce_teco_visual_debugs(fake_sext, 'fake', i) + self.produce_teco_visual_debugs(real_sext, 'real', i) + if self.opt['gan_type'] in ['gan', 'pixgan']: self.metrics.append(("d_fake", torch.mean(d_fake))) l_fake = self.criterion(d_fake, True) @@ -142,8 +148,16 @@ class TecoGanGeneratorLoss(ConfigurableLoss): self.criterion(d_fake_diff, True)) else: raise NotImplementedError + return l_total + def produce_teco_visual_debugs(self, sext, lbl, it): + base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl) + os.makedirs(base_path, exist_ok=True) + lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow'] + for i in range(6): + torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it))) + # This loss doesn't have a real entry - only fakes are used. class PingPongLoss(ConfigurableLoss): @@ -159,4 +173,15 @@ class PingPongLoss(ConfigurableLoss): early = fake[i] late = fake[-i] l_total += self.criterion(early, late) - return l_total \ No newline at end of file + + if self.env['step'] % 100 == 0: + self.produce_teco_visual_debugs(fake) + + return l_total + + def produce_teco_visual_debugs(self, imglist): + base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step'])) + os.makedirs(base_path, exist_ok=True) + assert isinstance(imglist, list) + for i, img in enumerate(imglist): + torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) \ No newline at end of file