Alter teco_losses to feed a recurrent input in as separate
This commit is contained in:
parent
0d30d18a3d
commit
3a5b23b9f7
|
@ -62,6 +62,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
self.flow = opt['flow_network']
|
self.flow = opt['flow_network']
|
||||||
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
self.input_lq_index = opt['input_lq_index'] if 'input_lq_index' in opt.keys() else 0
|
||||||
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
self.output_hq_index = opt['output_hq_index'] if 'output_hq_index' in opt.keys() else 0
|
||||||
|
self.recurrent_index = opt['recurrent_index']
|
||||||
self.scale = opt['scale']
|
self.scale = opt['scale']
|
||||||
self.resample = Resample2d()
|
self.resample = Resample2d()
|
||||||
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
self.first_inputs = opt['first_inputs'] if 'first_inputs' in opt.keys() else opt['in'] # Use this to specify inputs that will be used in the first teco iteration, the rest will use 'in'.
|
||||||
|
@ -83,17 +84,17 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
for i in range(f):
|
for i in range(f):
|
||||||
if first_step:
|
if first_step:
|
||||||
input = extract_inputs_index(first_inputs, i)
|
input = extract_inputs_index(first_inputs, i)
|
||||||
recurrent_input = input[self.input_lq_index]
|
recurrent_input = torch.zeros_like(input[self.recurrent_index])
|
||||||
first_step = False
|
first_step = False
|
||||||
else:
|
else:
|
||||||
input = extract_inputs_index(inputs, i)
|
input = extract_inputs_index(inputs, i)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1/self.scale, mode='bicubic')
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
flowfield = flow(flow_input)
|
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||||
# Resample does not work in FP16.
|
# Resample does not work in FP16.
|
||||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
input[self.recurrent_index] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
|
@ -111,9 +112,10 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
reduced_recurrent = F.interpolate(recurrent_input, scale_factor=1 / self.scale, mode='bicubic')
|
||||||
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
flow_input = torch.stack([input[self.input_lq_index], reduced_recurrent], dim=2)
|
||||||
flowfield = flow(flow_input)
|
flowfield = F.interpolate(flow(flow_input), scale_factor=self.scale, mode='bicubic')
|
||||||
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
recurrent_input = self.resample(reduced_recurrent.float(), flowfield.float())
|
||||||
input[self.input_lq_index] = torch.cat([input[self.input_lq_index], recurrent_input], dim=1)
|
input[self.recurrent_index
|
||||||
|
] = recurrent_input
|
||||||
if self.env['step'] % 50 == 0:
|
if self.env['step'] % 50 == 0:
|
||||||
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
self.produce_teco_visual_debugs(input[self.input_lq_index], debug_index)
|
||||||
debug_index += 1
|
debug_index += 1
|
||||||
|
@ -126,7 +128,7 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
return {self.output: results}
|
return {self.output: results}
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, gen_input, it):
|
def produce_teco_visual_debugs(self, gen_input, it):
|
||||||
if dist.get_rank() > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
@ -174,6 +176,7 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
self.image_flow_generator = opt['image_flow_generator']
|
self.image_flow_generator = opt['image_flow_generator']
|
||||||
self.resampler = Resample2d()
|
self.resampler = Resample2d()
|
||||||
self.for_generator = opt['for_generator']
|
self.for_generator = opt['for_generator']
|
||||||
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||||
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
|
@ -202,19 +205,21 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
l_real = self.criterion(d_real, True)
|
l_real = self.criterion(d_real, True)
|
||||||
else:
|
else:
|
||||||
l_real = 0
|
l_real = 0
|
||||||
l_total += l_fake + l_real
|
l_step = l_fake + l_real
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
d_fake_diff = d_fake - torch.mean(d_real)
|
d_fake_diff = d_fake - torch.mean(d_real)
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
l_total += (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
|
l_step = (self.criterion(d_real - torch.mean(d_fake), not self.for_generator) +
|
||||||
self.criterion(d_fake_diff, self.for_generator))
|
self.criterion(d_fake_diff, self.for_generator))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
if l_step > self.min_loss:
|
||||||
|
l_total += l_step
|
||||||
|
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||||
if dist.get_rank() > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
@ -244,7 +249,7 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, imglist):
|
def produce_teco_visual_debugs(self, imglist):
|
||||||
if dist.get_rank() > 0:
|
if self.env['rank'] > 0:
|
||||||
return
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user