Upgrade tecogan_losses for speed

This commit is contained in:
James Betker 2020-10-27 22:40:15 -06:00
parent ac3da0c5a6
commit 4dc16d5889
2 changed files with 111 additions and 48 deletions

View File

@ -31,8 +31,6 @@ def create_injector(opt_inject, env):
return GreyInjector(opt_inject, env)
elif type == 'interpolate':
return InterpolateInjector(opt_inject, env)
elif type == 'imageflow':
return ImageFlowInjector(opt_inject, env)
elif type == 'image_patch':
return ImagePatchInjector(opt_inject, env)
elif type == 'concatenate':

View File

@ -26,23 +26,6 @@ def create_teco_injector(opt, env):
return FlowAdjustment(opt, env)
return None
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with autocast(enabled=False):
triplet = input_list[:, index:index+3].float()
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2))
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2))
flow_triplet = [resampler(triplet[:,0], first_flow),
triplet[:,1],
resampler(triplet[:,2], last_flow)]
flow_triplet = torch.stack(flow_triplet, dim=1)
combined = torch.cat([triplet, flow_triplet], dim=1)
b, f, c, h, w = combined.shape
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
# Apply margin
return combined[:, :, margin:-margin, margin:-margin]
def extract_inputs_index(inputs, i):
res = []
for input in inputs:
@ -152,9 +135,14 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
results[out_key].append(gen_out[i])
recurrent_input = gen_out[self.output_hq_index]
final_results = {}
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
b, s, c, h, w = state['hq'].shape
final_results['hq_batched'] = state['hq'].view(b*s, c, h, w)
for k, v in results.items():
results[k] = torch.stack(v, dim=1)
return results
final_results[k] = torch.stack(v, dim=1)
final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.
return final_results
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
if self.env['rank'] > 0:
@ -183,6 +171,47 @@ class FlowAdjustment(Injector):
return {self.output: self.resample(state[self.flowed], flowfield)}
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with autocast(enabled=False):
triplet = input_list[:, index:index+3].float()
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2))
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2))
flow_triplet = [resampler(triplet[:,0], first_flow),
triplet[:,1],
resampler(triplet[:,2], last_flow)]
flow_triplet = torch.stack(flow_triplet, dim=1)
combined = torch.cat([triplet, flow_triplet], dim=1)
b, f, c, h, w = combined.shape
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
# Apply margin
return combined[:, :, margin:-margin, margin:-margin]
def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_gen, resampler, margin):
# Combine everything and feed it into the flow network at once for better efficiency.
batch_sz = input_list.shape[0]
flux_doubles_forward = [torch.stack([input_list[:,i], input_list[:,i+1]], dim=2) for i in range(1, total+1)]
flux_doubles_backward = [torch.stack([input_list[:,i], input_list[:,i-1]], dim=2) for i in range(1, total+1)]
flows_forward = flow_gen(torch.cat(flux_doubles_forward, dim=0))
flows_backward = flow_gen(torch.cat(flux_doubles_backward, dim=0))
sexts = []
for i in range(total):
flow_forward = flows_forward[batch_sz*i:batch_sz*(i+1)]
flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)]
mid = input_list[:,i+1]
sext = torch.stack([input_list[:,i], mid, input_list[:,i+2],
resampler(mid, flow_backward),
mid,
resampler(mid, flow_forward)], dim=1)
# Apply margin
b, f, c, h, w = sext.shape
sext = sext.view(b, 3*6, h, w) # f*c = 6*3
sext = sext[:, :, margin:-margin, margin:-margin]
sexts.append(sext)
return torch.cat(sexts, dim=0)
# This is the temporal discriminator loss from TecoGAN.
#
# It has a strict contract for 'real' and 'fake' inputs:
@ -208,48 +237,85 @@ class TecoGanLoss(ConfigurableLoss):
self.for_generator = opt['for_generator']
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
def forward(self, _, state):
fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']]
if self.ff:
return self.fast_forward(state)
else:
return self.lowmem_forward(state)
# Computes the discriminator loss one recursive step at a time, which has a lower memory overhead but is
# slower.
def lowmem_forward(self, state):
flow_gen = self.env['generators'][self.image_flow_generator]
real = state[self.opt['real']]
fake = state[self.opt['fake']]
sequence_len = real.shape[1]
lr = state[self.opt['lr_inputs']]
l_total = 0
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
with autocast(enabled=fp16):
d_fake = net(fake_sext)
d_real = net(real_sext)
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.for_generator and self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
self.produce_teco_visual_debugs(real_sext, 'real', i)
if self.opt['gan_type'] in ['gan', 'pixgan']:
l_fake = self.criterion(d_fake, self.for_generator)
if not self.for_generator:
l_real = self.criterion(d_real, True)
else:
l_real = 0
l_step = l_fake + l_real
elif self.opt['gan_type'] == 'ragan':
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
self.criterion(d_fake_diff, self.for_generator))
else:
raise NotImplementedError
l_step = self.compute_loss(real_sext, fake_sext)
if l_step > self.min_loss:
l_total += l_step
return l_total
# Computes the discriminator loss by dogpiling all of the sextuplets into the batch dimension and doing one massive
# forward() on the discriminators. High memory but faster.
def fast_forward(self, state):
flow_gen = self.env['generators'][self.image_flow_generator]
real = state[self.opt['real']]
fake = state[self.opt['fake']]
sequence_len = real.shape[1]
lr = state[self.opt['lr_inputs']]
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen,
self.resampler, self.margin)
combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen,
self.resampler, self.margin)
l_total = self.compute_loss(combined_real_sext, combined_fake_sext)
if l_total < self.min_loss:
l_total = 0
return l_total
def compute_loss(self, real_sext, fake_sext):
fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']]
with autocast(enabled=fp16):
d_fake = net(fake_sext)
d_real = net(real_sext)
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.for_generator and self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', 0)
self.produce_teco_visual_debugs(real_sext, 'real', 0)
if self.opt['gan_type'] in ['gan', 'pixgan']:
l_fake = self.criterion(d_fake, self.for_generator)
if not self.for_generator:
l_real = self.criterion(d_real, True)
else:
l_real = 0
l_step = l_fake + l_real
elif self.opt['gan_type'] == 'ragan':
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
self.criterion(d_fake_diff, self.for_generator))
else:
raise NotImplementedError
return l_step
def produce_teco_visual_debugs(self, sext, lbl, it):
if self.env['rank'] > 0:
return
@ -291,4 +357,3 @@ class PingPongLoss(ConfigurableLoss):
img = imglist[:, i]
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))