From 515905e904443b29b836b12947bc71355e715a47 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 28 Oct 2020 15:46:59 -0600 Subject: [PATCH] Add a min_loss that is DDP compatible --- codes/models/steps/steps.py | 12 ++++++++++++ codes/utils/loss_accumulator.py | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index d925361e..cb92749a 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -26,6 +26,7 @@ class ConfigurableStep(Module): self.optimizers = None self.scaler = GradScaler(enabled=self.opt['fp16']) self.grads_generated = False + self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else 0 self.injectors = [] if 'injectors' in self.step_opt.keys(): @@ -162,11 +163,22 @@ class ConfigurableStep(Module): # In some cases, the loss could not be set (e.g. all losses have 'after') if isinstance(total_loss, torch.Tensor): self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) + reset_required = total_loss < self.min_total_loss + # Scale the loss down by the accumulation factor. total_loss = total_loss / self.env['mega_batch_factor'] # Get dem grads! self.scaler.scale(total_loss).backward() + + if reset_required: + # You might be scratching your head at this. Why would you zero grad as opposed to not doing a + # backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good + # way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or + # implement it at the loss level. + self.training_net.zero_grad() + self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),)) + self.grads_generated = True # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step diff --git a/codes/utils/loss_accumulator.py b/codes/utils/loss_accumulator.py index 8965cf3d..bfc7e772 100644 --- a/codes/utils/loss_accumulator.py +++ b/codes/utils/loss_accumulator.py @@ -5,6 +5,7 @@ class LossAccumulator: def __init__(self, buffer_sz=50): self.buffer_sz = buffer_sz self.buffers = {} + self.counters = {} def add_loss(self, name, tensor): if name not in self.buffers.keys(): @@ -18,6 +19,12 @@ class LossAccumulator: filled = i+1 >= self.buffer_sz or filled self.buffers[name] = ((i+1) % self.buffer_sz, buf, filled) + def increment_metric(self, name): + if name not in self.counters.keys(): + self.counters[name] = 1 + else: + self.counters[name] += 1 + def as_dict(self): result = {} for k, v in self.buffers.items(): @@ -26,4 +33,6 @@ class LossAccumulator: result["loss_" + k] = torch.mean(buf) else: result["loss_" + k] = torch.mean(buf[:i]) + for k, v in self.counters.items(): + result[k] = v return result \ No newline at end of file