Add a min_loss that is DDP compatible
This commit is contained in:
parent
f133243ac8
commit
515905e904
|
@ -26,6 +26,7 @@ class ConfigurableStep(Module):
|
|||
self.optimizers = None
|
||||
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
||||
self.grads_generated = False
|
||||
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else 0
|
||||
|
||||
self.injectors = []
|
||||
if 'injectors' in self.step_opt.keys():
|
||||
|
@ -162,11 +163,22 @@ class ConfigurableStep(Module):
|
|||
# In some cases, the loss could not be set (e.g. all losses have 'after')
|
||||
if isinstance(total_loss, torch.Tensor):
|
||||
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||
reset_required = total_loss < self.min_total_loss
|
||||
|
||||
# Scale the loss down by the accumulation factor.
|
||||
total_loss = total_loss / self.env['mega_batch_factor']
|
||||
|
||||
# Get dem grads!
|
||||
self.scaler.scale(total_loss).backward()
|
||||
|
||||
if reset_required:
|
||||
# You might be scratching your head at this. Why would you zero grad as opposed to not doing a
|
||||
# backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good
|
||||
# way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
|
||||
# implement it at the loss level.
|
||||
self.training_net.zero_grad()
|
||||
self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
||||
|
||||
self.grads_generated = True
|
||||
|
||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
|
|
|
@ -5,6 +5,7 @@ class LossAccumulator:
|
|||
def __init__(self, buffer_sz=50):
|
||||
self.buffer_sz = buffer_sz
|
||||
self.buffers = {}
|
||||
self.counters = {}
|
||||
|
||||
def add_loss(self, name, tensor):
|
||||
if name not in self.buffers.keys():
|
||||
|
@ -18,6 +19,12 @@ class LossAccumulator:
|
|||
filled = i+1 >= self.buffer_sz or filled
|
||||
self.buffers[name] = ((i+1) % self.buffer_sz, buf, filled)
|
||||
|
||||
def increment_metric(self, name):
|
||||
if name not in self.counters.keys():
|
||||
self.counters[name] = 1
|
||||
else:
|
||||
self.counters[name] += 1
|
||||
|
||||
def as_dict(self):
|
||||
result = {}
|
||||
for k, v in self.buffers.items():
|
||||
|
@ -26,4 +33,6 @@ class LossAccumulator:
|
|||
result["loss_" + k] = torch.mean(buf)
|
||||
else:
|
||||
result["loss_" + k] = torch.mean(buf[:i])
|
||||
for k, v in self.counters.items():
|
||||
result[k] = v
|
||||
return result
|
Loading…
Reference in New Issue
Block a user