Tecogan losses work

This commit is contained in:
James Betker 2020-10-07 23:11:58 -06:00
parent 29bf78d791
commit c93dd623d7

View File

@ -26,10 +26,10 @@ def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_
triplet = input_list[:, index:index+3] triplet = input_list[:, index:index+3]
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it. # Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
with torch.no_grad(): with torch.no_grad():
first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2)) first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic') #first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2)) last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic') #last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()), flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
triplet[:,1], triplet[:,1],
resampler(triplet[:,2].float(), last_flow.float())] resampler(triplet[:,2].float(), last_flow.float())]
@ -61,25 +61,29 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.scale = opt['scale'] self.scale = opt['scale']
self.resample = Resample2d() self.resample = Resample2d()
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
flow = self.env['generators'][self.flow] flow = self.env['generators'][self.flow]
results = [] results = []
first_inputs = extract_params_from_state(self.first_inputs, state)
inputs = extract_params_from_state(self.input, state) inputs = extract_params_from_state(self.input, state)
if not isinstance(inputs, list): if not isinstance(inputs, list):
inputs = [inputs] inputs = [inputs]
recurrent_input = torch.zeros_like(inputs[self.input_lq_index][:,0])
# Go forward in the sequence first. # Go forward in the sequence first.
first_step = True first_step = True
b, f, c, h, w = inputs[self.input_lq_index].shape b, f, c, h, w = inputs[self.input_lq_index].shape
debug_index = 0 debug_index = 0
for i in range(f): for i in range(f):
input = extract_inputs_index(inputs, i)
if first_step: if first_step:
input = extract_inputs_index(first_inputs, i)
recurrent_input = input[self.input_lq_index]
first_step = False first_step = False
else: else:
input = extract_inputs_index(inputs, i)
with torch.no_grad(): with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
@ -87,7 +91,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
# Resample does not work in FP16. # Resample does not work in FP16.
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
if self.env['step'] % 20 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1 debug_index += 1
gen_out = gen(*input) gen_out = gen(*input)
@ -95,21 +99,22 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
results.append(recurrent_input) results.append(recurrent_input)
# Now go backwards, skipping the last element (it's already stored in recurrent_input) # Now go backwards, skipping the last element (it's already stored in recurrent_input)
it = reversed(range(f - 1)) if self.do_backwards:
for i in it: it = reversed(range(f - 1))
input = extract_inputs_index(inputs, i) for i in it:
with torch.no_grad(): input = extract_inputs_index(inputs, i)
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') with torch.no_grad():
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flowfield = flow(flow_input) flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) flowfield = flow(flow_input)
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
if self.env['step'] % 20 == 0: input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) if self.env['step'] % 50 == 0:
debug_index += 1 self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
gen_out = gen(*input) debug_index += 1
recurrent_input = gen_out[self.output_hq_index] gen_out = gen(*input)
results.append(recurrent_input) recurrent_input = gen_out[self.output_hq_index]
results.append(recurrent_input)
return {self.output: results} return {self.output: results}
@ -122,7 +127,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
# This is the temporal discriminator loss from TecoGAN. # This is the temporal discriminator loss from TecoGAN.
# #
# It has a strict contact for 'real' and 'fake' inputs: # It has a strict contract for 'real' and 'fake' inputs:
# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset # 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset
# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images. # 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images.
# #
@ -161,7 +166,7 @@ class TecoGanLoss(ConfigurableLoss):
self.metrics.append(("d_fake", torch.mean(d_fake))) self.metrics.append(("d_fake", torch.mean(d_fake)))
self.metrics.append(("d_real", torch.mean(d_real))) self.metrics.append(("d_real", torch.mean(d_real)))
if self.for_generator and self.env['step'] % 20 == 0: if self.for_generator and self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake_sext, 'fake', i) self.produce_teco_visual_debugs(fake_sext, 'fake', i)
self.produce_teco_visual_debugs(real_sext, 'real', i) self.produce_teco_visual_debugs(real_sext, 'real', i)
@ -205,7 +210,7 @@ class PingPongLoss(ConfigurableLoss):
late = fake[-i] late = fake[-i]
l_total += self.criterion(early, late) l_total += self.criterion(early, late)
if self.env['step'] % 20 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(fake) self.produce_teco_visual_debugs(fake)
return l_total return l_total