diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index abd8c6c8..9fecdff3 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -277,6 +277,14 @@ class ConfigurableStep(Module): # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients. new_state = recursively_detach(new_state) + + # Prune state outputs that are not actually needed. + if 'step_outputs' in self.step_opt.keys(): + nst = {} + for k in self.step_opt['step_outputs']: + nst[k] = new_state[k] + new_state = nst + return new_state # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()