Integrate flownet2 into codebase, add teco visual debugs

This commit is contained in:
James Betker 2020-10-06 20:35:39 -06:00
parent e4b89a172f
commit cffc596141
5 changed files with 42 additions and 3 deletions

6
.gitmodules vendored Normal file
View File

@ -0,0 +1,6 @@
[submodule "flownet2"]
path = flownet2
url = https://github.com/NVIDIA/flownet2-pytorch.git
[submodule "codes/models/flownet2"]
path = codes/models/flownet2
url = https://github.com/NVIDIA/flownet2-pytorch.git

View File

@ -29,7 +29,8 @@ class ExtensibleTrainer(BaseModel):
self.env = {'device': self.device, self.env = {'device': self.device,
'rank': self.rank, 'rank': self.rank,
'opt': opt, 'opt': opt,
'step': 0} 'step': 0,
'base_path': os.path.join(opt['path']['models'])}
self.mega_batch_factor = 1 self.mega_batch_factor = 1
if self.is_train: if self.is_train:

1
codes/models/flownet2 Submodule

@ -0,0 +1 @@
Subproject commit 2e9e010c98931bc7cef3eb063b195f1e0ab470ba

View File

@ -14,6 +14,8 @@ import models.archs.rcan as rcan
from collections import OrderedDict from collections import OrderedDict
import torchvision import torchvision
import functools import functools
from models.flownet2.models import FlowNet2
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -90,6 +92,10 @@ def define_G(opt, net_key='network_G', scale=None):
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8 xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'], netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10) init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == "flownet2":
args_dict = {}
args = munchify(args_dict)
netG = FlowNet2(args)
elif which_model == "backbone_encoder": elif which_model == "backbone_encoder":
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet']) netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
elif which_model == "backbone_encoder_no_ref": elif which_model == "backbone_encoder_no_ref":

View File

@ -3,6 +3,9 @@ from models.layers.resample2d_package.resample2d import Resample2d
from models.steps.recurrent import RecurrentController from models.steps.recurrent import RecurrentController
from models.steps.injectors import Injector from models.steps.injectors import Injector
import torch import torch
import os
import os.path as osp
import torchvision
def create_teco_loss(opt, env): def create_teco_loss(opt, env):
type = opt['type'] type = opt['type']
@ -114,7 +117,6 @@ class TecoGanDiscriminatorLoss(ConfigurableLoss):
class TecoGanGeneratorLoss(ConfigurableLoss): class TecoGanGeneratorLoss(ConfigurableLoss):
def __init__(self, opt, env): def __init__(self, opt, env):
super(TecoGanGeneratorLoss, self).__init__(opt, env) super(TecoGanGeneratorLoss, self).__init__(opt, env)
self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device']) self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
# TecoGAN parameters # TecoGAN parameters
self.image_flow_generator = opt['image_flow_generator'] self.image_flow_generator = opt['image_flow_generator']
@ -130,6 +132,10 @@ class TecoGanGeneratorLoss(ConfigurableLoss):
fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler) fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler)
d_fake = net(fake_sext) d_fake = net(fake_sext)
if self.env['step'] % 100 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
self.produce_teco_visual_debugs(real_sext, 'real', i)
if self.opt['gan_type'] in ['gan', 'pixgan']: if self.opt['gan_type'] in ['gan', 'pixgan']:
self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_fake", torch.mean(d_fake)))
l_fake = self.criterion(d_fake, True) l_fake = self.criterion(d_fake, True)
@ -142,8 +148,16 @@ class TecoGanGeneratorLoss(ConfigurableLoss):
self.criterion(d_fake_diff, True)) self.criterion(d_fake_diff, True))
else: else:
raise NotImplementedError raise NotImplementedError
return l_total return l_total
def produce_teco_visual_debugs(self, sext, lbl, it):
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True)
lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow']
for i in range(6):
torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it)))
# This loss doesn't have a real entry - only fakes are used. # This loss doesn't have a real entry - only fakes are used.
class PingPongLoss(ConfigurableLoss): class PingPongLoss(ConfigurableLoss):
@ -159,4 +173,15 @@ class PingPongLoss(ConfigurableLoss):
early = fake[i] early = fake[i]
late = fake[-i] late = fake[-i]
l_total += self.criterion(early, late) l_total += self.criterion(early, late)
return l_total
if self.env['step'] % 100 == 0:
self.produce_teco_visual_debugs(fake)
return l_total
def produce_teco_visual_debugs(self, imglist):
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list)
for i, img in enumerate(imglist):
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))