forked from mrq/DL-Art-School
Integrate flownet2 into codebase, add teco visual debugs
This commit is contained in:
parent
e4b89a172f
commit
cffc596141
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
[submodule "flownet2"]
|
||||||
|
path = flownet2
|
||||||
|
url = https://github.com/NVIDIA/flownet2-pytorch.git
|
||||||
|
[submodule "codes/models/flownet2"]
|
||||||
|
path = codes/models/flownet2
|
||||||
|
url = https://github.com/NVIDIA/flownet2-pytorch.git
|
|
@ -29,7 +29,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
self.env = {'device': self.device,
|
self.env = {'device': self.device,
|
||||||
'rank': self.rank,
|
'rank': self.rank,
|
||||||
'opt': opt,
|
'opt': opt,
|
||||||
'step': 0}
|
'step': 0,
|
||||||
|
'base_path': os.path.join(opt['path']['models'])}
|
||||||
|
|
||||||
self.mega_batch_factor = 1
|
self.mega_batch_factor = 1
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
|
1
codes/models/flownet2
Submodule
1
codes/models/flownet2
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 2e9e010c98931bc7cef3eb063b195f1e0ab470ba
|
|
@ -14,6 +14,8 @@ import models.archs.rcan as rcan
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import torchvision
|
import torchvision
|
||||||
import functools
|
import functools
|
||||||
|
from models.flownet2.models import FlowNet2
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -90,6 +92,10 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||||
netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||||
|
elif which_model == "flownet2":
|
||||||
|
args_dict = {}
|
||||||
|
args = munchify(args_dict)
|
||||||
|
netG = FlowNet2(args)
|
||||||
elif which_model == "backbone_encoder":
|
elif which_model == "backbone_encoder":
|
||||||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||||
elif which_model == "backbone_encoder_no_ref":
|
elif which_model == "backbone_encoder_no_ref":
|
||||||
|
|
|
@ -3,6 +3,9 @@ from models.layers.resample2d_package.resample2d import Resample2d
|
||||||
from models.steps.recurrent import RecurrentController
|
from models.steps.recurrent import RecurrentController
|
||||||
from models.steps.injectors import Injector
|
from models.steps.injectors import Injector
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import torchvision
|
||||||
|
|
||||||
def create_teco_loss(opt, env):
|
def create_teco_loss(opt, env):
|
||||||
type = opt['type']
|
type = opt['type']
|
||||||
|
@ -114,7 +117,6 @@ class TecoGanDiscriminatorLoss(ConfigurableLoss):
|
||||||
class TecoGanGeneratorLoss(ConfigurableLoss):
|
class TecoGanGeneratorLoss(ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(TecoGanGeneratorLoss, self).__init__(opt, env)
|
super(TecoGanGeneratorLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
# TecoGAN parameters
|
# TecoGAN parameters
|
||||||
self.image_flow_generator = opt['image_flow_generator']
|
self.image_flow_generator = opt['image_flow_generator']
|
||||||
|
@ -130,6 +132,10 @@ class TecoGanGeneratorLoss(ConfigurableLoss):
|
||||||
fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler)
|
fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler)
|
||||||
d_fake = net(fake_sext)
|
d_fake = net(fake_sext)
|
||||||
|
|
||||||
|
if self.env['step'] % 100 == 0:
|
||||||
|
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
||||||
|
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
l_fake = self.criterion(d_fake, True)
|
l_fake = self.criterion(d_fake, True)
|
||||||
|
@ -142,8 +148,16 @@ class TecoGanGeneratorLoss(ConfigurableLoss):
|
||||||
self.criterion(d_fake_diff, True))
|
self.criterion(d_fake_diff, True))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
|
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||||
|
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||||
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow']
|
||||||
|
for i in range(6):
|
||||||
|
torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it)))
|
||||||
|
|
||||||
|
|
||||||
# This loss doesn't have a real entry - only fakes are used.
|
# This loss doesn't have a real entry - only fakes are used.
|
||||||
class PingPongLoss(ConfigurableLoss):
|
class PingPongLoss(ConfigurableLoss):
|
||||||
|
@ -159,4 +173,15 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
early = fake[i]
|
early = fake[i]
|
||||||
late = fake[-i]
|
late = fake[-i]
|
||||||
l_total += self.criterion(early, late)
|
l_total += self.criterion(early, late)
|
||||||
return l_total
|
|
||||||
|
if self.env['step'] % 100 == 0:
|
||||||
|
self.produce_teco_visual_debugs(fake)
|
||||||
|
|
||||||
|
return l_total
|
||||||
|
|
||||||
|
def produce_teco_visual_debugs(self, imglist):
|
||||||
|
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||||
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
assert isinstance(imglist, list)
|
||||||
|
for i, img in enumerate(imglist):
|
||||||
|
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
Loading…
Reference in New Issue
Block a user