Recursively detach all outputs, even if they are nested in data structures
This commit is contained in:
parent
fe82785ba5
commit
e9a39bfa14
|
@ -7,6 +7,7 @@ from apex import amp
|
|||
from collections import OrderedDict
|
||||
from .injectors import create_injector
|
||||
from models.novograd import NovoGrad
|
||||
from utils.util import recursively_detach
|
||||
|
||||
logger = logging.getLogger('base')
|
||||
|
||||
|
@ -147,9 +148,7 @@ class ConfigurableStep(Module):
|
|||
|
||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||
# we must release the gradients.
|
||||
for k, v in new_state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_state[k] = v.detach()
|
||||
new_state = recursively_detach(new_state)
|
||||
return new_state
|
||||
|
||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||
|
|
|
@ -342,3 +342,21 @@ class ProgressBar(object):
|
|||
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
|
||||
self.completed, int(elapsed + 0.5), fps))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
# Recursively detaches all tensors in a tree of lists, dicts and tuples and returns the same structure.
|
||||
def recursively_detach(v):
|
||||
if isinstance(v, torch.Tensor):
|
||||
return v.detach()
|
||||
elif isinstance(v, list) or isinstance(v, tuple):
|
||||
out = [recursively_detach(i) for i in v]
|
||||
if isinstance(v, tuple):
|
||||
return tuple(out)
|
||||
return out
|
||||
elif isinstance(v, dict):
|
||||
out = {}
|
||||
for k, t in v.items():
|
||||
out[k] = recursively_detach(t)
|
||||
return out
|
||||
else:
|
||||
raise ValueError("Unsupported type")
|
Loading…
Reference in New Issue
Block a user