More tecogan fixes

This commit is contained in:
James Betker 2020-10-07 12:41:17 -06:00
parent a62a5dbb5f
commit c352c8bce4
2 changed files with 34 additions and 17 deletions

View File

@ -484,7 +484,7 @@ class StackedSwitchGenerator5Layer(nn.Module):
prefix = "amap_%i_a%i_%%i.png"
[save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step,
output_mag=False) for i in range(len(self.attentions))]
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps",
torchvision.utils.save_image(self.lr[:,:3], os.path.join(experiments_path, "attention_maps",
"amap_%i_base_image.png" % (step,)))
def get_debug_values(self, step, net_name):

View File

@ -22,21 +22,23 @@ def create_teco_injector(opt, env):
return RecurrentImageGeneratorSequenceInjector(opt, env)
return None
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler):
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
triplet = input_list[:, index:index+3]
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with torch.no_grad():
first_flow = flow_gen(torch.stack([lr_imgs[:,0], lr_imgs[:,1]], dim=2))
first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2))
first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
last_flow = flow_gen(torch.stack([lr_imgs[:,2], lr_imgs[:,1]], dim=2))
last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2))
last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
triplet[:,1],
resampler(triplet[:,2].float(), last_flow.float())]
flow_triplet = torch.stack(flow_triplet, dim=2)
combined = torch.cat([triplet, flow_triplet], dim=2)
flow_triplet = torch.stack(flow_triplet, dim=1)
combined = torch.cat([triplet, flow_triplet], dim=1)
b, f, c, h, w = combined.shape
return combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
# Apply margin
return combined[:, :, margin:-margin, margin:-margin]
def extract_inputs_index(inputs, i):
@ -72,6 +74,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
# Go forward in the sequence first.
first_step = True
b, f, c, h, w = inputs[self.input_lq_index].shape
debug_index = 0
for i in range(f):
input = extract_inputs_index(inputs, i)
if first_step:
@ -84,6 +87,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
# Resample does not work in FP16.
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
if self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
gen_out = gen(*input)
recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input)
@ -98,12 +104,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
flowfield = flow(flow_input)
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
if self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
gen_out = gen(*input)
recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input)
return {self.output: results}
def produce_teco_visual_debugs(self, gen_input, it):
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,)))
# This is the temporal discriminator loss from TecoGAN.
#
@ -128,6 +143,7 @@ class TecoGanLoss(ConfigurableLoss):
self.image_flow_generator = opt['image_flow_generator']
self.resampler = Resample2d()
self.for_generator = opt['for_generator']
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
def forward(self, _, state):
net = self.env['discriminators'][self.opt['discriminator']]
@ -138,16 +154,18 @@ class TecoGanLoss(ConfigurableLoss):
lr = state[self.opt['lr_inputs']]
l_total = 0
for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler)
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
d_fake = net(fake_sext)
d_real = net(real_sext)
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.for_generator and self.env['step'] % 100 == 0:
if self.for_generator and self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
self.produce_teco_visual_debugs(real_sext, 'real', i)
if self.opt['gan_type'] in ['gan', 'pixgan']:
self.metrics.append(("d_fake", torch.mean(d_fake)))
l_fake = self.criterion(d_fake, self.for_generator)
if not self.for_generator:
l_real = self.criterion(d_real, True)
@ -155,7 +173,6 @@ class TecoGanLoss(ConfigurableLoss):
l_real = 0
l_total += l_fake + l_real
elif self.opt['gan_type'] == 'ragan':
d_real = net(real_sext)
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
@ -166,11 +183,11 @@ class TecoGanLoss(ConfigurableLoss):
return l_total
def produce_teco_visual_debugs(self, sext, lbl, it):
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl)
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True)
lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow']
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
for i in range(6):
torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it)))
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :], osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
# This loss doesn't have a real entry - only fakes are used.
@ -188,13 +205,13 @@ class PingPongLoss(ConfigurableLoss):
late = fake[-i]
l_total += self.criterion(early, late)
if self.env['step'] % 100 == 0:
if self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(fake)
return l_total
def produce_teco_visual_debugs(self, imglist):
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step']))
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list)
for i, img in enumerate(imglist):