Allow steps to specify which state keys to carry forward (reducing memory utilization)
This commit is contained in:
parent
62475005e4
commit
e420df479f
|
@ -277,6 +277,14 @@ class ConfigurableStep(Module):
|
|||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
# we must release the gradients.
|
||||
new_state = recursively_detach(new_state)
|
||||
|
||||
# Prune state outputs that are not actually needed.
|
||||
if 'step_outputs' in self.step_opt.keys():
|
||||
nst = {}
|
||||
for k in self.step_opt['step_outputs']:
|
||||
nst[k] = new_state[k]
|
||||
new_state = nst
|
||||
|
||||
return new_state
|
||||
|
||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||
|
|
Loading…
Reference in New Issue
Block a user