From 6b679e2b513e4fab348fb5ca81ca135043cb9724 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 17 Nov 2020 18:31:40 -0700 Subject: [PATCH] Make grad_penalty available to classical discs --- codes/models/archs/RRDBNet_arch.py | 2 +- codes/models/networks.py | 2 +- codes/models/steps/losses.py | 21 +++++++++++++++++---- codes/train.py | 2 +- codes/train2.py | 6 +++--- 5 files changed, 23 insertions(+), 10 deletions(-) diff --git a/codes/models/archs/RRDBNet_arch.py b/codes/models/archs/RRDBNet_arch.py index d86493f5..a7fec89d 100644 --- a/codes/models/archs/RRDBNet_arch.py +++ b/codes/models/archs/RRDBNet_arch.py @@ -145,7 +145,7 @@ class RRDBNet(nn.Module): body_block=RRDB, blocks_per_checkpoint=4, scale=4, - additive_mode="not_additive" # Options: "not_additive", "additive", "additive_enforced" + additive_mode="not_additive" # Options: "not", "additive", "additive_enforced" ): super(RRDBNet, self).__init__() self.num_blocks = num_blocks diff --git a/codes/models/networks.py b/codes/models/networks.py index 117e2806..c49b40e2 100644 --- a/codes/models/networks.py +++ b/codes/models/networks.py @@ -47,7 +47,7 @@ def define_G(opt, net_key='network_G', scale=None): netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode) elif which_model == 'RRDBNetBypass': - additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive' + additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not' netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass, blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'], diff --git a/codes/models/steps/losses.py b/codes/models/steps/losses.py index 8364b1a6..abdefd46 100644 --- a/codes/models/steps/losses.py +++ b/codes/models/steps/losses.py @@ -192,8 +192,8 @@ class GeneratorGanLoss(ConfigurableLoss): nfake = [] for i, t in enumerate(real): if isinstance(t, torch.Tensor): - nreal.append(t + torch.randn_like(t) * self.noise) - nfake.append(fake[i] + torch.randn_like(t) * self.noise) + nreal.append(t + torch.rand_like(t) * self.noise) + nfake.append(fake[i] + torch.rand_like(t) * self.noise) else: nreal.append(t) nfake.append(fake[i]) @@ -234,6 +234,7 @@ class DiscriminatorGanLoss(ConfigurableLoss): # This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance # generators and discriminators by essentially having them skip steps while their counterparts "catch up". self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 + self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False if self.min_loss != 0: assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design. self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) @@ -243,6 +244,8 @@ class DiscriminatorGanLoss(ConfigurableLoss): def forward(self, net, state): real = extract_params_from_state(self.opt['real'], state) real = [r.detach() for r in real] + if self.gradient_penalty: + [r.requires_grad_() for r in real] fake = extract_params_from_state(self.opt['fake'], state) fake = [f.detach() for f in fake] if self.noise: @@ -250,8 +253,8 @@ class DiscriminatorGanLoss(ConfigurableLoss): nfake = [] for i, t in enumerate(real): if isinstance(t, torch.Tensor): - nreal.append(t + torch.randn_like(t) * self.noise) - nfake.append(fake[i] + torch.randn_like(t) * self.noise) + nreal.append(t + torch.rand_like(t) * self.noise) + nfake.append(fake[i] + torch.rand_like(t) * self.noise) else: nreal.append(t) nfake.append(fake[i]) @@ -282,6 +285,16 @@ class DiscriminatorGanLoss(ConfigurableLoss): if torch.mean(self.loss_rotating_buffer) < self.min_loss: return 0 self.losses_computed += 1 + + if self.gradient_penalty: + # Apply gradient penalty. TODO: migrate this elsewhere. + from models.archs.stylegan.stylegan2 import gradient_penalty + assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators. + gp = gradient_penalty(real[0], d_real) + self.metrics.append(("gradient_penalty", gp.clone().detach())) + loss = loss + gp + self.metrics.append(("gradient_penalty", gp)) + return loss diff --git a/codes/train.py b/codes/train.py index 08005c7d..93cc98f9 100644 --- a/codes/train.py +++ b/codes/train.py @@ -291,7 +291,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_corrected_disc.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/train2.py b/codes/train2.py index 5f964370..6ba64db4 100644 --- a/codes/train2.py +++ b/codes/train2.py @@ -221,7 +221,7 @@ class Trainer: img_dir = os.path.join(opt['path']['val_images'], img_name) util.mkdir(img_dir) - self.model.feed_data(val_data) + self.model.feed_data(val_data, self.current_step) self.model.test() visuals = self.model.get_current_visuals() @@ -291,14 +291,14 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_gen_real.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() opt = option.parse(args.opt, is_train=True) trainer = Trainer() - #### distributed training settings +#### distributed training settings if args.launcher == 'none': # disabled distributed training opt['dist'] = False trainer.rank = -1