Tecogan losses work
This commit is contained in:
parent
29bf78d791
commit
c93dd623d7
|
@ -26,10 +26,10 @@ def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_
|
||||||
triplet = input_list[:, index:index+3]
|
triplet = input_list[:, index:index+3]
|
||||||
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
first_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,0]], dim=2))
|
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2).float())
|
||||||
first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
#first_flow = F.interpolate(first_flow, scale_factor=scale, mode='bicubic')
|
||||||
last_flow = flow_gen(torch.stack([lr_imgs[:,1], lr_imgs[:,2]], dim=2))
|
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2).float())
|
||||||
last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
|
#last_flow = F.interpolate(last_flow, scale_factor=scale, mode='bicubic')
|
||||||
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
|
flow_triplet = [resampler(triplet[:,0].float(), first_flow.float()),
|
||||||
triplet[:,1],
|
triplet[:,1],
|
||||||
resampler(triplet[:,2].float(), last_flow.float())]
|
resampler(triplet[:,2].float(), last_flow.float())]
|
||||||
|
@ -61,25 +61,29 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||||
self.scale = opt['scale']
|
self.scale = opt['scale']
|
||||||
self.resample = Resample2d()
|
self.resample = Resample2d()
|
||||||
|
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
||||||
|
self.do_backwards = opt['do_backwards'] if 'do_backwards' in opt.keys() else True
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
flow = self.env['generators'][self.flow]
|
flow = self.env['generators'][self.flow]
|
||||||
results = []
|
results = []
|
||||||
|
first_inputs = extract_params_from_state(self.first_inputs, state)
|
||||||
inputs = extract_params_from_state(self.input, state)
|
inputs = extract_params_from_state(self.input, state)
|
||||||
if not isinstance(inputs, list):
|
if not isinstance(inputs, list):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
recurrent_input = torch.zeros_like(inputs[self.input_lq_index][:,0])
|
|
||||||
|
|
||||||
# Go forward in the sequence first.
|
# Go forward in the sequence first.
|
||||||
first_step = True
|
first_step = True
|
||||||
b, f, c, h, w = inputs[self.input_lq_index].shape
|
b, f, c, h, w = inputs[self.input_lq_index].shape
|
||||||
debug_index = 0
|
debug_index = 0
|
||||||
for i in range(f):
|
for i in range(f):
|
||||||
input = extract_inputs_index(inputs, i)
|
|
||||||
if first_step:
|
if first_step:
|
||||||
|
input = extract_inputs_index(first_inputs, i)
|
||||||
|
recurrent_input = input[self.input_lq_index]
|
||||||
first_step = False
|
first_step = False
|
||||||
else:
|
else:
|
||||||
|
input = extract_inputs_index(inputs, i)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
|
@ -87,7 +91,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
# Resample does not work in FP16.
|
# Resample does not work in FP16.
|
||||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||||
if self.env['step'] % 20 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
gen_out = gen(*input)
|
gen_out = gen(*input)
|
||||||
|
@ -95,21 +99,22 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
results.append(recurrent_input)
|
results.append(recurrent_input)
|
||||||
|
|
||||||
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
# Now go backwards, skipping the last element (it's already stored in recurrent_input)
|
||||||
it = reversed(range(f - 1))
|
if self.do_backwards:
|
||||||
for i in it:
|
it = reversed(range(f - 1))
|
||||||
input = extract_inputs_index(inputs, i)
|
for i in it:
|
||||||
with torch.no_grad():
|
input = extract_inputs_index(inputs, i)
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
with torch.no_grad():
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||||
flowfield = flow(flow_input)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
flowfield = flow(flow_input)
|
||||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||||
if self.env['step'] % 20 == 0:
|
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
if self.env['step'] % 50 == 0:
|
||||||
debug_index += 1
|
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||||
gen_out = gen(*input)
|
debug_index += 1
|
||||||
recurrent_input = gen_out[self.output_hq_index]
|
gen_out = gen(*input)
|
||||||
results.append(recurrent_input)
|
recurrent_input = gen_out[self.output_hq_index]
|
||||||
|
results.append(recurrent_input)
|
||||||
|
|
||||||
return {self.output: results}
|
return {self.output: results}
|
||||||
|
|
||||||
|
@ -122,7 +127,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
|
|
||||||
# This is the temporal discriminator loss from TecoGAN.
|
# This is the temporal discriminator loss from TecoGAN.
|
||||||
#
|
#
|
||||||
# It has a strict contact for 'real' and 'fake' inputs:
|
# It has a strict contract for 'real' and 'fake' inputs:
|
||||||
# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset
|
# 'real' - Must be a list of arbitrary images (len>3) drawn from the dataset
|
||||||
# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images.
|
# 'fake' - The output of the RecurrentImageGeneratorSequenceInjector for the same set of images.
|
||||||
#
|
#
|
||||||
|
@ -161,7 +166,7 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||||
|
|
||||||
if self.for_generator and self.env['step'] % 20 == 0:
|
if self.for_generator and self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
||||||
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
||||||
|
|
||||||
|
@ -205,7 +210,7 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
late = fake[-i]
|
late = fake[-i]
|
||||||
l_total += self.criterion(early, late)
|
l_total += self.criterion(early, late)
|
||||||
|
|
||||||
if self.env['step'] % 20 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(fake)
|
self.produce_teco_visual_debugs(fake)
|
||||||
|
|
||||||
return l_total
|
return l_total
|
||||||
|
|
Loading…
Reference in New Issue
Block a user