Upgrade tecogan_losses for speed
This commit is contained in:
parent
ac3da0c5a6
commit
4dc16d5889
|
@ -31,8 +31,6 @@ def create_injector(opt_inject, env):
|
||||||
return GreyInjector(opt_inject, env)
|
return GreyInjector(opt_inject, env)
|
||||||
elif type == 'interpolate':
|
elif type == 'interpolate':
|
||||||
return InterpolateInjector(opt_inject, env)
|
return InterpolateInjector(opt_inject, env)
|
||||||
elif type == 'imageflow':
|
|
||||||
return ImageFlowInjector(opt_inject, env)
|
|
||||||
elif type == 'image_patch':
|
elif type == 'image_patch':
|
||||||
return ImagePatchInjector(opt_inject, env)
|
return ImagePatchInjector(opt_inject, env)
|
||||||
elif type == 'concatenate':
|
elif type == 'concatenate':
|
||||||
|
|
|
@ -26,23 +26,6 @@ def create_teco_injector(opt, env):
|
||||||
return FlowAdjustment(opt, env)
|
return FlowAdjustment(opt, env)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
|
|
||||||
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
|
||||||
with autocast(enabled=False):
|
|
||||||
triplet = input_list[:, index:index+3].float()
|
|
||||||
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2))
|
|
||||||
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2))
|
|
||||||
flow_triplet = [resampler(triplet[:,0], first_flow),
|
|
||||||
triplet[:,1],
|
|
||||||
resampler(triplet[:,2], last_flow)]
|
|
||||||
flow_triplet = torch.stack(flow_triplet, dim=1)
|
|
||||||
combined = torch.cat([triplet, flow_triplet], dim=1)
|
|
||||||
b, f, c, h, w = combined.shape
|
|
||||||
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
|
|
||||||
# Apply margin
|
|
||||||
return combined[:, :, margin:-margin, margin:-margin]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_inputs_index(inputs, i):
|
def extract_inputs_index(inputs, i):
|
||||||
res = []
|
res = []
|
||||||
for input in inputs:
|
for input in inputs:
|
||||||
|
@ -152,9 +135,14 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
results[out_key].append(gen_out[i])
|
results[out_key].append(gen_out[i])
|
||||||
recurrent_input = gen_out[self.output_hq_index]
|
recurrent_input = gen_out[self.output_hq_index]
|
||||||
|
|
||||||
|
final_results = {}
|
||||||
|
# Include 'hq_batched' here - because why not... Don't really need a separate injector for this.
|
||||||
|
b, s, c, h, w = state['hq'].shape
|
||||||
|
final_results['hq_batched'] = state['hq'].view(b*s, c, h, w)
|
||||||
for k, v in results.items():
|
for k, v in results.items():
|
||||||
results[k] = torch.stack(v, dim=1)
|
final_results[k] = torch.stack(v, dim=1)
|
||||||
return results
|
final_results[k + "_batched"] = torch.cat(v[:s], dim=0) # Only include the original sequence - this output is generally used to compare against HQ.
|
||||||
|
return final_results
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
def produce_teco_visual_debugs(self, gen_input, gen_recurrent, it):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
|
@ -183,6 +171,47 @@ class FlowAdjustment(Injector):
|
||||||
return {self.output: self.resample(state[self.flowed], flowfield)}
|
return {self.output: self.resample(state[self.flowed], flowfield)}
|
||||||
|
|
||||||
|
|
||||||
|
def create_teco_discriminator_sextuplet(input_list, lr_imgs, scale, index, flow_gen, resampler, margin):
|
||||||
|
# Flow is interpreted from the LR images so that the generator cannot learn to manipulate it.
|
||||||
|
with autocast(enabled=False):
|
||||||
|
triplet = input_list[:, index:index+3].float()
|
||||||
|
first_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,0]], dim=2))
|
||||||
|
last_flow = flow_gen(torch.stack([triplet[:,1], triplet[:,2]], dim=2))
|
||||||
|
flow_triplet = [resampler(triplet[:,0], first_flow),
|
||||||
|
triplet[:,1],
|
||||||
|
resampler(triplet[:,2], last_flow)]
|
||||||
|
flow_triplet = torch.stack(flow_triplet, dim=1)
|
||||||
|
combined = torch.cat([triplet, flow_triplet], dim=1)
|
||||||
|
b, f, c, h, w = combined.shape
|
||||||
|
combined = combined.view(b, 3*6, h, w) # 3*6 is essentially an assertion here.
|
||||||
|
# Apply margin
|
||||||
|
return combined[:, :, margin:-margin, margin:-margin]
|
||||||
|
|
||||||
|
|
||||||
|
def create_all_discriminator_sextuplets(input_list, lr_imgs, scale, total, flow_gen, resampler, margin):
|
||||||
|
# Combine everything and feed it into the flow network at once for better efficiency.
|
||||||
|
batch_sz = input_list.shape[0]
|
||||||
|
flux_doubles_forward = [torch.stack([input_list[:,i], input_list[:,i+1]], dim=2) for i in range(1, total+1)]
|
||||||
|
flux_doubles_backward = [torch.stack([input_list[:,i], input_list[:,i-1]], dim=2) for i in range(1, total+1)]
|
||||||
|
flows_forward = flow_gen(torch.cat(flux_doubles_forward, dim=0))
|
||||||
|
flows_backward = flow_gen(torch.cat(flux_doubles_backward, dim=0))
|
||||||
|
sexts = []
|
||||||
|
for i in range(total):
|
||||||
|
flow_forward = flows_forward[batch_sz*i:batch_sz*(i+1)]
|
||||||
|
flow_backward = flows_backward[batch_sz*i:batch_sz*(i+1)]
|
||||||
|
mid = input_list[:,i+1]
|
||||||
|
sext = torch.stack([input_list[:,i], mid, input_list[:,i+2],
|
||||||
|
resampler(mid, flow_backward),
|
||||||
|
mid,
|
||||||
|
resampler(mid, flow_forward)], dim=1)
|
||||||
|
# Apply margin
|
||||||
|
b, f, c, h, w = sext.shape
|
||||||
|
sext = sext.view(b, 3*6, h, w) # f*c = 6*3
|
||||||
|
sext = sext[:, :, margin:-margin, margin:-margin]
|
||||||
|
sexts.append(sext)
|
||||||
|
return torch.cat(sexts, dim=0)
|
||||||
|
|
||||||
|
|
||||||
# This is the temporal discriminator loss from TecoGAN.
|
# This is the temporal discriminator loss from TecoGAN.
|
||||||
#
|
#
|
||||||
# It has a strict contract for 'real' and 'fake' inputs:
|
# It has a strict contract for 'real' and 'fake' inputs:
|
||||||
|
@ -208,48 +237,85 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
self.for_generator = opt['for_generator']
|
self.for_generator = opt['for_generator']
|
||||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||||
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
||||||
|
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
fp16 = self.env['opt']['fp16']
|
if self.ff:
|
||||||
net = self.env['discriminators'][self.opt['discriminator']]
|
return self.fast_forward(state)
|
||||||
|
else:
|
||||||
|
return self.lowmem_forward(state)
|
||||||
|
|
||||||
|
|
||||||
|
# Computes the discriminator loss one recursive step at a time, which has a lower memory overhead but is
|
||||||
|
# slower.
|
||||||
|
def lowmem_forward(self, state):
|
||||||
flow_gen = self.env['generators'][self.image_flow_generator]
|
flow_gen = self.env['generators'][self.image_flow_generator]
|
||||||
real = state[self.opt['real']]
|
real = state[self.opt['real']]
|
||||||
fake = state[self.opt['fake']]
|
fake = state[self.opt['fake']]
|
||||||
sequence_len = real.shape[1]
|
sequence_len = real.shape[1]
|
||||||
lr = state[self.opt['lr_inputs']]
|
lr = state[self.opt['lr_inputs']]
|
||||||
l_total = 0
|
l_total = 0
|
||||||
|
|
||||||
|
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
|
||||||
for i in range(sequence_len - 2):
|
for i in range(sequence_len - 2):
|
||||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||||
with autocast(enabled=fp16):
|
l_step = self.compute_loss(real_sext, fake_sext)
|
||||||
d_fake = net(fake_sext)
|
|
||||||
d_real = net(real_sext)
|
|
||||||
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
|
||||||
self.metrics.append(("d_real", torch.mean(d_real)))
|
|
||||||
|
|
||||||
if self.for_generator and self.env['step'] % 50 == 0:
|
|
||||||
self.produce_teco_visual_debugs(fake_sext, 'fake', i)
|
|
||||||
self.produce_teco_visual_debugs(real_sext, 'real', i)
|
|
||||||
|
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
|
||||||
l_fake = self.criterion(d_fake, self.for_generator)
|
|
||||||
if not self.for_generator:
|
|
||||||
l_real = self.criterion(d_real, True)
|
|
||||||
else:
|
|
||||||
l_real = 0
|
|
||||||
l_step = l_fake + l_real
|
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
|
||||||
d_fake_diff = d_fake - torch.mean(d_real)
|
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
|
||||||
l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
|
|
||||||
self.criterion(d_fake_diff, self.for_generator))
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
if l_step > self.min_loss:
|
if l_step > self.min_loss:
|
||||||
l_total += l_step
|
l_total += l_step
|
||||||
|
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
|
# Computes the discriminator loss by dogpiling all of the sextuplets into the batch dimension and doing one massive
|
||||||
|
# forward() on the discriminators. High memory but faster.
|
||||||
|
def fast_forward(self, state):
|
||||||
|
flow_gen = self.env['generators'][self.image_flow_generator]
|
||||||
|
real = state[self.opt['real']]
|
||||||
|
fake = state[self.opt['fake']]
|
||||||
|
sequence_len = real.shape[1]
|
||||||
|
lr = state[self.opt['lr_inputs']]
|
||||||
|
|
||||||
|
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
|
||||||
|
combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen,
|
||||||
|
self.resampler, self.margin)
|
||||||
|
combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen,
|
||||||
|
self.resampler, self.margin)
|
||||||
|
l_total = self.compute_loss(combined_real_sext, combined_fake_sext)
|
||||||
|
if l_total < self.min_loss:
|
||||||
|
l_total = 0
|
||||||
|
return l_total
|
||||||
|
|
||||||
|
def compute_loss(self, real_sext, fake_sext):
|
||||||
|
fp16 = self.env['opt']['fp16']
|
||||||
|
net = self.env['discriminators'][self.opt['discriminator']]
|
||||||
|
with autocast(enabled=fp16):
|
||||||
|
d_fake = net(fake_sext)
|
||||||
|
d_real = net(real_sext)
|
||||||
|
|
||||||
|
self.metrics.append(("d_fake", torch.mean(d_fake)))
|
||||||
|
self.metrics.append(("d_real", torch.mean(d_real)))
|
||||||
|
|
||||||
|
if self.for_generator and self.env['step'] % 50 == 0:
|
||||||
|
self.produce_teco_visual_debugs(fake_sext, 'fake', 0)
|
||||||
|
self.produce_teco_visual_debugs(real_sext, 'real', 0)
|
||||||
|
|
||||||
|
if self.opt['gan_type'] in ['gan', 'pixgan']:
|
||||||
|
l_fake = self.criterion(d_fake, self.for_generator)
|
||||||
|
if not self.for_generator:
|
||||||
|
l_real = self.criterion(d_real, True)
|
||||||
|
else:
|
||||||
|
l_real = 0
|
||||||
|
l_step = l_fake + l_real
|
||||||
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
|
d_fake_diff = d_fake - torch.mean(d_real)
|
||||||
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
|
l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
|
||||||
|
self.criterion(d_fake_diff, self.for_generator))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
return l_step
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
|
@ -291,4 +357,3 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
img = imglist[:, i]
|
img = imglist[:, i]
|
||||||
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))
|
torchvision.utils.save_image(img.float(), osp.join(base_path, "%s.png" % (i, )))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user