Make grad_penalty available to classical discs

This commit is contained in:
James Betker 2020-11-17 18:31:40 -07:00
parent 8a19c9ae15
commit 6b679e2b51
5 changed files with 23 additions and 10 deletions

View File

@ -145,7 +145,7 @@ class RRDBNet(nn.Module):
body_block=RRDB, body_block=RRDB,
blocks_per_checkpoint=4, blocks_per_checkpoint=4,
scale=4, scale=4,
additive_mode="not_additive" # Options: "not_additive", "additive", "additive_enforced" additive_mode="not_additive" # Options: "not", "additive", "additive_enforced"
): ):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
self.num_blocks = num_blocks self.num_blocks = num_blocks

View File

@ -47,7 +47,7 @@ def define_G(opt, net_key='network_G', scale=None):
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode) mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode)
elif which_model == 'RRDBNetBypass': elif which_model == 'RRDBNetBypass':
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive' additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'], netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass, mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'], blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],

View File

@ -192,8 +192,8 @@ class GeneratorGanLoss(ConfigurableLoss):
nfake = [] nfake = []
for i, t in enumerate(real): for i, t in enumerate(real):
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
nreal.append(t + torch.randn_like(t) * self.noise) nreal.append(t + torch.rand_like(t) * self.noise)
nfake.append(fake[i] + torch.randn_like(t) * self.noise) nfake.append(fake[i] + torch.rand_like(t) * self.noise)
else: else:
nreal.append(t) nreal.append(t)
nfake.append(fake[i]) nfake.append(fake[i])
@ -234,6 +234,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance # This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
# generators and discriminators by essentially having them skip steps while their counterparts "catch up". # generators and discriminators by essentially having them skip steps while their counterparts "catch up".
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False
if self.min_loss != 0: if self.min_loss != 0:
assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design. assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design.
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
@ -243,6 +244,8 @@ class DiscriminatorGanLoss(ConfigurableLoss):
def forward(self, net, state): def forward(self, net, state):
real = extract_params_from_state(self.opt['real'], state) real = extract_params_from_state(self.opt['real'], state)
real = [r.detach() for r in real] real = [r.detach() for r in real]
if self.gradient_penalty:
[r.requires_grad_() for r in real]
fake = extract_params_from_state(self.opt['fake'], state) fake = extract_params_from_state(self.opt['fake'], state)
fake = [f.detach() for f in fake] fake = [f.detach() for f in fake]
if self.noise: if self.noise:
@ -250,8 +253,8 @@ class DiscriminatorGanLoss(ConfigurableLoss):
nfake = [] nfake = []
for i, t in enumerate(real): for i, t in enumerate(real):
if isinstance(t, torch.Tensor): if isinstance(t, torch.Tensor):
nreal.append(t + torch.randn_like(t) * self.noise) nreal.append(t + torch.rand_like(t) * self.noise)
nfake.append(fake[i] + torch.randn_like(t) * self.noise) nfake.append(fake[i] + torch.rand_like(t) * self.noise)
else: else:
nreal.append(t) nreal.append(t)
nfake.append(fake[i]) nfake.append(fake[i])
@ -282,6 +285,16 @@ class DiscriminatorGanLoss(ConfigurableLoss):
if torch.mean(self.loss_rotating_buffer) < self.min_loss: if torch.mean(self.loss_rotating_buffer) < self.min_loss:
return 0 return 0
self.losses_computed += 1 self.losses_computed += 1
if self.gradient_penalty:
# Apply gradient penalty. TODO: migrate this elsewhere.
from models.archs.stylegan.stylegan2 import gradient_penalty
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
gp = gradient_penalty(real[0], d_real)
self.metrics.append(("gradient_penalty", gp.clone().detach()))
loss = loss + gp
self.metrics.append(("gradient_penalty", gp))
return loss return loss

View File

@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_corrected_disc.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()

View File

@ -221,7 +221,7 @@ class Trainer:
img_dir = os.path.join(opt['path']['val_images'], img_name) img_dir = os.path.join(opt['path']['val_images'], img_name)
util.mkdir(img_dir) util.mkdir(img_dir)
self.model.feed_data(val_data) self.model.feed_data(val_data, self.current_step)
self.model.test() self.model.test()
visuals = self.model.get_current_visuals() visuals = self.model.get_current_visuals()
@ -291,7 +291,7 @@ class Trainer:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_gen_real.yml') parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()