Add more sophisticated mechanism for balancing GAN losses
This commit is contained in:
parent
39865ca3df
commit
dd9d7b27ac
|
@ -125,6 +125,12 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
||||||
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
|
self.detach_real = opt['detach_real'] if 'detach_real' in opt.keys() else True
|
||||||
|
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||||
|
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||||
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||||
|
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||||
|
self.rb_ptr = 0
|
||||||
|
self.losses_computed = 0
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||||
|
@ -144,16 +150,23 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
fake = nfake
|
fake = nfake
|
||||||
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
if self.opt['gan_type'] in ['gan', 'pixgan', 'pixgan_fea']:
|
||||||
pred_g_fake = netD(*fake)
|
pred_g_fake = netD(*fake)
|
||||||
return self.criterion(pred_g_fake, True)
|
loss = self.criterion(pred_g_fake, True)
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
pred_d_real = netD(*real)
|
pred_d_real = netD(*real)
|
||||||
if self.detach_real:
|
if self.detach_real:
|
||||||
pred_d_real = pred_d_real.detach()
|
pred_d_real = pred_d_real.detach()
|
||||||
pred_g_fake = netD(*fake)
|
pred_g_fake = netD(*fake)
|
||||||
return (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
loss = (self.criterion(pred_d_real - torch.mean(pred_g_fake), False) +
|
||||||
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||||
|
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||||
|
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||||
|
return 0
|
||||||
|
self.losses_computed += 1
|
||||||
|
self.metrics.append(("loss_counter", self.losses_computed))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
class DiscriminatorGanLoss(ConfigurableLoss):
|
class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
|
@ -162,6 +175,12 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
self.opt = opt
|
self.opt = opt
|
||||||
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
self.criterion = GANLoss(opt['gan_type'], 1.0, 0.0).to(env['device'])
|
||||||
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
self.noise = None if 'noise' not in opt.keys() else opt['noise']
|
||||||
|
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||||
|
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||||
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||||
|
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||||
|
self.rb_ptr = 0
|
||||||
|
self.losses_computed = 0
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
self.metrics = []
|
self.metrics = []
|
||||||
|
@ -190,14 +209,21 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
l_real = self.criterion(d_real, True)
|
l_real = self.criterion(d_real, True)
|
||||||
l_fake = self.criterion(d_fake, False)
|
l_fake = self.criterion(d_fake, False)
|
||||||
l_total = l_real + l_fake
|
l_total = l_real + l_fake
|
||||||
return l_total
|
loss = l_total
|
||||||
elif self.opt['gan_type'] == 'ragan':
|
elif self.opt['gan_type'] == 'ragan':
|
||||||
d_fake_diff = d_fake - torch.mean(d_real)
|
d_fake_diff = d_fake - torch.mean(d_real)
|
||||||
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
self.metrics.append(("d_fake_diff", torch.mean(d_fake_diff)))
|
||||||
return (self.criterion(d_real - torch.mean(d_fake), True) +
|
loss = (self.criterion(d_real - torch.mean(d_fake), True) +
|
||||||
self.criterion(d_fake_diff, False))
|
self.criterion(d_fake_diff, False))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||||
|
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||||
|
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||||
|
return 0
|
||||||
|
self.losses_computed += 1
|
||||||
|
self.metrics.append(("loss_counter", self.losses_computed))
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
# Computes a loss created by comparing the output of a generator to the output from the same generator when fed an
|
||||||
|
|
|
@ -10,7 +10,11 @@ class LossAccumulator:
|
||||||
if name not in self.buffers.keys():
|
if name not in self.buffers.keys():
|
||||||
self.buffers[name] = (0, torch.zeros(self.buffer_sz))
|
self.buffers[name] = (0, torch.zeros(self.buffer_sz))
|
||||||
i, buf = self.buffers[name]
|
i, buf = self.buffers[name]
|
||||||
|
# Can take tensors or just plain python numbers.
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
buf[i] = tensor.detach().cpu()
|
buf[i] = tensor.detach().cpu()
|
||||||
|
else:
|
||||||
|
buf[i] = tensor
|
||||||
self.buffers[name] = ((i+1) % self.buffer_sz, buf)
|
self.buffers[name] = ((i+1) % self.buffer_sz, buf)
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user