forked from mrq/DL-Art-School
Fix memory leak with recurrent loss
This commit is contained in:
parent
552e70a032
commit
c709d38cd5
|
@ -140,9 +140,10 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||
self.rb_ptr = 0
|
||||
self.losses_computed = 0
|
||||
if self.min_loss != 0:
|
||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||
self.rb_ptr = 0
|
||||
self.losses_computed = 0
|
||||
|
||||
def forward(self, _, state):
|
||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||
|
@ -172,12 +173,13 @@ class GeneratorGanLoss(ConfigurableLoss):
|
|||
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||
return 0
|
||||
self.losses_computed += 1
|
||||
self.metrics.append(("loss_counter", self.losses_computed))
|
||||
if self.min_loss != 0:
|
||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||
return 0
|
||||
self.losses_computed += 1
|
||||
self.metrics.append(("loss_counter", self.losses_computed))
|
||||
return loss
|
||||
|
||||
|
||||
|
@ -190,9 +192,10 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||
self.rb_ptr = 0
|
||||
self.losses_computed = 0
|
||||
if self.min_loss != 0:
|
||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||
self.rb_ptr = 0
|
||||
self.losses_computed = 0
|
||||
|
||||
def forward(self, net, state):
|
||||
real = extract_params_from_state(self.opt['real'], state)
|
||||
|
@ -228,12 +231,13 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
|||
self.criterion(d_fake_diff, False))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||
return 0
|
||||
self.losses_computed += 1
|
||||
self.metrics.append(("loss_counter", self.losses_computed))
|
||||
if self.min_loss != 0:
|
||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||
return 0
|
||||
self.losses_computed += 1
|
||||
self.metrics.append(("loss_counter", self.losses_computed))
|
||||
return loss
|
||||
|
||||
|
||||
|
@ -397,6 +401,12 @@ class RecurrentLoss(ConfigurableLoss):
|
|||
total_loss += self.loss(net, st)
|
||||
return total_loss
|
||||
|
||||
def extra_metrics(self):
|
||||
return self.loss.extra_metrics()
|
||||
|
||||
def clear_metrics(self):
|
||||
self.loss.clear_metrics()
|
||||
|
||||
|
||||
# Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss.
|
||||
class ForElementLoss(ConfigurableLoss):
|
||||
|
@ -414,3 +424,9 @@ class ForElementLoss(ConfigurableLoss):
|
|||
st['_real'] = state[self.opt['real']][:, self.index]
|
||||
st['_fake'] = state[self.opt['fake']][:, self.index]
|
||||
return self.loss(net, st)
|
||||
|
||||
def extra_metrics(self):
|
||||
return self.loss.extra_metrics()
|
||||
|
||||
def clear_metrics(self):
|
||||
self.loss.clear_metrics()
|
||||
|
|
|
@ -64,6 +64,7 @@ class ProgressiveGeneratorInjector(Injector):
|
|||
inputs = extract_params_from_state(self.input, state)
|
||||
lq_inputs = inputs[self.input_lq_index]
|
||||
hq_inputs = state[self.hq_key]
|
||||
output = self.output
|
||||
if not isinstance(inputs, list):
|
||||
inputs = [inputs]
|
||||
if not isinstance(self.output, list):
|
||||
|
|
Loading…
Reference in New Issue
Block a user