diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 42c9163e..eeb858f8 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -231,7 +231,7 @@ class ExtensibleTrainer(BaseModel): # Now do a forward and backward pass for each gradient accumulation step. new_states = {} for m in range(self.batch_factor): - ns = s.do_forward_backward(state, m, step_num, train=train_step) + ns = s.do_forward_backward(state, m, step_num, train=train_step, no_ddp_sync=(m+1 < self.batch_factor)) for k, v in ns.items(): if k not in new_states.keys(): new_states[k] = [v] diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 760d6187..c5064c8c 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -142,7 +142,7 @@ class ConfigurableStep(Module): # Performs all forward and backward passes for this step given an input state. All input states are lists of # chunked tensors. Use grad_accum_step to dereference these steps. Should return a dict of tensors that later # steps might use. These tensors are automatically detached and accumulated into chunks. - def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True): + def do_forward_backward(self, state, grad_accum_step, amp_loss_id, train=True, no_ddp_sync=False): local_state = {} # <-- Will store the entire local state to be passed to injectors & losses. new_state = {} # <-- Will store state values created by this step for returning to ExtensibleTrainer. for k, v in state.items(): @@ -169,7 +169,12 @@ class ConfigurableStep(Module): continue if 'no_accum' in inj.opt.keys() and grad_accum_step > 0: continue - injected = inj(local_state) + training_net = self.get_network_for_name(self.step_opt['training']) + if no_ddp_sync and hasattr(training_net, 'no_sync'): + with training_net.no_sync(): + injected = inj(local_state) + else: + injected = inj(local_state) local_state.update(injected) new_state.update(injected)