More tecogan fixes

This commit is contained in:
James Betker 2020-10-07 12:41:17 -06:00
parent a62a5dbb5f
commit c352c8bce4
2 changed files with 34 additions and 17 deletions

View File

@ -484,7 +484,7 @@ class StackedSwitchGenerator5Layer(nn.Module):
prefix = "amap_%i_a%i_%%i.png" prefix = "amap_%i_a%i_%%i.png"
[save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step, [save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step,
output_mag=False) for i in range(len(self.attentions))] output_mag=False) for i in range(len(self.attentions))]
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps", torchvision.utils.save_image(self.lr[:,:3], os.path.join(experiments_path, "attention_maps",
"amap_%i_base_image.png" % (step,))) "amap_%i_base_image.png" % (step,)))
def get_debug_values(self, step, net_name): def get_debug_values(self, step, net_name):

View File

@ -22,21 +22,23 @@ def create_teco_injector(opt, env):
return RecurrentImageGeneratorSequenceInjector(opt, env) return RecurrentImageGeneratorSequenceInjector(opt, env)
return None return None
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler): def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
triplet = input_list[:, index:index+3] triplet = input_list[:, index:index+3]
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with torch.no_grad(): with torch.no_grad():
first_flow = flow_gen(torch.stack([lr_imgs[:,0], lr_imgs[:,1]], dim=2)) first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2))
first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
last_flow = flow_gen(torch.stack([lr_imgs[:,2], lr_imgs[:,1]], dim=2)) last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2))
last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic') last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()), flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
triplet[:,1], triplet[:,1],
resampler(triplet[:,2].float(), last_flow.float())] resampler(triplet[:,2].float(), last_flow.float())]
flow_triplet = torch.stack(flow_triplet, dim=2) flow_triplet = torch.stack(flow_triplet, dim=1)
combined = torch.cat([triplet, flow_triplet], dim=2) combined = torch.cat([triplet, flow_triplet], dim=1)
b, f, c, h, w = combined.shape b, f, c, h, w = combined.shape
return combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here. combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
# Apply margin
return combined[:, :, margin:-margin, margin:-margin]
def extract_inputs_index(inputs, i): def extract_inputs_index(inputs, i):
@ -72,6 +74,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
# Go forward in the sequence first. # Go forward in the sequence first.
first_step = True first_step = True
b, f, c, h, w = inputs[self.input_lq_index].shape b, f, c, h, w = inputs[self.input_lq_index].shape
debug_index = 0
for i in range(f): for i in range(f):
input = extract_inputs_index(inputs, i) input = extract_inputs_index(inputs, i)
if first_step: if first_step:
@ -84,6 +87,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
# Resample does not work in FP16. # Resample does not work in FP16.
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
if self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
gen_out = gen(*input) gen_out = gen(*input)
recurrent_input = gen_out[self.output_hq_index] recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input) results.append(recurrent_input)
@ -98,12 +104,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
flowfield = flow(flow_input) flowfield = flow(flow_input)
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
if self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
gen_out = gen(*input) gen_out = gen(*input)
recurrent_input = gen_out[self.output_hq_index] recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input) results.append(recurrent_input)
return {self.output: results} return {self.output: results}
def produce_teco_visual_debugs(self, gen_input, it):
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,)))
# This is the temporal discriminator loss from TecoGAN. # This is the temporal discriminator loss from TecoGAN.
# #
@ -128,6 +143,7 @@ class TecoGanLoss(ConfigurableLoss):
self.image_flow_generator = opt['image_flow_generator'] self.image_flow_generator = opt['image_flow_generator']
self.resampler = Resample2d() self.resampler = Resample2d()
self.for_generator = opt['for_generator'] self.for_generator = opt['for_generator']
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
def forward(self, _, state): def forward(self, _, state):
net = self.env['discriminators'][self.opt['discriminator']] net = self.env['discriminators'][self.opt['discriminator']]
@ -138,16 +154,18 @@ class TecoGanLoss(ConfigurableLoss):
lr = state[self.opt['lr_inputs']] lr = state[self.opt['lr_inputs']]
l_total = 0 l_total = 0
for i in range(sequence_len - 2): for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler) real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler) fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
d_fake = net(fake_sext) d_fake = net(fake_sext)
d_real = net(real_sext)
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.for_generator and self.env['step'] % 100 == 0: if self.for_generator and self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', i) self.produce_teco_visual_debugs(fake_sext, 'fake', i)
self.produce_teco_visual_debugs(real_sext, 'real', i) self.produce_teco_visual_debugs(real_sext, 'real', i)
if self.opt['gan_type'] in ['gan', 'pixgan']: if self.opt['gan_type'] in ['gan', 'pixgan']:
self.metrics.append(("d_fake", torch.mean(d_fake)))
l_fake = self.criterion(d_fake, self.for_generator) l_fake = self.criterion(d_fake, self.for_generator)
if not self.for_generator: if not self.for_generator:
l_real = self.criterion(d_real, True) l_real = self.criterion(d_real, True)
@ -155,7 +173,6 @@ class TecoGanLoss(ConfigurableLoss):
l_real = 0 l_real = 0
l_total += l_fake + l_real l_total += l_fake + l_real
elif self.opt['gan_type'] == 'ragan': elif self.opt['gan_type'] == 'ragan':
d_real = net(real_sext)
d_fake_diff = d_fake - torch.mean(d_real) d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
@ -166,11 +183,11 @@ class TecoGanLoss(ConfigurableLoss):
return l_total return l_total
def produce_teco_visual_debugs(self, sext, lbl, it): def produce_teco_visual_debugs(self, sext, lbl, it):
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow'] lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
for i in range(6): for i in range(6):
torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it))) torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :], osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
# This loss doesn't have a real entry - only fakes are used. # This loss doesn't have a real entry - only fakes are used.
@ -188,13 +205,13 @@ class PingPongLoss(ConfigurableLoss):
late = fake[-i] late = fake[-i]
l_total += self.criterion(early, late) l_total += self.criterion(early, late)
if self.env['step'] % 100 == 0: if self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(fake) self.produce_teco_visual_debugs(fake)
return l_total return l_total
def produce_teco_visual_debugs(self, imglist): def produce_teco_visual_debugs(self, imglist):
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list) assert isinstance(imglist, list)
for i, img in enumerate(imglist): for i, img in enumerate(imglist):