Make grad_penalty available to classical discs
This commit is contained in:
parent
8a19c9ae15
commit
6b679e2b51
|
@ -145,7 +145,7 @@ class RRDBNet(nn.Module):
|
||||||
body_block=RRDB,
|
body_block=RRDB,
|
||||||
blocks_per_checkpoint=4,
|
blocks_per_checkpoint=4,
|
||||||
scale=4,
|
scale=4,
|
||||||
additive_mode="not_additive" # Options: "not_additive", "additive", "additive_enforced"
|
additive_mode="not_additive" # Options: "not", "additive", "additive_enforced"
|
||||||
):
|
):
|
||||||
super(RRDBNet, self).__init__()
|
super(RRDBNet, self).__init__()
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
|
|
|
@ -47,7 +47,7 @@ def define_G(opt, net_key='network_G', scale=None):
|
||||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode)
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode)
|
||||||
elif which_model == 'RRDBNetBypass':
|
elif which_model == 'RRDBNetBypass':
|
||||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive'
|
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
||||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
|
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
|
||||||
|
|
|
@ -192,8 +192,8 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
nfake = []
|
nfake = []
|
||||||
for i, t in enumerate(real):
|
for i, t in enumerate(real):
|
||||||
if isinstance(t, torch.Tensor):
|
if isinstance(t, torch.Tensor):
|
||||||
nreal.append(t + torch.randn_like(t) * self.noise)
|
nreal.append(t + torch.rand_like(t) * self.noise)
|
||||||
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
|
nfake.append(fake[i] + torch.rand_like(t) * self.noise)
|
||||||
else:
|
else:
|
||||||
nreal.append(t)
|
nreal.append(t)
|
||||||
nfake.append(fake[i])
|
nfake.append(fake[i])
|
||||||
|
@ -234,6 +234,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||||
|
self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False
|
||||||
if self.min_loss != 0:
|
if self.min_loss != 0:
|
||||||
assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design.
|
assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design.
|
||||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||||
|
@ -243,6 +244,8 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
real = extract_params_from_state(self.opt['real'], state)
|
real = extract_params_from_state(self.opt['real'], state)
|
||||||
real = [r.detach() for r in real]
|
real = [r.detach() for r in real]
|
||||||
|
if self.gradient_penalty:
|
||||||
|
[r.requires_grad_() for r in real]
|
||||||
fake = extract_params_from_state(self.opt['fake'], state)
|
fake = extract_params_from_state(self.opt['fake'], state)
|
||||||
fake = [f.detach() for f in fake]
|
fake = [f.detach() for f in fake]
|
||||||
if self.noise:
|
if self.noise:
|
||||||
|
@ -250,8 +253,8 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
nfake = []
|
nfake = []
|
||||||
for i, t in enumerate(real):
|
for i, t in enumerate(real):
|
||||||
if isinstance(t, torch.Tensor):
|
if isinstance(t, torch.Tensor):
|
||||||
nreal.append(t + torch.randn_like(t) * self.noise)
|
nreal.append(t + torch.rand_like(t) * self.noise)
|
||||||
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
|
nfake.append(fake[i] + torch.rand_like(t) * self.noise)
|
||||||
else:
|
else:
|
||||||
nreal.append(t)
|
nreal.append(t)
|
||||||
nfake.append(fake[i])
|
nfake.append(fake[i])
|
||||||
|
@ -282,6 +285,16 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||||
return 0
|
return 0
|
||||||
self.losses_computed += 1
|
self.losses_computed += 1
|
||||||
|
|
||||||
|
if self.gradient_penalty:
|
||||||
|
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||||
|
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||||
|
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
||||||
|
gp = gradient_penalty(real[0], d_real)
|
||||||
|
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||||
|
loss = loss + gp
|
||||||
|
self.metrics.append(("gradient_penalty", gp))
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_corrected_disc.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -221,7 +221,7 @@ class Trainer:
|
||||||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||||
util.mkdir(img_dir)
|
util.mkdir(img_dir)
|
||||||
|
|
||||||
self.model.feed_data(val_data)
|
self.model.feed_data(val_data, self.current_step)
|
||||||
self.model.test()
|
self.model.test()
|
||||||
|
|
||||||
visuals = self.model.get_current_visuals()
|
visuals = self.model.get_current_visuals()
|
||||||
|
@ -291,7 +291,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_gen_real.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user