Add a min_loss that is DDP compatible
This commit is contained in:
parent
f133243ac8
commit
515905e904
|
@ -26,6 +26,7 @@ class ConfigurableStep(Module):
|
||||||
self.optimizers = None
|
self.optimizers = None
|
||||||
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
self.scaler = GradScaler(enabled=self.opt['fp16'])
|
||||||
self.grads_generated = False
|
self.grads_generated = False
|
||||||
|
self.min_total_loss = opt_step['min_total_loss'] if 'min_total_loss' in opt_step.keys() else 0
|
||||||
|
|
||||||
self.injectors = []
|
self.injectors = []
|
||||||
if 'injectors' in self.step_opt.keys():
|
if 'injectors' in self.step_opt.keys():
|
||||||
|
@ -162,11 +163,22 @@ class ConfigurableStep(Module):
|
||||||
# In some cases, the loss could not be set (e.g. all losses have 'after')
|
# In some cases, the loss could not be set (e.g. all losses have 'after')
|
||||||
if isinstance(total_loss, torch.Tensor):
|
if isinstance(total_loss, torch.Tensor):
|
||||||
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss)
|
||||||
|
reset_required = total_loss < self.min_total_loss
|
||||||
|
|
||||||
# Scale the loss down by the accumulation factor.
|
# Scale the loss down by the accumulation factor.
|
||||||
total_loss = total_loss / self.env['mega_batch_factor']
|
total_loss = total_loss / self.env['mega_batch_factor']
|
||||||
|
|
||||||
# Get dem grads!
|
# Get dem grads!
|
||||||
self.scaler.scale(total_loss).backward()
|
self.scaler.scale(total_loss).backward()
|
||||||
|
|
||||||
|
if reset_required:
|
||||||
|
# You might be scratching your head at this. Why would you zero grad as opposed to not doing a
|
||||||
|
# backwards? Because DDP uses the backward() pass as a synchronization point and there is not a good
|
||||||
|
# way to simply bypass backward. If you want a more efficient way to specify a min_loss, use or
|
||||||
|
# implement it at the loss level.
|
||||||
|
self.training_net.zero_grad()
|
||||||
|
self.loss_accumulator.increment_metric("%s_skipped_steps" % (self.get_training_network_name(),))
|
||||||
|
|
||||||
self.grads_generated = True
|
self.grads_generated = True
|
||||||
|
|
||||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||||
|
|
|
@ -5,6 +5,7 @@ class LossAccumulator:
|
||||||
def __init__(self, buffer_sz=50):
|
def __init__(self, buffer_sz=50):
|
||||||
self.buffer_sz = buffer_sz
|
self.buffer_sz = buffer_sz
|
||||||
self.buffers = {}
|
self.buffers = {}
|
||||||
|
self.counters = {}
|
||||||
|
|
||||||
def add_loss(self, name, tensor):
|
def add_loss(self, name, tensor):
|
||||||
if name not in self.buffers.keys():
|
if name not in self.buffers.keys():
|
||||||
|
@ -18,6 +19,12 @@ class LossAccumulator:
|
||||||
filled = i+1 >= self.buffer_sz or filled
|
filled = i+1 >= self.buffer_sz or filled
|
||||||
self.buffers[name] = ((i+1) % self.buffer_sz, buf, filled)
|
self.buffers[name] = ((i+1) % self.buffer_sz, buf, filled)
|
||||||
|
|
||||||
|
def increment_metric(self, name):
|
||||||
|
if name not in self.counters.keys():
|
||||||
|
self.counters[name] = 1
|
||||||
|
else:
|
||||||
|
self.counters[name] += 1
|
||||||
|
|
||||||
def as_dict(self):
|
def as_dict(self):
|
||||||
result = {}
|
result = {}
|
||||||
for k, v in self.buffers.items():
|
for k, v in self.buffers.items():
|
||||||
|
@ -26,4 +33,6 @@ class LossAccumulator:
|
||||||
result["loss_" + k] = torch.mean(buf)
|
result["loss_" + k] = torch.mean(buf)
|
||||||
else:
|
else:
|
||||||
result["loss_" + k] = torch.mean(buf[:i])
|
result["loss_" + k] = torch.mean(buf[:i])
|
||||||
|
for k, v in self.counters.items():
|
||||||
|
result[k] = v
|
||||||
return result
|
return result
|
Loading…
Reference in New Issue
Block a user