forked from mrq/DL-Art-School
Tecogan work
Its training! There's still probably plenty of bugs though..
This commit is contained in:
parent
e9d7371a61
commit
1c44d395af
|
@ -1,3 +1,4 @@
|
||||||
|
import munch
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from munch import munchify
|
from munch import munchify
|
||||||
|
@ -14,7 +15,6 @@ import models.archs.rcan as rcan
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import torchvision
|
import torchvision
|
||||||
import functools
|
import functools
|
||||||
from models.flownet2.models import FlowNet2
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
@ -86,20 +86,24 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||||
elif which_model == 'stacked_switches':
|
elif which_model == 'stacked_switches':
|
||||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||||
netG = ssg.StackedSwitchGenerator(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3
|
||||||
|
netG = ssg.StackedSwitchGenerator(in_nc=in_nc, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||||
elif which_model == 'stacked_switches_5lyr':
|
elif which_model == 'stacked_switches_5lyr':
|
||||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||||
netG = ssg.StackedSwitchGenerator5Layer(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
in_nc = opt_net['in_nc'] if 'in_nc' in opt_net.keys() else 3
|
||||||
|
netG = ssg.StackedSwitchGenerator5Layer(in_nc=in_nc, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||||
elif which_model == 'ssg_deep':
|
elif which_model == 'ssg_deep':
|
||||||
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
xforms = opt_net['num_transforms'] if 'num_transforms' in opt_net.keys() else 8
|
||||||
netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
netG = ssg.SSGDeep(in_nc=3, out_nc=3, nf=opt_net['nf'], xforms=xforms, upscale=opt_net['scale'],
|
||||||
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
init_temperature=opt_net['temperature'] if 'temperature' in opt_net.keys() else 10)
|
||||||
elif which_model == "flownet2":
|
elif which_model == "flownet2":
|
||||||
args_dict = {}
|
from models.flownet2.models import FlowNet2
|
||||||
args = munchify(args_dict)
|
ld = torch.load(opt_net['load_path'])
|
||||||
|
args = munch.Munch({'fp16': False, 'rgb_max': 1.0})
|
||||||
netG = FlowNet2(args)
|
netG = FlowNet2(args)
|
||||||
|
netG.load_state_dict(ld['state_dict'])
|
||||||
elif which_model == "backbone_encoder":
|
elif which_model == "backbone_encoder":
|
||||||
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
netG = SwitchedGen_arch.BackboneEncoder(pretrained_backbone=opt_net['pretrained_spinenet'])
|
||||||
elif which_model == "backbone_encoder_no_ref":
|
elif which_model == "backbone_encoder_no_ref":
|
||||||
|
|
|
@ -28,6 +28,8 @@ def create_loss(opt_loss, env):
|
||||||
return TranslationInvarianceLoss(opt_loss, env)
|
return TranslationInvarianceLoss(opt_loss, env)
|
||||||
elif type == 'recursive':
|
elif type == 'recursive':
|
||||||
return RecursiveInvarianceLoss(opt_loss, env)
|
return RecursiveInvarianceLoss(opt_loss, env)
|
||||||
|
elif type == 'recurrent':
|
||||||
|
return RecurrentLoss(opt_loss, env)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -372,3 +374,25 @@ class RecursiveInvarianceLoss(ConfigurableLoss):
|
||||||
else:
|
else:
|
||||||
return self.criterion(compare_real, compare_fake)
|
return self.criterion(compare_real, compare_fake)
|
||||||
|
|
||||||
|
|
||||||
|
# Loss that pulls tensors from dim 1 of the input and repeatedly feeds them into the
|
||||||
|
# 'subtype' loss.
|
||||||
|
class RecurrentLoss(ConfigurableLoss):
|
||||||
|
def __init__(self, opt, env):
|
||||||
|
super(RecurrentLoss, self).__init__(opt, env)
|
||||||
|
o = opt.copy()
|
||||||
|
o['type'] = opt['subtype']
|
||||||
|
o['fake'] = '_fake'
|
||||||
|
o['real'] = '_real'
|
||||||
|
self.loss = create_loss(o, self.env)
|
||||||
|
|
||||||
|
def forward(self, net, state):
|
||||||
|
total_loss = 0
|
||||||
|
st = state.copy()
|
||||||
|
real = state[self.opt['real']]
|
||||||
|
for i in range(real.shape[1]):
|
||||||
|
st['_real'] = real[:, i]
|
||||||
|
st['_fake'] = state[self.opt['fake']][i]
|
||||||
|
total_loss += self.loss(net, st)
|
||||||
|
return total_loss
|
||||||
|
|
||||||
|
|
|
@ -1,29 +1,52 @@
|
||||||
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state
|
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||||
from models.layers.resample2d_package.resample2d import Resample2d
|
from models.layers.resample2d_package.resample2d import Resample2d
|
||||||
from models.steps.recurrent import RecurrentController
|
from models.steps.recurrent import RecurrentController
|
||||||
from models.steps.injectors import Injector
|
from models.steps.injectors import Injector
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
def create_teco_loss(opt, env):
|
def create_teco_loss(opt, env):
|
||||||
type = opt['type']
|
type = opt['type']
|
||||||
if type == 'teco_generator_gan':
|
if type == 'teco_gan':
|
||||||
return TecoGanGeneratorLoss(opt, env)
|
return TecoGanLoss(opt, env)
|
||||||
elif type == 'teco_discriminator_gan':
|
|
||||||
return TecoGanDiscriminatorLoss(opt, env)
|
|
||||||
elif type == "teco_pingpong":
|
elif type == "teco_pingpong":
|
||||||
return PingPongLoss(opt, env)
|
return PingPongLoss(opt, env)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def create_teco_discriminator_sextuplet(input_list, index, flow_gen, resampler):
|
def create_teco_injector(opt, env):
|
||||||
triplet = input_list[index:index+3]
|
type = opt['type']
|
||||||
first_flow = flow_gen(triplet[0], triplet[1])
|
if type == 'teco_recurrent_generated_sequence_injector':
|
||||||
last_flow = flow_gen(triplet[2], triplet[1])
|
return RecurrentImageGeneratorSequenceInjector(opt, env)
|
||||||
flow_triplet = [resampler(triplet[0], first_flow), triplet[1], resampler(triplet[2], last_flow)]
|
return None
|
||||||
return torch.cat(triplet + flow_triplet, dim=1)
|
|
||||||
|
|
||||||
|
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler):
|
||||||
|
triplet = input_list[:, index:index+3]
|
||||||
|
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
||||||
|
with torch.no_grad():
|
||||||
|
first_flow = flow_gen(torch.stack([lr_imgs[:,0], lr_imgs[:,1]], dim=2))
|
||||||
|
first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
||||||
|
last_flow = flow_gen(torch.stack([lr_imgs[:,2], lr_imgs[:,1]], dim=2))
|
||||||
|
last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
|
||||||
|
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
|
||||||
|
triplet[:,1],
|
||||||
|
resampler(triplet[:,2].float(), last_flow.float())]
|
||||||
|
flow_triplet = torch.stack(flow_triplet, dim=2)
|
||||||
|
combined = torch.cat([triplet, flow_triplet], dim=2)
|
||||||
|
b, f, c, h, w = combined.shape
|
||||||
|
return combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
|
||||||
|
|
||||||
|
|
||||||
|
def extract_inputs_index(inputs, i):
|
||||||
|
res = []
|
||||||
|
for input in inputs:
|
||||||
|
if isinstance(input, torch.Tensor):
|
||||||
|
res.append(input[:, i])
|
||||||
|
else:
|
||||||
|
res.append(input)
|
||||||
|
return res
|
||||||
|
|
||||||
# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out]
|
# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out]
|
||||||
# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated).
|
# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated).
|
||||||
|
@ -32,32 +55,51 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
|
||||||
self.flow = opt['flow_network']
|
self.flow = opt['flow_network']
|
||||||
|
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||||
|
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||||
|
self.scale = opt['scale']
|
||||||
self.resample = Resample2d()
|
self.resample = Resample2d()
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
flow = self.env['generators'][self.flow]
|
flow = self.env['generators'][self.flow]
|
||||||
results = []
|
results = []
|
||||||
recurrent_input = torch.zeros_like(state[self.input][0])
|
inputs = extract_params_from_state(self.input, state)
|
||||||
|
if not isinstance(inputs, list):
|
||||||
|
inputs = [inputs]
|
||||||
|
recurrent_input = torch.zeros_like(inputs[self.input_lq_index][:,0])
|
||||||
|
|
||||||
# Go forward in the sequence first.
|
# Go forward in the sequence first.
|
||||||
first_step = True
|
first_step = True
|
||||||
for input in state[self.input]:
|
b, f, c, h, w = inputs[self.input_lq_index].shape
|
||||||
|
for i in range(f):
|
||||||
|
input = extract_inputs_index(inputs, i)
|
||||||
if first_step:
|
if first_step:
|
||||||
first_step = False
|
first_step = False
|
||||||
else:
|
else:
|
||||||
flowfield = flow(recurrent_input, input)
|
with torch.no_grad():
|
||||||
recurrent_input = self.resample(recurrent_input, flowfield)
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
||||||
recurrent_input = gen(input, recurrent_input)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
|
flowfield = flow(flow_input)
|
||||||
|
# Resample does not work in FP16.
|
||||||
|
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||||
|
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||||
|
gen_out = gen(*input)
|
||||||
|
recurrent_input = gen_out[self.output_hq_index]
|
||||||
results.append(recurrent_input)
|
results.append(recurrent_input)
|
||||||
recurrent_input = self.flow()
|
|
||||||
|
|
||||||
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
||||||
it = reversed(range(len(results) - 1))
|
it = reversed(range(f - 1))
|
||||||
for i in it:
|
for i in it:
|
||||||
flowfield = flow(recurrent_input, results[i])
|
input = extract_inputs_index(inputs, i)
|
||||||
recurrent_input = self.resample(recurrent_input, flowfield)
|
with torch.no_grad():
|
||||||
recurrent_input = gen(results[i], recurrent_input)
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||||
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
|
flowfield = flow(flow_input)
|
||||||
|
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||||
|
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||||
|
gen_out = gen(*input)
|
||||||
|
recurrent_input = gen_out[self.output_hq_index]
|
||||||
results.append(recurrent_input)
|
results.append(recurrent_input)
|
||||||
|
|
||||||
return {self.output: results}
|
return {self.output: results}
|
||||||
|
@ -76,76 +118,48 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
# 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor.
|
# 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor.
|
||||||
# 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward().
|
# 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward().
|
||||||
# 6) Repeat from (1) until all triplets from the real sequence have been exhausted.
|
# 6) Repeat from (1) until all triplets from the real sequence have been exhausted.
|
||||||
class TecoGanDiscriminatorLoss(ConfigurableLoss):
|
class TecoGanLoss(ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(TecoGanDiscriminatorLoss, self).__init__(opt, env)
|
super(TecoGanLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
|
||||||
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
|
||||||
self.image_flow_generator = opt['image_flow_generator']
|
|
||||||
self.resampler = Resample2d()
|
|
||||||
|
|
||||||
def forward(self, net, state):
|
|
||||||
self.metrics = []
|
|
||||||
flow_gen = self.env['generators'][self.image_flow_generator]
|
|
||||||
real = state[self.opt['real']]
|
|
||||||
fake = state[self.opt['fake']]
|
|
||||||
l_total = 0
|
|
||||||
for i in range(len(real) - 2):
|
|
||||||
real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler)
|
|
||||||
fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler)
|
|
||||||
|
|
||||||
d_real = net(real_sext)
|
|
||||||
d_fake = net(fake_sext)
|
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
|
||||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
|
||||||
l_real = self.criterion(d_real, True)
|
|
||||||
l_fake = self.criterion(d_fake, False)
|
|
||||||
l_total += l_real + l_fake
|
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
|
||||||
d_fake_diff = d_fake - torch.mean(d_real)
|
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
|
||||||
l_total += (self.criterion(d_real - torch.mean(d_fake), True) +
|
|
||||||
self.criterion(d_fake_diff, False))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
return l_total
|
|
||||||
|
|
||||||
|
|
||||||
class TecoGanGeneratorLoss(ConfigurableLoss):
|
|
||||||
def __init__(self, opt, env):
|
|
||||||
super(TecoGanGeneratorLoss, self).__init__(opt, env)
|
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
# TecoGAN parameters
|
# TecoGAN parameters
|
||||||
|
self.scale = opt['scale']
|
||||||
|
self.lr_inputs = opt['lr_inputs']
|
||||||
self.image_flow_generator = opt['image_flow_generator']
|
self.image_flow_generator = opt['image_flow_generator']
|
||||||
self.resampler = Resample2d()
|
self.resampler = Resample2d()
|
||||||
|
self.for_generator = opt['for_generator']
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
|
net = self.env['discriminators'][self.opt['discriminator']]
|
||||||
flow_gen = self.env['generators'][self.image_flow_generator]
|
flow_gen = self.env['generators'][self.image_flow_generator]
|
||||||
real = state[self.opt['real']]
|
real = state[self.opt['real']]
|
||||||
fake = state[self.opt['fake']]
|
fake = torch.stack(state[self.opt['fake']], dim=1)
|
||||||
|
sequence_len = real.shape[1]
|
||||||
|
lr = state[self.opt['lr_inputs']]
|
||||||
l_total = 0
|
l_total = 0
|
||||||
for i in range(len(real) - 2):
|
for i in range(sequence_len - 2):
|
||||||
real_sext = create_teco_discriminator_sextuplet(real, i, flow_gen, self.resampler)
|
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler)
|
||||||
fake_sext = create_teco_discriminator_sextuplet(fake, i, flow_gen, self.resampler)
|
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler)
|
||||||
d_fake = net(fake_sext)
|
d_fake = net(fake_sext)
|
||||||
|
|
||||||
if self.env['step'] % 100 == 0:
|
if self.for_generator and self.env['step'] % 100 == 0:
|
||||||
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
||||||
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
l_fake = self.criterion(d_fake, True)
|
l_fake = self.criterion(d_fake, self.for_generator)
|
||||||
l_total += l_fake
|
if not self.for_generator:
|
||||||
|
l_real = self.criterion(d_real, True)
|
||||||
|
else:
|
||||||
|
l_real = 0
|
||||||
|
l_total += l_fake + l_real
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
d_real = net(real_sext)
|
d_real = net(real_sext)
|
||||||
d_fake_diff = d_fake - torch.mean(d_real)
|
d_fake_diff = d_fake - torch.mean(d_real)
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
l_total += (self.criterion(d_real - torch.mean(d_fake), False) +
|
l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
|
||||||
self.criterion(d_fake_diff, True))
|
self.criterion(d_fake_diff, self.for_generator))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -164,12 +178,12 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
def __init__(self, opt, env):
|
def __init__(self, opt, env):
|
||||||
super(PingPongLoss, self).__init__(opt, env)
|
super(PingPongLoss, self).__init__(opt, env)
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
fake = state[self.opt['fake']]
|
fake = state[self.opt['fake']]
|
||||||
l_total = 0
|
l_total = 0
|
||||||
for i in range((len(fake) - 1) / 2):
|
for i in range((len(fake) - 1) // 2):
|
||||||
early = fake[i]
|
early = fake[i]
|
||||||
late = fake[-i]
|
late = fake[-i]
|
||||||
l_total += self.criterion(early, late)
|
l_total += self.criterion(early, late)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user