Alter teco_losses to feed a recurrent input in as separate

This commit is contained in:
James Betker 2020-10-10 20:21:09 -06:00
parent 0d30d18a3d
commit 3a5b23b9f7

View File

@ -62,6 +62,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
self.flow = opt['flow_network'] self.flow = opt['flow_network']
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0 self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0 self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
self.recurrent_index = opt['recurrent_index']
self.scale = opt['scale'] self.scale = opt['scale']
self.resample = Resample2d() self.resample = Resample2d()
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'. self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
@ -83,17 +84,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
for i in range(f): for i in range(f):
if first_step: if first_step:
input = extract_inputs_index(first_inputs, i) input = extract_inputs_index(first_inputs, i)
recurrent_input = input[self.input_lq_index] recurrent_input = torch.zeros_like(input[self.recurrent_index])
first_step = False first_step = False
else: else:
input = extract_inputs_index(inputs, i) input = extract_inputs_index(inputs, i)
with torch.no_grad(): with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic') reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = flow(flow_input) flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
# Resample does not work in FP16. # Resample does not work in FP16.
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) input[self.recurrent_index] = recurrent_input
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1 debug_index += 1
@ -111,9 +112,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
with torch.no_grad(): with torch.no_grad():
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic') reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2) flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
flowfield = flow(flow_input) flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float()) recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1) input[self.recurrent_index
] = recurrent_input
if self.env['step'] % 50 == 0: if self.env['step'] % 50 == 0:
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index) self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
debug_index += 1 debug_index += 1
@ -126,7 +128,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
return {self.output: results} return {self.output: results}
def produce_teco_visual_debugs(self, gen_input, it): def produce_teco_visual_debugs(self, gen_input, it):
if dist.get_rank() > 0: if self.env['rank'] > 0:
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
@ -174,6 +176,7 @@ class TecoGanLoss(ConfigurableLoss):
self.image_flow_generator = opt['image_flow_generator'] self.image_flow_generator = opt['image_flow_generator']
self.resampler = Resample2d() self.resampler = Resample2d()
self.for_generator = opt['for_generator'] self.for_generator = opt['for_generator']
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
def forward(self, _, state): def forward(self, _, state):
@ -202,19 +205,21 @@ class TecoGanLoss(ConfigurableLoss):
l_real = self.criterion(d_real, True) l_real = self.criterion(d_real, True)
else: else:
l_real = 0 l_real = 0
l_total += l_fake + l_real l_step = l_fake + l_real
elif self.opt['gan_type'] == 'ragan': elif self.opt['gan_type'] == 'ragan':
d_fake_diff = d_fake - torch.mean(d_real) d_fake_diff = d_fake - torch.mean(d_real)
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff))) self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) + l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
self.criterion(d_fake_diff, self.for_generator)) self.criterion(d_fake_diff, self.for_generator))
else: else:
raise NotImplementedError raise NotImplementedError
if l_step > self.min_loss:
l_total += l_step
return l_total return l_total
def produce_teco_visual_debugs(self, sext, lbl, it): def produce_teco_visual_debugs(self, sext, lbl, it):
if dist.get_rank() > 0: if self.env['rank'] > 0:
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
@ -244,7 +249,7 @@ class PingPongLoss(ConfigurableLoss):
return l_total return l_total
def produce_teco_visual_debugs(self, imglist): def produce_teco_visual_debugs(self, imglist):
if dist.get_rank() > 0: if self.env['rank'] > 0:
return return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)