forked from mrq/DL-Art-School
Fix memory leak with recurrent loss
This commit is contained in:
parent
552e70a032
commit
c709d38cd5
|
@ -140,9 +140,10 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
if self.min_loss != 0:
|
||||||
self.rb_ptr = 0
|
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||||
self.losses_computed = 0
|
self.rb_ptr = 0
|
||||||
|
self.losses_computed = 0
|
||||||
|
|
||||||
def forward(self, _, state):
|
def forward(self, _, state):
|
||||||
netD = self.env['discriminators'][self.opt['discriminator']]
|
netD = self.env['discriminators'][self.opt['discriminator']]
|
||||||
|
@ -172,12 +173,13 @@ class GeneratorGanLoss(ConfigurableLoss):
|
||||||
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
self.criterion(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
if self.min_loss != 0:
|
||||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||||
return 0
|
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||||
self.losses_computed += 1
|
return 0
|
||||||
self.metrics.append(("loss_counter", self.losses_computed))
|
self.losses_computed += 1
|
||||||
|
self.metrics.append(("loss_counter", self.losses_computed))
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -190,9 +192,10 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
# This is a mechanism to prevent backpropagation for a GAN loss if it goes too low. This can be used to balance
|
||||||
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
# generators and discriminators by essentially having them skip steps while their counterparts "catch up".
|
||||||
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
self.min_loss = opt['min_loss'] if 'min_loss' in opt.keys() else 0
|
||||||
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
if self.min_loss != 0:
|
||||||
self.rb_ptr = 0
|
self.loss_rotating_buffer = torch.zeros(10, requires_grad=False)
|
||||||
self.losses_computed = 0
|
self.rb_ptr = 0
|
||||||
|
self.losses_computed = 0
|
||||||
|
|
||||||
def forward(self, net, state):
|
def forward(self, net, state):
|
||||||
real = extract_params_from_state(self.opt['real'], state)
|
real = extract_params_from_state(self.opt['real'], state)
|
||||||
|
@ -228,12 +231,13 @@ class DiscriminatorGanLoss(ConfigurableLoss):
|
||||||
self.criterion(d_fake_diff, False))
|
self.criterion(d_fake_diff, False))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
if self.min_loss != 0:
|
||||||
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
self.loss_rotating_buffer[self.rb_ptr] = loss.item()
|
||||||
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
self.rb_ptr = (self.rb_ptr + 1) % self.loss_rotating_buffer.shape[0]
|
||||||
return 0
|
if torch.mean(self.loss_rotating_buffer) < self.min_loss:
|
||||||
self.losses_computed += 1
|
return 0
|
||||||
self.metrics.append(("loss_counter", self.losses_computed))
|
self.losses_computed += 1
|
||||||
|
self.metrics.append(("loss_counter", self.losses_computed))
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@ -397,6 +401,12 @@ class RecurrentLoss(ConfigurableLoss):
|
||||||
total_loss += self.loss(net, st)
|
total_loss += self.loss(net, st)
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|
||||||
|
def extra_metrics(self):
|
||||||
|
return self.loss.extra_metrics()
|
||||||
|
|
||||||
|
def clear_metrics(self):
|
||||||
|
self.loss.clear_metrics()
|
||||||
|
|
||||||
|
|
||||||
# Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss.
|
# Loss that pulls a tensor from dim 1 of the input and feeds it into a "sub" loss.
|
||||||
class ForElementLoss(ConfigurableLoss):
|
class ForElementLoss(ConfigurableLoss):
|
||||||
|
@ -414,3 +424,9 @@ class ForElementLoss(ConfigurableLoss):
|
||||||
st['_real'] = state[self.opt['real']][:, self.index]
|
st['_real'] = state[self.opt['real']][:, self.index]
|
||||||
st['_fake'] = state[self.opt['fake']][:, self.index]
|
st['_fake'] = state[self.opt['fake']][:, self.index]
|
||||||
return self.loss(net, st)
|
return self.loss(net, st)
|
||||||
|
|
||||||
|
def extra_metrics(self):
|
||||||
|
return self.loss.extra_metrics()
|
||||||
|
|
||||||
|
def clear_metrics(self):
|
||||||
|
self.loss.clear_metrics()
|
||||||
|
|
|
@ -64,6 +64,7 @@ class ProgressiveGeneratorInjector(Injector):
|
||||||
inputs = extract_params_from_state(self.input, state)
|
inputs = extract_params_from_state(self.input, state)
|
||||||
lq_inputs = inputs[self.input_lq_index]
|
lq_inputs = inputs[self.input_lq_index]
|
||||||
hq_inputs = state[self.hq_key]
|
hq_inputs = state[self.hq_key]
|
||||||
|
output = self.output
|
||||||
if not isinstance(inputs, list):
|
if not isinstance(inputs, list):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
if not isinstance(self.output, list):
|
if not isinstance(self.output, list):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user