Fix memory leak with recurrent loss

This commit is contained in:
James Betker 2020-10-18 10:22:10 -06:00
parent 552e70a032
commit c709d38cd5
2 changed files with 35 additions and 18 deletions

View File

@ -140,9 +140,10 @@ class GeneratorGanLoss(ConfigurableLoss):
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance # This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
# generators and discriminators by essentially having them skip steps while their counterparts "catch up". # generators and discriminators by essentially having them skip steps while their counterparts "catch up".
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) if self.min_loss != 0:
self.rb_ptr = 0 self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
self.losses_computed = 0 self.rb_ptr = 0
self.losses_computed = 0
def forward(self, _, state): def forward(self, _, state):
netD = self.env['discriminators'][self.opt['discriminator']] netD = self.env['discriminators'][self.opt['discriminator']]
@ -172,12 +173,13 @@ class GeneratorGanLoss(ConfigurableLoss):
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2 self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
else: else:
raise NotImplementedError raise NotImplementedError
self.loss_rotating_buffer[self.rb_ptr] = loss.item() if self.min_loss != 0:
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] self.loss_rotating_buffer[self.rb_ptr] = loss.item()
if torch.mean(self.loss_rotating_buffer) < self.min_loss: self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
return 0 if torch.mean(self.loss_rotating_buffer) < self.min_loss:
self.losses_computed += 1 return 0
self.metrics.append(("loss_counter", self.losses_computed)) self.losses_computed += 1
self.metrics.append(("loss_counter", self.losses_computed))
return loss return loss
@ -190,9 +192,10 @@ class DiscriminatorGanLoss(ConfigurableLoss):
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance # This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
# generators and discriminators by essentially having them skip steps while their counterparts "catch up". # generators and discriminators by essentially having them skip steps while their counterparts "catch up".
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0 self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False) if self.min_loss != 0:
self.rb_ptr = 0 self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
self.losses_computed = 0 self.rb_ptr = 0
self.losses_computed = 0
def forward(self, net, state): def forward(self, net, state):
real = extract_params_from_state(self.opt['real'], state) real = extract_params_from_state(self.opt['real'], state)
@ -228,12 +231,13 @@ class DiscriminatorGanLoss(ConfigurableLoss):
self.criterion(d_fake_diff, False)) self.criterion(d_fake_diff, False))
else: else:
raise NotImplementedError raise NotImplementedError
self.loss_rotating_buffer[self.rb_ptr] = loss.item() if self.min_loss != 0:
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0] self.loss_rotating_buffer[self.rb_ptr] = loss.item()
if torch.mean(self.loss_rotating_buffer) < self.min_loss: self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
return 0 if torch.mean(self.loss_rotating_buffer) < self.min_loss:
self.losses_computed += 1 return 0
self.metrics.append(("loss_counter", self.losses_computed)) self.losses_computed += 1
self.metrics.append(("loss_counter", self.losses_computed))
return loss return loss
@ -397,6 +401,12 @@ class RecurrentLoss(ConfigurableLoss):
total_loss += self.loss(net, st) total_loss += self.loss(net, st)
return total_loss return total_loss
def extra_metrics(self):
return self.loss.extra_metrics()
def clear_metrics(self):
self.loss.clear_metrics()
# Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss. # Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss.
class ForElementLoss(ConfigurableLoss): class ForElementLoss(ConfigurableLoss):
@ -414,3 +424,9 @@ class ForElementLoss(ConfigurableLoss):
st['_real'] = state[self.opt['real']][:, self.index] st['_real'] = state[self.opt['real']][:, self.index]
st['_fake'] = state[self.opt['fake']][:, self.index] st['_fake'] = state[self.opt['fake']][:, self.index]
return self.loss(net, st) return self.loss(net, st)
def extra_metrics(self):
return self.loss.extra_metrics()
def clear_metrics(self):
self.loss.clear_metrics()

View File

@ -64,6 +64,7 @@ class ProgressiveGeneratorInjector(Injector):
inputs = extract_params_from_state(self.input, state) inputs = extract_params_from_state(self.input, state)
lq_inputs = inputs[self.input_lq_index] lq_inputs = inputs[self.input_lq_index]
hq_inputs = state[self.hq_key] hq_inputs = state[self.hq_key]
output = self.output
if not isinstance(inputs, list): if not isinstance(inputs, list):
inputs = [inputs] inputs = [inputs]
if not isinstance(self.output, list): if not isinstance(self.output, list):