Fixes to teco losses and translational losses

This commit is contained in:
James Betker 2020-11-19 11:35:05 -07:00
parent b2a05465fc
commit d7877d0a36
4 changed files with 22 additions and 8 deletions

View File

@ -388,6 +388,8 @@ class TranslationInvarianceLoss(ConfigurableLoss):
trans_output = net(*input) trans_output = net(*input)
else: else:
trans_output = net(*input) trans_output = net(*input)
if not isinstance(trans_output, list) and not isinstance(trans_output, tuple):
trans_output = [trans_output]
if self.gen_output_to_use is not None: if self.gen_output_to_use is not None:
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]

View File

@ -1,5 +1,6 @@
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
from models.archs.stylegan.stylegan2 import gradient_penalty
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
from models.flownet2.networks.resample2d_package.resample2d import Resample2d from models.flownet2.networks.resample2d_package.resample2d import Resample2d
from models.steps.injectors import Injector from models.steps.injectors import Injector
@ -243,6 +244,7 @@ class TecoGanLoss(ConfigurableLoss):
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
self.noise = opt['noise'] if 'noise' in opt.keys() else 0 self.noise = opt['noise'] if 'noise' in opt.keys() else 0
self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False
def forward(self, _, state): def forward(self, _, state):
if self.ff: if self.ff:
@ -264,10 +266,15 @@ class TecoGanLoss(ConfigurableLoss):
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation. # Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
for i in range(sequence_len - 2): for i in range(sequence_len - 2):
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin) real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
if self.gradient_penalty:
real_sext.requires_grad_()
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin) fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
l_step = self.compute_loss(real_sext, fake_sext) l_step, d_real = self.compute_loss(real_sext, fake_sext)
if l_step > self.min_loss: if l_step > self.min_loss:
l_total += l_step l_total = l_total + l_step
elif self.gradient_penalty:
gp = gradient_penalty(real_sext, d_real)
l_total = l_total + gp
return l_total return l_total
@ -283,19 +290,24 @@ class TecoGanLoss(ConfigurableLoss):
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation. # Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen, combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen,
self.resampler, self.margin) self.resampler, self.margin)
if self.gradient_penalty:
combined_real_sext.requires_grad_()
combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen, combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen,
self.resampler, self.margin) self.resampler, self.margin)
l_total = self.compute_loss(combined_real_sext, combined_fake_sext) l_total, d_real = self.compute_loss(combined_real_sext, combined_fake_sext)
if l_total < self.min_loss: if l_total < self.min_loss:
l_total = 0 l_total = 0
elif self.gradient_penalty:
gp = gradient_penalty(combined_real_sext, d_real)
l_total = l_total + gp
return l_total return l_total
def compute_loss(self, real_sext, fake_sext): def compute_loss(self, real_sext, fake_sext):
fp16 = self.env['opt']['fp16'] fp16 = self.env['opt']['fp16']
net = self.env['discriminators'][self.opt['discriminator']] net = self.env['discriminators'][self.opt['discriminator']]
if self.noise != 0: if self.noise != 0:
real_sext += torch.randn_like(real_sext) * self.noise real_sext = real_sext + torch.rand_like(real_sext) * self.noise
fake_sext += torch.randn_like(fake_sext) * self.noise fake_sext = fake_sext + torch.rand_like(fake_sext) * self.noise
with autocast(enabled=fp16): with autocast(enabled=fp16):
d_fake = net(fake_sext) d_fake = net(fake_sext)
d_real = net(real_sext) d_real = net(real_sext)
@ -322,7 +334,7 @@ class TecoGanLoss(ConfigurableLoss):
else: else:
raise NotImplementedError raise NotImplementedError
return l_step return l_step, d_real
def produce_teco_visual_debugs(self, sext, lbl, it): def produce_teco_visual_debugs(self, sext, lbl, it):
if self.env['rank'] > 0: if self.env['rank'] > 0:

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()