Alter teco_losses to feed a recurrent input in as separate

This commit is contained in:
James Betker 2020-10-10 20:21:09 -06:00
parent 0d30d18a3d
commit 3a5b23b9f7

View File

@ -62,6 +62,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
self.flow = opt['flow_network']
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.recurrent_index = opt['recurrent_index']
self.scale = opt['scale']
self.resample = Resample2d()
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
@ -83,17 +84,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
for i in range(f):
if first_step:
input = extract_inputs_index(first_inputs, i)
recurrent_input = input[self.input_lq_index]
recurrent_input = torch.zeros_like(input[self.recurrent_index])
first_step = False
else:
input = extract_inputs_index(inputs, i)
with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = flow(flow_input)
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
# Resample does not work in FP16.
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
@ -111,9 +112,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = flow(flow_input)
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
input[self.recurrent_index
] = recurrent_input
if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1
@ -126,7 +128,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
return {self.output: results}
def produce_teco_visual_debugs(self, gen_input, it):
if dist.get_rank() > 0:
if self.env['rank'] > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
@ -174,6 +176,7 @@ class TecoGanLoss(ConfigurableLoss):
self.image_flow_generator = opt['image_flow_generator']
self.resampler = Resample2d()
self.for_generator = opt['for_generator']
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
def forward(self, _, state):
@ -202,19 +205,21 @@ class TecoGanLoss(ConfigurableLoss):
l_real = self.criterion(d_real, True)
else:
l_real = 0
l_total += l_fake + l_real
l_step = l_fake + l_real
elif self.opt['gan_type'] == 'ragan':
d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
self.criterion(d_fake_diff, self.for_generator))
else:
raise NotImplementedError
if l_step > self.min_loss:
l_total += l_step
return l_total
def produce_teco_visual_debugs(self, sext, lbl, it):
if dist.get_rank() > 0:
if self.env['rank'] > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True)
@ -244,7 +249,7 @@ class PingPongLoss(ConfigurableLoss):
return l_total
def produce_teco_visual_debugs(self, imglist):
if dist.get_rank() > 0:
if self.env['rank'] > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)