forked from mrq/DL-Art-School
Integrate flownet2 into codebase, add teco visual debugs
This commit is contained in:
parent
e4b89a172f
commit
cffc596141
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
[submodule "flownet2"]
|
||||
path = flownet2
|
||||
url = https://github.com/NVIDIA/flownet2-pytorch.git
|
||||
[submodule "codes/models/flownet2"]
|
||||
path = codes/models/flownet2
|
||||
url = https://github.com/NVIDIA/flownet2-pytorch.git
|
|
@ -29,7 +29,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
self.env = {'device': self.device,
|
||||
'rank': self.rank,
|
||||
'opt': opt,
|
||||
'step': 0}
|
||||
'step': 0,
|
||||
'base_path': os.path.join(opt['path']['models'])}
|
||||
|
||||
self.mega_batch_factor = 1
|
||||
if self.is_train:
|
||||
|
|
1
codes/models/flownet2
Submodule
1
codes/models/flownet2
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 2e9e010c98931bc7cef3eb063b195f1e0ab470ba
|
|
@ -14,6 +14,8 @@ import models.archs.rcan as rcan
|
|||
from collections import OrderedDict
|
||||
import torchvision
|
||||
import functools
|
||||
from models.flownet2.models import FlowNet2
|
||||
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -90,6 +92,10 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||
netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||
elif which_model == "flownet2":
|
||||
args_dict = {}
|
||||
args = munchify(args_dict)
|
||||
netG = FlowNet2(args)
|
||||
elif which_model == "backbone_encoder":
|
||||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||
elif which_model == "backbone_encoder_no_ref":
|
||||
|
|
|
@ -3,6 +3,9 @@ from models.layers.resample2d_package.resample2d import Resample2d
|
|||
from models.steps.recurrent import RecurrentController
|
||||
from models.steps.injectors import Injector
|
||||
import torch
|
||||
import os
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
|
||||
def create_teco_loss(opt, env):
|
||||
type = opt['type']
|
||||
|
@ -114,7 +117,6 @@ class TecoGanDiscriminatorLoss(ConfigurableLoss):
|
|||
class TecoGanGeneratorLoss(ConfigurableLoss):
|
||||
def __init__(self, opt, env):
|
||||
super(TecoGanGeneratorLoss, self).__init__(opt, env)
|
||||
self.opt = opt
|
||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||
# TecoGAN parameters
|
||||
self.image_flow_generator = opt['image_flow_generator']
|
||||
|
@ -130,6 +132,10 @@ class TecoGanGeneratorLoss(ConfigurableLoss):
|
|||
fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler)
|
||||
d_fake = net(fake_sext)
|
||||
|
||||
if self.env['step'] % 100 == 0:
|
||||
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
||||
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
||||
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
l_fake = self.criterion(d_fake, True)
|
||||
|
@ -142,8 +148,16 @@ class TecoGanGeneratorLoss(ConfigurableLoss):
|
|||
self.criterion(d_fake_diff, True))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow']
|
||||
for i in range(6):
|
||||
torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it)))
|
||||
|
||||
|
||||
# This loss doesn't have a real entry - only fakes are used.
|
||||
class PingPongLoss(ConfigurableLoss):
|
||||
|
@ -159,4 +173,15 @@ class PingPongLoss(ConfigurableLoss):
|
|||
early = fake[i]
|
||||
late = fake[-i]
|
||||
l_total += self.criterion(early, late)
|
||||
return l_total
|
||||
|
||||
if self.env['step'] % 100 == 0:
|
||||
self.produce_teco_visual_debugs(fake)
|
||||
|
||||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, imglist):
|
||||
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
assert isinstance(imglist, list)
|
||||
for i, img in enumerate(imglist):
|
||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
Loading…
Reference in New Issue
Block a user