forked from mrq/DL-Art-School
Attempt to fix syncing multiple times when doing gradient accumulation
This commit is contained in:
parent
1cd75dfd33
commit
9cfe840872
|
@ -231,7 +231,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
# Now do a forward and backward pass for each gradient accumulation step.
|
||||
new_states = {}
|
||||
for m in range(self.batch_factor):
|
||||
ns = s.do_forward_backward(state, m, step_num, train=train_step)
|
||||
ns = s.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor))
|
||||
for k, v in ns.items():
|
||||
if k not in new_states.keys():
|
||||
new_states[k] = [v]
|
||||
|
|
|
@ -142,7 +142,7 @@ class ConfigurableStep(Module):
|
|||
# Performs all forward and backward passes for this step given an input state. All input states are lists of
|
||||
# chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later
|
||||
# steps might use. These tensors are automatically detached and accumulated into chunks.
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True):
|
||||
def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False):
|
||||
local_state = {} # <-- Will store the entire local state to be passed to injectors & losses.
|
||||
new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer.
|
||||
for k, v in state.items():
|
||||
|
@ -169,7 +169,12 @@ class ConfigurableStep(Module):
|
|||
continue
|
||||
if 'no_accum' in inj.opt.keys() and grad_accum_step > 0:
|
||||
continue
|
||||
injected = inj(local_state)
|
||||
training_net = self.get_network_for_name(self.step_opt['training'])
|
||||
if no_ddp_sync and hasattr(training_net, 'no_sync'):
|
||||
with training_net.no_sync():
|
||||
injected = inj(local_state)
|
||||
else:
|
||||
injected = inj(local_state)
|
||||
local_state.update(injected)
|
||||
new_state.update(injected)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user