From e420df479ffb70b7cacb9ad128d816f35ea5717c Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 24 Jan 2022 11:01:27 -0700 Subject: [PATCH] Allow steps to specify which state keys to carry forward (reducing memory utilization) --- codes/trainer/steps.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index abd8c6c8..9fecdff3 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -277,6 +277,14 @@ class ConfigurableStep(Module): # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients. new_state = recursively_detach(new_state) + + # Prune state outputs that are not actually needed. + if 'step_outputs' in self.step_opt.keys(): + nst = {} + for k in self.step_opt['step_outputs']: + nst[k] = new_state[k] + new_state = nst + return new_state # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()