Make grad_penalty available to classical discs
This commit is contained in:
parent
8a19c9ae15
commit
6b679e2b51
|
@ -145,7 +145,7 @@ class RRDBNet(nn.Module):
|
|||
body_block=RRDB,
|
||||
blocks_per_checkpoint=4,
|
||||
scale=4,
|
||||
additive_mode="not_additive" # Options: "not_additive", "additive", "additive_enforced"
|
||||
additive_mode="not_additive" # Options: "not", "additive", "additive_enforced"
|
||||
):
|
||||
super(RRDBNet, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
|
|
|
@ -47,7 +47,7 @@ def define_G(opt, net_key='network_G', scale=None):
|
|||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], additive_mode=additive_mode)
|
||||
elif which_model == 'RRDBNetBypass':
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not_additive'
|
||||
additive_mode = opt_net['additive_mode'] if 'additive_mode' in opt_net.keys() else 'not'
|
||||
netG = RRDBNet_arch.RRDBNet(in_channels=opt_net['in_nc'], out_channels=opt_net['out_nc'],
|
||||
mid_channels=opt_net['nf'], num_blocks=opt_net['nb'], body_block=RRDBNet_arch.RRDBWithBypass,
|
||||
blocks_per_checkpoint=opt_net['blocks_per_checkpoint'], scale=opt_net['scale'],
|
||||
|
|
|
@ -192,8 +192,8 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
nfake = []
|
||||
for i, t in enumerate(real):
|
||||
if isinstance(t, torch.Tensor):
|
||||
nreal.append(t + torch.randn_like(t) * self.noise)
|
||||
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
|
||||
nreal.append(t + torch.rand_like(t) * self.noise)
|
||||
nfake.append(fake[i] + torch.rand_like(t) * self.noise)
|
||||
else:
|
||||
nreal.append(t)
|
||||
nfake.append(fake[i])
|
||||
|
@ -234,6 +234,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||
self.gradient_penalty = opt['gradient_penalty'] if 'gradient_penalty' in opt.keys() else False
|
||||
if self.min_loss != 0:
|
||||
assert not self.env['dist'] # distributed training does not support 'min_loss' - it can result in backward() desync by design.
|
||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||
|
@ -243,6 +244,8 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
def forward(self, net, state):
|
||||
real = extract_params_from_state(self.opt['real'], state)
|
||||
real = [r.detach() for r in real]
|
||||
if self.gradient_penalty:
|
||||
[r.requires_grad_() for r in real]
|
||||
fake = extract_params_from_state(self.opt['fake'], state)
|
||||
fake = [f.detach() for f in fake]
|
||||
if self.noise:
|
||||
|
@ -250,8 +253,8 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
nfake = []
|
||||
for i, t in enumerate(real):
|
||||
if isinstance(t, torch.Tensor):
|
||||
nreal.append(t + torch.randn_like(t) * self.noise)
|
||||
nfake.append(fake[i] + torch.randn_like(t) * self.noise)
|
||||
nreal.append(t + torch.rand_like(t) * self.noise)
|
||||
nfake.append(fake[i] + torch.rand_like(t) * self.noise)
|
||||
else:
|
||||
nreal.append(t)
|
||||
nfake.append(fake[i])
|
||||
|
@ -282,6 +285,16 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||
return 0
|
||||
self.losses_computed += 1
|
||||
|
||||
if self.gradient_penalty:
|
||||
# Apply gradient penalty. TODO: migrate this elsewhere.
|
||||
from models.archs.stylegan.stylegan2 import gradient_penalty
|
||||
assert len(real) == 1 # Grad penalty doesn't currently support multi-input discriminators.
|
||||
gp = gradient_penalty(real[0], d_real)
|
||||
self.metrics.append(("gradient_penalty", gp.clone().detach()))
|
||||
loss = loss + gp
|
||||
self.metrics.append(("gradient_penalty", gp))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_corrected_disc.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_normal.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -221,7 +221,7 @@ class Trainer:
|
|||
img_dir = os.path.join(opt['path']['val_images'], img_name)
|
||||
util.mkdir(img_dir)
|
||||
|
||||
self.model.feed_data(val_data)
|
||||
self.model.feed_data(val_data, self.current_step)
|
||||
self.model.test()
|
||||
|
||||
visuals = self.model.get_current_visuals()
|
||||
|
@ -291,7 +291,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_stylegan2_for_sr_gen_real.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_mi1_rrdb4x_6bl_grad_penalty.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user