Alter teco_losses to feed a recurrent input in as separate
This commit is contained in:
parent
0d30d18a3d
commit
3a5b23b9f7
|
@ -62,6 +62,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
self.flow = opt['flow_network']
|
||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||
self.recurrent_index = opt['recurrent_index']
|
||||
self.scale = opt['scale']
|
||||
self.resample = Resample2d()
|
||||
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
||||
|
@ -83,17 +84,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
for i in range(f):
|
||||
if first_step:
|
||||
input = extract_inputs_index(first_inputs, i)
|
||||
recurrent_input = input[self.input_lq_index]
|
||||
recurrent_input = torch.zeros_like(input[self.recurrent_index])
|
||||
first_step = False
|
||||
else:
|
||||
input = extract_inputs_index(inputs, i)
|
||||
with torch.no_grad():
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
flowfield = flow(flow_input)
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
# Resample does not work in FP16.
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||
input[self.recurrent_index] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||
debug_index += 1
|
||||
|
@ -111,9 +112,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
with torch.no_grad():
|
||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||
flowfield = flow(flow_input)
|
||||
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
||||
input[self.recurrent_index
|
||||
] = recurrent_input
|
||||
if self.env['step'] % 50 == 0:
|
||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||
debug_index += 1
|
||||
|
@ -126,7 +128,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
return {self.output: results}
|
||||
|
||||
def produce_teco_visual_debugs(self, gen_input, it):
|
||||
if dist.get_rank() > 0:
|
||||
if self.env['rank'] > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
|
@ -174,6 +176,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
self.image_flow_generator = opt['image_flow_generator']
|
||||
self.resampler = Resample2d()
|
||||
self.for_generator = opt['for_generator']
|
||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
||||
|
||||
def forward(self, _, state):
|
||||
|
@ -202,19 +205,21 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
l_real = self.criterion(d_real, True)
|
||||
else:
|
||||
l_real = 0
|
||||
l_total += l_fake + l_real
|
||||
l_step = l_fake + l_real
|
||||
elif self.opt['gan_type'] == 'ragan':
|
||||
d_fake_diff = d_fake - torch.mean(d_real)
|
||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||
l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
|
||||
l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
|
||||
self.criterion(d_fake_diff, self.for_generator))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if l_step > self.min_loss:
|
||||
l_total += l_step
|
||||
|
||||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||
if dist.get_rank() > 0:
|
||||
if self.env['rank'] > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
|
@ -244,7 +249,7 @@ class PingPongLoss(ConfigurableLoss):
|
|||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, imglist):
|
||||
if dist.get_rank() > 0:
|
||||
if self.env['rank'] > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
|
|
Loading…
Reference in New Issue
Block a user