Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-10-06 20:41:58 -06:00
commit 8a7e993aea
5 changed files with 42 additions and 3 deletions

6
.gitmodules vendored Normal file
View File

@ -0,0 +1,6 @@
[submodule "flownet2"]
path = flownet2
url = https://github.com/NVIDIA/flownet2-pytorch.git
[submodule "codes/models/flownet2"]
path = codes/models/flownet2
url = https://github.com/NVIDIA/flownet2-pytorch.git

View File

@ -29,7 +29,8 @@ class ExtensibleTrainer(BaseModel):
self.env = {'device': self.device,
'rank': self.rank,
'opt': opt,
'step': 0}
'step': 0,
'base_path': os.path.join(opt['path']['models'])}
self.mega_batch_factor = 1
if self.is_train:

1
codes/models/flownet2 Submodule

@ -0,0 +1 @@
Subproject commit 2e9e010c98931bc7cef3eb063b195f1e0ab470ba

View File

@ -14,6 +14,8 @@ import models.archs.rcan as rcan
from collections import OrderedDict
import torchvision
import functools
from models.flownet2.models import FlowNet2
logger = logging.getLogger('base')
@ -94,6 +96,10 @@ def define_G(opt, net_key='network_G', scale=None):
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
elif which_model == "flownet2":
args_dict = {}
args = munchify(args_dict)
netG = FlowNet2(args)
elif which_model == "backbone_encoder":
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
elif which_model == "backbone_encoder_no_ref":

View File

@ -3,6 +3,9 @@ from models.layers.resample2d_package.resample2d import Resample2d
from models.steps.recurrent import RecurrentController
from models.steps.injectors import Injector
import torch
import os
import os.path as osp
import torchvision
def create_teco_loss(opt, env):
type = opt['type']
@ -114,7 +117,6 @@ class TecoGanDiscriminatorLoss(ConfigurableLoss):
class TecoGanGeneratorLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(TecoGanGeneratorLoss, self).__init__(opt, env)
self.opt = opt
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
# TecoGAN parameters
self.image_flow_generator = opt['image_flow_generator']
@ -130,6 +132,10 @@ class TecoGanGeneratorLoss(ConfigurableLoss):
fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler)
d_fake = net(fake_sext)
if self.env['step'] % 100 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
self.produce_teco_visual_debugs(real_sext, 'real', i)
if self.opt['gan_type'] in ['gan', 'pixgan']:
self.metrics.append(("d_fake", torch.mean(d_fake)))
l_fake = self.criterion(d_fake, True)
@ -142,8 +148,16 @@ class TecoGanGeneratorLoss(ConfigurableLoss):
self.criterion(d_fake_diff, True))
else:
raise NotImplementedError
return l_total
def produce_teco_visual_debugs(self, sext, lbl, it):
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True)
lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow']
for i in range(6):
torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it)))
# This loss doesn't have a real entry - only fakes are used.
class PingPongLoss(ConfigurableLoss):
@ -159,4 +173,15 @@ class PingPongLoss(ConfigurableLoss):
early = fake[i]
late = fake[-i]
l_total += self.criterion(early, late)
return l_total
if self.env['step'] % 100 == 0:
self.produce_teco_visual_debugs(fake)
return l_total
def produce_teco_visual_debugs(self, imglist):
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list)
for i, img in enumerate(imglist):
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))