forked from mrq/DL-Art-School
Loss accumulator fix
This commit is contained in:
parent
567b4d50a4
commit
4e44fcd655
|
@ -155,12 +155,13 @@ class ConfigurableStep(Module):
|
|||
self.loss_accumulator.add_loss(loss_name, l)
|
||||
for n, v in loss.extra_metrics():
|
||||
self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v)
|
||||
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||
# Scale the loss down by the accumulation factor.
|
||||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
|
||||
# In some cases, the loss could not be set (e.g. all losses have 'after'
|
||||
if isinstance(loss, torch.Tensor):
|
||||
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||
# Scale the loss down by the accumulation factor.
|
||||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
# Get dem grads!
|
||||
if self.env['amp']:
|
||||
with amp.scale_loss(total_loss, self.optimizers, amp_loss_id) as scaled_loss:
|
||||
|
|
Loading…
Reference in New Issue
Block a user