Fixes to teco losses and translational losses

This commit is contained in:
James Betker 2020-11-19 11:35:05 -07:00
parent b2a05465fc
commit d7877d0a36
4 changed files with 22 additions and 8 deletions

View File

@ -388,6 +388,8 @@ class TranslationInvarianceLoss(ConfigurableLoss):
trans_output = net(*input)
else:
trans_output = net(*input)
if not isinstance(trans_output, list) and not isinstance(trans_output, tuple):
trans_output = [trans_output]
if self.gen_output_to_use is not None:
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]

View File

@ -1,5 +1,6 @@
from torch.cuda.amp import autocast
from models.archs.stylegan.stylegan2 import gradient_penalty
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector
@ -243,6 +244,7 @@ class TecoGanLoss(ConfigurableLoss):
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False
def forward(self, _, state):
if self.ff:
@ -264,10 +266,15 @@ class TecoGanLoss(ConfigurableLoss):
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
if self.gradient_penalty:
real_sext.requires_grad_()
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
l_step = self.compute_loss(real_sext, fake_sext)
l_step, d_real = self.compute_loss(real_sext, fake_sext)
if l_step > self.min_loss:
l_total += l_step
l_total = l_total + l_step
elif self.gradient_penalty:
gp = gradient_penalty(real_sext, d_real)
l_total = l_total + gp
return l_total
@ -283,19 +290,24 @@ class TecoGanLoss(ConfigurableLoss):
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen,
self.resampler, self.margin)
if self.gradient_penalty:
combined_real_sext.requires_grad_()
combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen,
self.resampler, self.margin)
l_total = self.compute_loss(combined_real_sext, combined_fake_sext)
l_total, d_real = self.compute_loss(combined_real_sext, combined_fake_sext)
if l_total < self.min_loss:
l_total = 0
elif self.gradient_penalty:
gp = gradient_penalty(combined_real_sext, d_real)
l_total = l_total + gp
return l_total
def compute_loss(self, real_sext, fake_sext):
fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']]
if self.noise != 0:
real_sext += torch.randn_like(real_sext) * self.noise
fake_sext += torch.randn_like(fake_sext) * self.noise
real_sext = real_sext + torch.rand_like(real_sext) * self.noise
fake_sext = fake_sext + torch.rand_like(fake_sext) * self.noise
with autocast(enabled=fp16):
d_fake = net(fake_sext)
d_real = net(real_sext)
@ -322,7 +334,7 @@ class TecoGanLoss(ConfigurableLoss):
else:
raise NotImplementedError
return l_step
return l_step, d_real
def produce_teco_visual_debugs(self, sext, lbl, it):
if self.env['rank'] > 0:

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()