Allow steps to specify which state keys to carry forward (reducing memory utilization)

This commit is contained in:
James Betker 2022-01-24 11:01:27 -07:00
parent 62475005e4
commit e420df479f

View File

@ -277,6 +277,14 @@ class ConfigurableStep(Module):
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients.
new_state = recursively_detach(new_state)
# Prune state outputs that are not actually needed.
if 'step_outputs' in self.step_opt.keys():
nst = {}
for k in self.step_opt['step_outputs']:
nst[k] = new_state[k]
new_state = nst
return new_state
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()