Attempt to fix syncing multiple times when doing gradient accumulation

This commit is contained in:
James Betker 2021-06-13 14:30:30 -06:00
parent 1cd75dfd33
commit 9cfe840872
2 changed files with 8 additions and 3 deletions

View File

@ -231,7 +231,7 @@ class ExtensibleTrainer(BaseModel):
# Now do a forward and backward pass for each gradient accumulation step.
new_states = {}
for m in range(self.batch_factor):
ns = s.do_forward_backward(state, m, step_num, train=train_step)
ns = s.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
for k, v in ns.items():
if k not in new_states.keys():
new_states[k] = [v]

View File

@ -142,7 +142,7 @@ class ConfigurableStep(Module):
# Performs all forward and backward passes for this step given an input state. All input states are lists of
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
# steps might use. These tensors are automatically detached and accumulated into chunks.
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False):
local_state = {} # <-- Will store the entire local state to be passed to injectors & losses.
new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer.
for k, v in state.items():
@ -169,7 +169,12 @@ class ConfigurableStep(Module):
continue
if 'no_accum' in inj.opt.keys() and grad_accum_step > 0:
continue
injected = inj(local_state)
training_net = self.get_network_for_name(self.step_opt['training'])
if no_ddp_sync and hasattr(training_net, 'no_sync'):
with training_net.no_sync():
injected = inj(local_state)
else:
injected = inj(local_state)
local_state.update(injected)
new_state.update(injected)