Fix memory leak with recurrent loss

This commit is contained in:
James Betker 2020-10-18 10:22:10 -06:00
parent 552e70a032
commit c709d38cd5
2 changed files with 35 additions and 18 deletions

View File

@ -140,6 +140,7 @@ class GeneratorGanLoss(ConfigurableLoss):
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
if self.min_loss != 0:
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
self.rb_ptr = 0
self.losses_computed = 0
@ -172,6 +173,7 @@ class GeneratorGanLoss(ConfigurableLoss):
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
else:
raise NotImplementedError
if self.min_loss != 0:
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
@ -190,6 +192,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
if self.min_loss != 0:
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
self.rb_ptr = 0
self.losses_computed = 0
@ -228,6 +231,7 @@ class DiscriminatorGanLoss(ConfigurableLoss):
self.criterion(d_fake_diff, False))
else:
raise NotImplementedError
if self.min_loss != 0:
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
@ -397,6 +401,12 @@ class RecurrentLoss(ConfigurableLoss):
total_loss += self.loss(net, st)
return total_loss
def extra_metrics(self):
return self.loss.extra_metrics()
def clear_metrics(self):
self.loss.clear_metrics()
# Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss.
class ForElementLoss(ConfigurableLoss):
@ -414,3 +424,9 @@ class ForElementLoss(ConfigurableLoss):
st['_real'] = state[self.opt['real']][:, self.index]
st['_fake'] = state[self.opt['fake']][:, self.index]
return self.loss(net, st)
def extra_metrics(self):
return self.loss.extra_metrics()
def clear_metrics(self):
self.loss.clear_metrics()

View File

@ -64,6 +64,7 @@ class ProgressiveGeneratorInjector(Injector):
inputs = extract_params_from_state(self.input, state)
lq_inputs = inputs[self.input_lq_index]
hq_inputs = state[self.hq_key]
output = self.output
if not isinstance(inputs, list):
inputs = [inputs]
if not isinstance(self.output, list):