diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index abdefd46..02908e3e 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -388,6 +388,8 @@ class TranslationInvarianceLoss(ConfigurableLoss): trans_output = net(*input) else: trans_output = net(*input) + if not isinstance(trans_output, list) and not isinstance(trans_output, tuple): + trans_output = [trans_output] if self.gen_output_to_use is not None: fake_shared_output = trans_output[self.gen_output_to_use][:, :, hl:hh, wl:wh] diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index 91fd08bd..16e82413 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -1,5 +1,6 @@ from torch.cuda.amp import autocast +from models.archs.stylegan.stylegan2 import gradient_penalty from models.steps.losses import ConfigurableLoss, GANLoss, extract_params_from_state, get_basic_criterion_for_name from models.flownet2.networks.resample2d_package.resample2d import Resample2d from models.steps.injectors import Injector @@ -243,6 +244,7 @@ class TecoGanLoss(ConfigurableLoss): self.margin = opt['margin'] # Per the tecogan paper, the GAN loss only pays attention to an inner part of the image with the margin removed, to get rid of artifacts resulting from flow errors. self.ff = opt['fast_forward'] if 'fast_forward' in opt.keys() else False self.noise = opt['noise'] if 'noise' in opt.keys() else 0 + self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False def forward(self, _, state): if self.ff: @@ -264,10 +266,15 @@ class TecoGanLoss(ConfigurableLoss): # Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation. for i in range(sequence_len - 2): real_sext = create_teco_discriminator_sextuplet(real, lr, self.scale, i, flow_gen, self.resampler, self.margin) + if self.gradient_penalty: + real_sext.requires_grad_() fake_sext = create_teco_discriminator_sextuplet(fake, lr, self.scale, i, flow_gen, self.resampler, self.margin) - l_step = self.compute_loss(real_sext, fake_sext) + l_step, d_real = self.compute_loss(real_sext, fake_sext) if l_step > self.min_loss: - l_total += l_step + l_total = l_total + l_step + elif self.gradient_penalty: + gp = gradient_penalty(real_sext, d_real) + l_total = l_total + gp return l_total @@ -283,19 +290,24 @@ class TecoGanLoss(ConfigurableLoss): # Create a list of all the discriminator inputs, which will be reduced into the batch dim for efficient computation. combined_real_sext = create_all_discriminator_sextuplets(real, lr, self.scale, sequence_len - 2, flow_gen, self.resampler, self.margin) + if self.gradient_penalty: + combined_real_sext.requires_grad_() combined_fake_sext = create_all_discriminator_sextuplets(fake, lr, self.scale, sequence_len - 2, flow_gen, self.resampler, self.margin) - l_total = self.compute_loss(combined_real_sext, combined_fake_sext) + l_total, d_real = self.compute_loss(combined_real_sext, combined_fake_sext) if l_total < self.min_loss: l_total = 0 + elif self.gradient_penalty: + gp = gradient_penalty(combined_real_sext, d_real) + l_total = l_total + gp return l_total def compute_loss(self, real_sext, fake_sext): fp16 = self.env['opt']['fp16'] net = self.env['discriminators'][self.opt['discriminator']] if self.noise != 0: - real_sext += torch.randn_like(real_sext) * self.noise - fake_sext += torch.randn_like(fake_sext) * self.noise + real_sext = real_sext + torch.rand_like(real_sext) * self.noise + fake_sext = fake_sext + torch.rand_like(fake_sext) * self.noise with autocast(enabled=fp16): d_fake = net(fake_sext) d_real = net(real_sext) @@ -322,7 +334,7 @@ class TecoGanLoss(ConfigurableLoss): else: raise NotImplementedError - return l_step + return l_step, d_real def produce_teco_visual_debugs(self, sext, lbl, it): if self.env['rank'] > 0: diff --git a/codes/train.py b/codes/train.py index 93cc98f9..ecef64c3 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_teco_vix_srg2_classic_proper_disc.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index 6ba64db4..4eff3b79 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_srg2_grad_penalty.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()