DL-Art-School/codes/models/steps/tecogan_losses.py
2020-10-27 22:48:23 -06:00

364 lines
18 KiB
Python

from torch.cuda.amp import autocast
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector
import torch
import torch.nn.functional as F
import os
import os.path as osp
import torchvision
import torch.distributed as dist
def create_teco_loss(opt, env):
type = opt['type']
if type == 'teco_gan':
return TecoGanLoss(opt, env)
elif type == "teco_pingpong":
return PingPongLoss(opt, env)
return None
def create_teco_injector(opt, env):
type = opt['type']
if type == 'teco_recurrent_generated_sequence_injector':
return RecurrentImageGeneratorSequenceInjector(opt, env)
elif type == 'teco_flow_adjustment':
return FlowAdjustment(opt, env)
return None
def extract_inputs_index(inputs, i):
res = []
for input in inputs:
if isinstance(input, torch.Tensor):
res.append(input[:, i])
else:
res.append(input)
return res
# Uses a generator to synthesize a sequence of images from [in] and injects the results into a list [out]
# Images are fed in sequentially forward and back, resulting in len([out])=2*len([in])-1 (last element is not repeated).
# All computation is done with torch.no_grad().
class RecurrentImageGeneratorSequenceInjector(Injector):
def __init__(self, opt, env):
super(RecurrentImageGeneratorSequenceInjector, self).__init__(opt, env)
self.flow = opt['flow_network']
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.recurrent_index = opt['recurrent_index']
self.scale = opt['scale']
self.resample = Resample2d()
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
self.hq_recurrent = opt['hq_recurrent'] if 'hq_recurrent' in opt.keys() else False # When True, recurrent_index is not touched for the first iteration, allowing you to specify what is fed in. When False, zeros are fed into the recurrent index.
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
flow = self.env['generators'][self.flow]
first_inputs = extract_params_from_state(self.first_inputs, state)
inputs = extract_params_from_state(self.input, state)
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(self.output, list):
self.output = [self.output]
results = {}
for out_key in self.output:
results[out_key] = []
# Go forward in the sequence first.
first_step = True
b, f, c, h, w = inputs[self.input_lq_index].shape
debug_index = 0
for i in range(f):
if first_step:
input = extract_inputs_index(first_inputs, i)
if self.hq_recurrent:
recurrent_input = input[self.recurrent_index]
else:
recurrent_input = torch.zeros_like(input[self.recurrent_index])
first_step = False
else:
input = extract_inputs_index(inputs, i)
with torch.no_grad() and autocast(enabled=False):
# This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
# made here that if you are operating at 4x scale, your inputs are 32px x 32px
if self.scale >= 4:
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
else:
flow_input = input[self.input_lq_index]
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield)
input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i])
recurrent_input = gen_out[self.output_hq_index]
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
if self.do_backwards:
it = reversed(range(f - 1))
for i in it:
input = extract_inputs_index(inputs, i)
with torch.no_grad():
with autocast(enabled=False):
# This is a hack to workaround the fact that flownet2 cannot operate at resolutions < 64px. An assumption is
# made here that if you are operating at 4x scale, your inputs are 32px x 32px
if self.scale >= 4:
flow_input = F.interpolate(input[self.input_lq_index], scale_factor=self.scale//2, mode='bicubic')
else:
flow_input = input[self.input_lq_index]
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=.5, mode='bicubic')
flow_input = torch.stack([flow_input, reduced_recurrent], dim=2).float()
flowfield = F.interpolate(flow(flow_input), scale_factor=2, mode='bicubic')
recurrent_input = self.resample(recurrent_input.float(), flowfield)
input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], input[self.recurrent_index], debug_index)
debug_index += 1
with autocast(enabled=self.env['opt']['fp16']):
gen_out = gen(*input)
if isinstance(gen_out, torch.Tensor):
gen_out = [gen_out]
for i, out_key in enumerate(self.output):
results[out_key].append(gen_out[i])
recurrent_input = gen_out[self.output_hq_index]
final_results = {}
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
b, s, c, h, w = state['hq'].shape
final_results['hq_batched'] = state['hq'].view(b*s, c, h, w)
for k, v in results.items():
final_results[k] = torch.stack(v, dim=1)
final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.
return final_results
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
if self.env['rank'] > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input.float(), osp.join(base_path, "%s_img.png" % (it,)))
torchvision.utils.save_image(gen_recurrent.float(), osp.join(base_path, "%s_recurrent.png" % (it,)))
class FlowAdjustment(Injector):
def __init__(self, opt, env):
super(FlowAdjustment, self).__init__(opt, env)
self.resample = Resample2d()
self.flow = opt['flow_network']
self.flow_target = opt['flow_target']
self.flowed = opt['flowed']
def forward(self, state):
with autocast(enabled=False):
flow = self.env['generators'][self.flow]
flow_target = state[self.flow_target]
flowed = F.interpolate(state[self.flowed], size=flow_target.shape[2:], mode='bicubic')
flow_input = torch.stack([flow_target, flowed], dim=2).float()
flowfield = F.interpolate(flow(flow_input), size=state[self.flowed].shape[2:], mode='bicubic')
return {self.output: self.resample(state[self.flowed], flowfield)}
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with autocast(enabled=False):
triplet = input_list[:, index:index+3].float()
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2))
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2))
flow_triplet = [resampler(triplet[:,0], first_flow),
triplet[:,1],
resampler(triplet[:,2], last_flow)]
flow_triplet = torch.stack(flow_triplet, dim=1)
combined = torch.cat([triplet, flow_triplet], dim=1)
b, f, c, h, w = combined.shape
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
# Apply margin
return combined[:, :, margin:-margin, margin:-margin]
def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_gen, resampler, margin):
# Combine everything and feed it into the flow network at once for better efficiency.
batch_sz = input_list.shape[0]
flux_doubles_forward = [torch.stack([input_list[:,i], input_list[:,i+1]], dim=2) for i in range(1, total+1)]
flux_doubles_backward = [torch.stack([input_list[:,i], input_list[:,i-1]], dim=2) for i in range(1, total+1)]
flows_forward = flow_gen(torch.cat(flux_doubles_forward, dim=0))
flows_backward = flow_gen(torch.cat(flux_doubles_backward, dim=0))
sexts = []
for i in range(total):
flow_forward = flows_forward[batch_sz*i:batch_sz*(i+1)]
flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)]
mid = input_list[:,i+1]
sext = torch.stack([input_list[:,i], mid, input_list[:,i+2],
resampler(mid, flow_backward),
mid,
resampler(mid, flow_forward)], dim=1)
# Apply margin
b, f, c, h, w = sext.shape
sext = sext.view(b, 3*6, h, w) # f*c = 6*3
sext = sext[:, :, margin:-margin, margin:-margin]
sexts.append(sext)
return torch.cat(sexts, dim=0)
# This is the temporal discriminator loss from TecoGAN.
#
# It has a strict contract for 'real' and 'fake' inputs:
# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset
# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images.
#
# This loss does the following:
# 1) Picks an image triplet, starting with the first '3' elements in 'real' and 'fake'.
# 2) Uses the image flow generator (specified with 'image_flow_generator') to create detached flow fields for the first and last images in the above sequence.
# 3) Warps the first and last images according to the flow field.
# 4) Composes the three base image and the 2 warped images and middle image into a tensor concatenated at the filter dimension for both real and fake, resulting in a bx18xhxw shape tensor.
# 5) Feeds the catted real and fake image sets into the discriminator, computes a loss, and backward().
# 6) Repeat from (1) until all triplets from the real sequence have been exhausted.
class TecoGanLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(TecoGanLoss, self).__init__(opt, env)
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
# TecoGAN parameters
self.scale = opt['scale']
self.lr_inputs = opt['lr_inputs']
self.image_flow_generator = opt['image_flow_generator']
self.resampler = Resample2d()
self.for_generator = opt['for_generator']
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
def forward(self, _, state):
if self.ff:
return self.fast_forward(state)
else:
return self.lowmem_forward(state)
# Computes the discriminator loss one recursive step at a time, which has a lower memory overhead but is
# slower.
def lowmem_forward(self, state):
flow_gen = self.env['generators'][self.image_flow_generator]
real = state[self.opt['real']]
fake = state[self.opt['fake']]
sequence_len = real.shape[1]
lr = state[self.opt['lr_inputs']]
l_total = 0
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
l_step = self.compute_loss(real_sext, fake_sext)
if l_step > self.min_loss:
l_total += l_step
return l_total
# Computes the discriminator loss by dogpiling all of the sextuplets into the batch dimension and doing one massive
# forward() on the discriminators. High memory but faster.
def fast_forward(self, state):
flow_gen = self.env['generators'][self.image_flow_generator]
real = state[self.opt['real']]
fake = state[self.opt['fake']]
sequence_len = real.shape[1]
lr = state[self.opt['lr_inputs']]
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen,
self.resampler, self.margin)
combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen,
self.resampler, self.margin)
l_total = self.compute_loss(combined_real_sext, combined_fake_sext)
if l_total < self.min_loss:
l_total = 0
return l_total
def compute_loss(self, real_sext, fake_sext):
fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']]
if self.noise != 0:
real_sext += torch.randn_like(real_sext) * self.noise
fake_sext += torch.randn_like(fake_sext) * self.noise
with autocast(enabled=fp16):
d_fake = net(fake_sext)
d_real = net(real_sext)
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.for_generator and self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', 0)
self.produce_teco_visual_debugs(real_sext, 'real', 0)
if self.opt['gan_type'] in ['gan', 'pixgan']:
l_fake = self.criterion(d_fake, self.for_generator)
if not self.for_generator:
l_real = self.criterion(d_real, True)
else:
l_real = 0
l_step = l_fake + l_real
elif self.opt['gan_type'] == 'ragan':
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
self.criterion(d_fake_diff, self.for_generator))
else:
raise NotImplementedError
return l_step
def produce_teco_visual_debugs(self, sext, lbl, it):
if self.env['rank'] > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True)
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
for i in range(6):
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :].float(), osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
# This loss doesn't have a real entry - only fakes are used.
class PingPongLoss(ConfigurableLoss):
def __init__(self, opt, env):
super(PingPongLoss, self).__init__(opt, env)
self.opt = opt
self.criterion = get_basic_criterion_for_name(opt['criterion'], env['device'])
def forward(self, _, state):
fake = state[self.opt['fake']]
l_total = 0
img_count = fake.shape[1]
for i in range((img_count - 1) // 2):
early = fake[:, i]
late = fake[:, -i]
l_total += self.criterion(early, late)
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake)
return l_total
def produce_teco_visual_debugs(self, imglist):
if self.env['rank'] > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
cnt = imglist.shape[1]
for i in range(cnt):
img = imglist[:, i]
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))