More tecogan fixes
This commit is contained in:
parent
a62a5dbb5f
commit
c352c8bce4
|
@ -484,7 +484,7 @@ class StackedSwitchGenerator5Layer(nn.Module):
|
|||
prefix = "amap_%i_a%i_%%i.png"
|
||||
[save_attention_to_image_rgb(output_path, self.attentions[i], self.nf, prefix % (step, i), step,
|
||||
output_mag=False) for i in range(len(self.attentions))]
|
||||
torchvision.utils.save_image(self.lr, os.path.join(experiments_path, "attention_maps",
|
||||
torchvision.utils.save_image(self.lr[:,:3], os.path.join(experiments_path, "attention_maps",
|
||||
"amap_%i_base_image.png" % (step,)))
|
||||
|
||||
def get_debug_values(self, step, net_name):
|
||||
|
|
|
@ -22,21 +22,23 @@ def create_teco_injector(opt, env):
|
|||
return RecurrentImageGeneratorSequenceInjector(opt, env)
|
||||
return None
|
||||
|
||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler):
|
||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
|
||||
triplet = input_list[:, index:index+3]
|
||||
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
||||
with torch.no_grad():
|
||||
first_flow = flow_gen(torch.stack([lr_imgs[:,0], lr_imgs[:,1]], dim=2))
|
||||
first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2))
|
||||
first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
||||
last_flow = flow_gen(torch.stack([lr_imgs[:,2], lr_imgs[:,1]], dim=2))
|
||||
last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2))
|
||||
last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
|
||||
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
|
||||
triplet[:,1],
|
||||
resampler(triplet[:,2].float(), last_flow.float())]
|
||||
flow_triplet = torch.stack(flow_triplet, dim=2)
|
||||
combined = torch.cat([triplet, flow_triplet], dim=2)
|
||||
flow_triplet = torch.stack(flow_triplet, dim=1)
|
||||
combined = torch.cat([triplet, flow_triplet], dim=1)
|
||||
b, f, c, h, w = combined.shape
|
||||
return combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
|
||||
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
|
||||
# Apply margin
|
||||
return combined[:, :, margin:-margin, margin:-margin]
|
||||
|
||||
|
||||
def extract_inputs_index(inputs, i):
|
||||
|
@ -72,6 +74,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
# Go forward in the sequence first.
|
||||
first_step = True
|
||||
b, f, c, h, w = inputs[self.input_lq_index].shape
|
||||
debug_index = 0
|
||||
for i in range(f):
|
||||
input = extract_inputs_index(inputs, i)
|
||||
if first_step:
|
||||
|
@ -84,6 +87,9 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
# Resample does not work in FP16.
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||
if self.env['step'] % 20 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||
debug_index += 1
|
||||
gen_out = gen(*input)
|
||||
recurrent_input = gen_out[self.output_hq_index]
|
||||
results.append(recurrent_input)
|
||||
|
@ -98,12 +104,21 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
flowfield = flow(flow_input)
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||
if self.env['step'] % 20 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||
debug_index += 1
|
||||
gen_out = gen(*input)
|
||||
recurrent_input = gen_out[self.output_hq_index]
|
||||
results.append(recurrent_input)
|
||||
|
||||
return {self.output: results}
|
||||
|
||||
def produce_teco_visual_debugs(self, gen_input, it):
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
|
||||
torchvision.utils.save_image(gen_input[:, 3:], osp.join(base_path, "%s_recurrent.png" % (it,)))
|
||||
|
||||
|
||||
# This is the temporal discriminator loss from TecoGAN.
|
||||
#
|
||||
|
@ -128,6 +143,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
self.image_flow_generator = opt['image_flow_generator']
|
||||
self.resampler = Resample2d()
|
||||
self.for_generator = opt['for_generator']
|
||||
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
||||
|
||||
def forward(self, _, state):
|
||||
net = self.env['discriminators'][self.opt['discriminator']]
|
||||
|
@ -138,16 +154,18 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
lr = state[self.opt['lr_inputs']]
|
||||
l_total = 0
|
||||
for i in range(sequence_len - 2):
|
||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler)
|
||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler)
|
||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||
d_fake = net(fake_sext)
|
||||
d_real = net(real_sext)
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||
|
||||
if self.for_generator and self.env['step'] % 100 == 0:
|
||||
if self.for_generator and self.env['step'] % 20 == 0:
|
||||
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
||||
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
||||
|
||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
l_fake = self.criterion(d_fake, self.for_generator)
|
||||
if not self.for_generator:
|
||||
l_real = self.criterion(d_real, True)
|
||||
|
@ -155,7 +173,6 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
l_real = 0
|
||||
l_total += l_fake + l_real
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
d_real = net(real_sext)
|
||||
d_fake_diff = d_fake - torch.mean(d_real)
|
||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||
l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
|
||||
|
@ -166,11 +183,11 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
lbls = ['first', 'second', 'third', 'first_flow', 'second_flow', 'third_flow']
|
||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||
for i in range(6):
|
||||
torchvision.utils.save_image(sext[:, i*3:(i+1)*3-1, :, :], osp.join(base_path, "%s_%s.png" % (lbls[i], it)))
|
||||
torchvision.utils.save_image(sext[:, i*3:(i+1)*3, :, :], osp.join(base_path, "%s_%s.png" % (it, lbls[i])))
|
||||
|
||||
|
||||
# This loss doesn't have a real entry - only fakes are used.
|
||||
|
@ -188,13 +205,13 @@ class PingPongLoss(ConfigurableLoss):
|
|||
late = fake[-i]
|
||||
l_total += self.criterion(early, late)
|
||||
|
||||
if self.env['step'] % 100 == 0:
|
||||
if self.env['step'] % 20 == 0:
|
||||
self.produce_teco_visual_debugs(fake)
|
||||
|
||||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, imglist):
|
||||
base_path = osp.join(self.env['base_path'], "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
assert isinstance(imglist, list)
|
||||
for i, img in enumerate(imglist):
|
||||
|
|
Loading…
Reference in New Issue
Block a user