Fixes to teco losses and translational losses
This commit is contained in:
parent
b2a05465fc
commit
d7877d0a36
|
@ -388,6 +388,8 @@ class TranslationInvarianceLoss(ConfigurableLoss):
|
||||||
trans_output = net(*input)
|
trans_output = net(*input)
|
||||||
else:
|
else:
|
||||||
trans_output = net(*input)
|
trans_output = net(*input)
|
||||||
|
if not isinstance(trans_output, list) and not isinstance(trans_output, tuple):
|
||||||
|
trans_output = [trans_output]
|
||||||
|
|
||||||
if self.gen_output_to_use is not None:
|
if self.gen_output_to_use is not None:
|
||||||
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
|
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||||
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name
|
||||||
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
from models.flownet2.networks.resample2d_package.resample2d import Resample2d
|
||||||
from models.steps.injectors import Injector
|
from models.steps.injectors import Injector
|
||||||
|
@ -243,6 +244,7 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors.
|
||||||
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
|
self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False
|
||||||
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
self.noise = opt['noise'] if 'noise' in opt.keys() else 0
|
||||||
|
self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
if self.ff:
|
if self.ff:
|
||||||
|
@ -264,10 +266,15 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
|
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
|
||||||
for i in range(sequence_len - 2):
|
for i in range(sequence_len - 2):
|
||||||
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||||
|
if self.gradient_penalty:
|
||||||
|
real_sext.requires_grad_()
|
||||||
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin)
|
||||||
l_step = self.compute_loss(real_sext, fake_sext)
|
l_step, d_real = self.compute_loss(real_sext, fake_sext)
|
||||||
if l_step > self.min_loss:
|
if l_step > self.min_loss:
|
||||||
l_total += l_step
|
l_total = l_total + l_step
|
||||||
|
elif self.gradient_penalty:
|
||||||
|
gp = gradient_penalty(real_sext, d_real)
|
||||||
|
l_total = l_total + gp
|
||||||
|
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
|
@ -283,19 +290,24 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
|
# Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation.
|
||||||
combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen,
|
combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen,
|
||||||
self.resampler, self.margin)
|
self.resampler, self.margin)
|
||||||
|
if self.gradient_penalty:
|
||||||
|
combined_real_sext.requires_grad_()
|
||||||
combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen,
|
combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen,
|
||||||
self.resampler, self.margin)
|
self.resampler, self.margin)
|
||||||
l_total = self.compute_loss(combined_real_sext, combined_fake_sext)
|
l_total, d_real = self.compute_loss(combined_real_sext, combined_fake_sext)
|
||||||
if l_total < self.min_loss:
|
if l_total < self.min_loss:
|
||||||
l_total = 0
|
l_total = 0
|
||||||
|
elif self.gradient_penalty:
|
||||||
|
gp = gradient_penalty(combined_real_sext, d_real)
|
||||||
|
l_total = l_total + gp
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
def compute_loss(self, real_sext, fake_sext):
|
def compute_loss(self, real_sext, fake_sext):
|
||||||
fp16 = self.env['opt']['fp16']
|
fp16 = self.env['opt']['fp16']
|
||||||
net = self.env['discriminators'][self.opt['discriminator']]
|
net = self.env['discriminators'][self.opt['discriminator']]
|
||||||
if self.noise != 0:
|
if self.noise != 0:
|
||||||
real_sext += torch.randn_like(real_sext) * self.noise
|
real_sext = real_sext + torch.rand_like(real_sext) * self.noise
|
||||||
fake_sext += torch.randn_like(fake_sext) * self.noise
|
fake_sext = fake_sext + torch.rand_like(fake_sext) * self.noise
|
||||||
with autocast(enabled=fp16):
|
with autocast(enabled=fp16):
|
||||||
d_fake = net(fake_sext)
|
d_fake = net(fake_sext)
|
||||||
d_real = net(real_sext)
|
d_real = net(real_sext)
|
||||||
|
@ -322,7 +334,7 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
return l_step
|
return l_step, d_real
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||||
if self.env['rank'] > 0:
|
if self.env['rank'] > 0:
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user