Recursively detach all outputs, even if they are nested in data structures

This commit is contained in:
James Betker 2020-09-19 21:47:34 -06:00
parent fe82785ba5
commit e9a39bfa14
2 changed files with 20 additions and 3 deletions

View File

@ -7,6 +7,7 @@ from apex import amp
from collections import OrderedDict from collections import OrderedDict
from .injectors import create_injector from .injectors import create_injector
from models.novograd import NovoGrad from models.novograd import NovoGrad
from utils.util import recursively_detach
logger = logging.getLogger('base') logger = logging.getLogger('base')
@ -147,9 +148,7 @@ class ConfigurableStep(Module):
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
# we must release the gradients. # we must release the gradients.
for k, v in new_state.items(): new_state = recursively_detach(new_state)
if isinstance(v, torch.Tensor):
new_state[k] = v.detach()
return new_state return new_state
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps() # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()

View File

@ -342,3 +342,21 @@ class ProgressBar(object):
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
self.completed, int(elapsed + 0.5), fps)) self.completed, int(elapsed + 0.5), fps))
sys.stdout.flush() sys.stdout.flush()
# Recursively detaches all tensors in a tree of lists, dicts and tuples and returns the same structure.
def recursively_detach(v):
if isinstance(v, torch.Tensor):
return v.detach()
elif isinstance(v, list) or isinstance(v, tuple):
out = [recursively_detach(i) for i in v]
if isinstance(v, tuple):
return tuple(out)
return out
elif isinstance(v, dict):
out = {}
for k, t in v.items():
out[k] = recursively_detach(t)
return out
else:
raise ValueError("Unsupported type")