Tecogan losses work

This commit is contained in:
James Betker 2020-10-07 23:11:58 -06:00
parent 29bf78d791
commit c93dd623d7

View File

@ -26,10 +26,10 @@ def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_
triplet = input_list[:, index:index+3]
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with torch.no_grad():
first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2))
first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2))
last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
#last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
triplet[:,1],
resampler(triplet[:,2].float(), last_flow.float())]
@ -61,25 +61,29 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.scale = opt['scale']
self.resample = Resample2d()
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
flow = self.env['generators'][self.flow]
results = []
first_inputs = extract_params_from_state(self.first_inputs, state)
inputs = extract_params_from_state(self.input, state)
if not isinstance(inputs, list):
inputs = [inputs]
recurrent_input = torch.zeros_like(inputs[self.input_lq_index][:,0])
# Go forward in the sequence first.
first_step = True
b, f, c, h, w = inputs[self.input_lq_index].shape
debug_index = 0
for i in range(f):
input = extract_inputs_index(inputs, i)
if first_step:
input = extract_inputs_index(first_inputs, i)
recurrent_input = input[self.input_lq_index]
first_step = False
else:
input = extract_inputs_index(inputs, i)
with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
@ -87,7 +91,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
# Resample does not work in FP16.
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
if self.env['step'] % 20 == 0:
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
gen_out = gen(*input)
@ -95,21 +99,22 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
results.append(recurrent_input)
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
it = reversed(range(f - 1))
for i in it:
input = extract_inputs_index(inputs, i)
with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = flow(flow_input)
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
if self.env['step'] % 20 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
gen_out = gen(*input)
recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input)
if self.do_backwards:
it = reversed(range(f - 1))
for i in it:
input = extract_inputs_index(inputs, i)
with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = flow(flow_input)
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
gen_out = gen(*input)
recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input)
return {self.output: results}
@ -122,7 +127,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
# This is the temporal discriminator loss from TecoGAN.
#
# It has a strict contact for 'real' and 'fake' inputs:
# It has a strict contract for 'real' and 'fake' inputs:
# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset
# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images.
#
@ -161,7 +166,7 @@ class TecoGanLoss(ConfigurableLoss):
self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real)))
if self.for_generator and self.env['step'] % 20 == 0:
if self.for_generator and self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
self.produce_teco_visual_debugs(real_sext, 'real', i)
@ -205,7 +210,7 @@ class PingPongLoss(ConfigurableLoss):
late = fake[-i]
l_total += self.criterion(early, late)
if self.env['step'] % 20 == 0:
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake)
return l_total