Tecogan losses work
This commit is contained in:
parent
29bf78d791
commit
c93dd623d7
|
@ -26,10 +26,10 @@ def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_
|
|||
triplet = input_list[:, index:index+3]
|
||||
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
||||
with torch.no_grad():
|
||||
first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2))
|
||||
first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
||||
last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2))
|
||||
last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
|
||||
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
|
||||
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
||||
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
|
||||
#last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
|
||||
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
|
||||
triplet[:,1],
|
||||
resampler(triplet[:,2].float(), last_flow.float())]
|
||||
|
@ -61,25 +61,29 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||
self.scale = opt['scale']
|
||||
self.resample = Resample2d()
|
||||
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
||||
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
flow = self.env['generators'][self.flow]
|
||||
results = []
|
||||
first_inputs = extract_params_from_state(self.first_inputs, state)
|
||||
inputs = extract_params_from_state(self.input, state)
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
recurrent_input = torch.zeros_like(inputs[self.input_lq_index][:,0])
|
||||
|
||||
# Go forward in the sequence first.
|
||||
first_step = True
|
||||
b, f, c, h, w = inputs[self.input_lq_index].shape
|
||||
debug_index = 0
|
||||
for i in range(f):
|
||||
input = extract_inputs_index(inputs, i)
|
||||
if first_step:
|
||||
input = extract_inputs_index(first_inputs, i)
|
||||
recurrent_input = input[self.input_lq_index]
|
||||
first_step = False
|
||||
else:
|
||||
input = extract_inputs_index(inputs, i)
|
||||
with torch.no_grad():
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
|
@ -87,7 +91,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
# Resample does not work in FP16.
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||
if self.env['step'] % 20 == 0:
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||
debug_index += 1
|
||||
gen_out = gen(*input)
|
||||
|
@ -95,21 +99,22 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
results.append(recurrent_input)
|
||||
|
||||
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
||||
it = reversed(range(f - 1))
|
||||
for i in it:
|
||||
input = extract_inputs_index(inputs, i)
|
||||
with torch.no_grad():
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
flowfield = flow(flow_input)
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||
if self.env['step'] % 20 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||
debug_index += 1
|
||||
gen_out = gen(*input)
|
||||
recurrent_input = gen_out[self.output_hq_index]
|
||||
results.append(recurrent_input)
|
||||
if self.do_backwards:
|
||||
it = reversed(range(f - 1))
|
||||
for i in it:
|
||||
input = extract_inputs_index(inputs, i)
|
||||
with torch.no_grad():
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
flowfield = flow(flow_input)
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||
debug_index += 1
|
||||
gen_out = gen(*input)
|
||||
recurrent_input = gen_out[self.output_hq_index]
|
||||
results.append(recurrent_input)
|
||||
|
||||
return {self.output: results}
|
||||
|
||||
|
@ -122,7 +127,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
|
||||
# This is the temporal discriminator loss from TecoGAN.
|
||||
#
|
||||
# It has a strict contact for 'real' and 'fake' inputs:
|
||||
# It has a strict contract for 'real' and 'fake' inputs:
|
||||
# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset
|
||||
# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images.
|
||||
#
|
||||
|
@ -161,7 +166,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||
|
||||
if self.for_generator and self.env['step'] % 20 == 0:
|
||||
if self.for_generator and self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
||||
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
||||
|
||||
|
@ -205,7 +210,7 @@ class PingPongLoss(ConfigurableLoss):
|
|||
late = fake[-i]
|
||||
l_total += self.criterion(early, late)
|
||||
|
||||
if self.env['step'] % 20 == 0:
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(fake)
|
||||
|
||||
return l_total
|
||||
|
|
Loading…
Reference in New Issue
Block a user