Add a min_loss that is DDP compatible

This commit is contained in:
James Betker 2020-10-28 15:46:59 -06:00
parent f133243ac8
commit 515905e904
2 changed files with 21 additions and 0 deletions

View File

@ -26,6 +26,7 @@ class ConfigurableStep(Module):
self.optimizers = None
self.scaler = GradScaler(enabled=self.opt['fp16'])
self.grads_generated = False
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else 0
self.injectors = []
if 'injectors' in self.step_opt.keys():
@ -162,11 +163,22 @@ class ConfigurableStep(Module):
# In some cases, the loss could not be set (e.g. all losses have 'after')
if isinstance(total_loss, torch.Tensor):
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
reset_required = total_loss < self.min_total_loss
# Scale the loss down by the accumulation factor.
total_loss = total_loss / self.env['mega_batch_factor']
# Get dem grads!
self.scaler.scale(total_loss).backward()
if reset_required:
# You might be scratching your head at this. Why would you zero grad as opposed to not doing a
# backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good
# way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
# implement it at the loss level.
self.training_net.zero_grad()
self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
self.grads_generated = True
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step

View File

@ -5,6 +5,7 @@ class LossAccumulator:
def __init__(self, buffer_sz=50):
self.buffer_sz = buffer_sz
self.buffers = {}
self.counters = {}
def add_loss(self, name, tensor):
if name not in self.buffers.keys():
@ -18,6 +19,12 @@ class LossAccumulator:
filled = i+1 >= self.buffer_sz or filled
self.buffers[name] = ((i+1) % self.buffer_sz, buf, filled)
def increment_metric(self, name):
if name not in self.counters.keys():
self.counters[name] = 1
else:
self.counters[name] += 1
def as_dict(self):
result = {}
for k, v in self.buffers.items():
@ -26,4 +33,6 @@ class LossAccumulator:
result["loss_" + k] = torch.mean(buf)
else:
result["loss_" + k] = torch.mean(buf[:i])
for k, v in self.counters.items():
result[k] = v
return result