Fixes to teco losses and translational losses
This commit is contained in:
parent
b2a05465fc
commit
d7877d0a36
|
@ -388,6 +388,8 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
|||
trans_output = net(*input)
|
||||
else:
|
||||
trans_output = net(*input)
|
||||
if not isinstance(trans_output, list) and not isinstance(trans_output, tuple):
|
||||
trans_output = [trans_output]
|
||||
|
||||
if self.gen_output_to_use is not None:
|
||||
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from torch.cuda.amp import autocast
|
||||
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||
from models.steps.injectors import Injector
|
||||
|
@ -243,6 +244,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
||||
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
|
||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||
self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False
|
||||
|
||||
def forward(self, _, state):
|
||||
if self.ff:
|
||||
|
@ -264,10 +266,15 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
|
||||
for i in range(sequence_len - 2):
|
||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||
if self.gradient_penalty:
|
||||
real_sext.requires_grad_()
|
||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||
l_step = self.compute_loss(real_sext, fake_sext)
|
||||
l_step, d_real = self.compute_loss(real_sext, fake_sext)
|
||||
if l_step > self.min_loss:
|
||||
l_total += l_step
|
||||
l_total = l_total + l_step
|
||||
elif self.gradient_penalty:
|
||||
gp = gradient_penalty(real_sext, d_real)
|
||||
l_total = l_total + gp
|
||||
|
||||
return l_total
|
||||
|
||||
|
@ -283,19 +290,24 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
|
||||
combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen,
|
||||
self.resampler, self.margin)
|
||||
if self.gradient_penalty:
|
||||
combined_real_sext.requires_grad_()
|
||||
combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen,
|
||||
self.resampler, self.margin)
|
||||
l_total = self.compute_loss(combined_real_sext, combined_fake_sext)
|
||||
l_total, d_real = self.compute_loss(combined_real_sext, combined_fake_sext)
|
||||
if l_total < self.min_loss:
|
||||
l_total = 0
|
||||
elif self.gradient_penalty:
|
||||
gp = gradient_penalty(combined_real_sext, d_real)
|
||||
l_total = l_total + gp
|
||||
return l_total
|
||||
|
||||
def compute_loss(self, real_sext, fake_sext):
|
||||
fp16 = self.env['opt']['fp16']
|
||||
net = self.env['discriminators'][self.opt['discriminator']]
|
||||
if self.noise != 0:
|
||||
real_sext += torch.randn_like(real_sext) * self.noise
|
||||
fake_sext += torch.randn_like(fake_sext) * self.noise
|
||||
real_sext = real_sext + torch.rand_like(real_sext) * self.noise
|
||||
fake_sext = fake_sext + torch.rand_like(fake_sext) * self.noise
|
||||
with autocast(enabled=fp16):
|
||||
d_fake = net(fake_sext)
|
||||
d_real = net(real_sext)
|
||||
|
@ -322,7 +334,7 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return l_step
|
||||
return l_step, d_real
|
||||
|
||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||
if self.env['rank'] > 0:
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user