Recursively detach all outputs, even if they are nested in data structures
This commit is contained in:
parent
fe82785ba5
commit
e9a39bfa14
|
@ -7,6 +7,7 @@ from apex import amp
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from .injectors import create_injector
|
from .injectors import create_injector
|
||||||
from models.novograd import NovoGrad
|
from models.novograd import NovoGrad
|
||||||
|
from utils.util import recursively_detach
|
||||||
|
|
||||||
logger = logging.getLogger('base')
|
logger = logging.getLogger('base')
|
||||||
|
|
||||||
|
@ -147,9 +148,7 @@ class ConfigurableStep(Module):
|
||||||
|
|
||||||
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
# Detach all state variables. Within the step, gradients can flow. Once these variables leave the step
|
||||||
# we must release the gradients.
|
# we must release the gradients.
|
||||||
for k, v in new_state.items():
|
new_state = recursively_detach(new_state)
|
||||||
if isinstance(v, torch.Tensor):
|
|
||||||
new_state[k] = v.detach()
|
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
# Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps()
|
||||||
|
|
|
@ -342,3 +342,21 @@ class ProgressBar(object):
|
||||||
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
|
sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
|
||||||
self.completed, int(elapsed + 0.5), fps))
|
self.completed, int(elapsed + 0.5), fps))
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
# Recursively detaches all tensors in a tree of lists, dicts and tuples and returns the same structure.
|
||||||
|
def recursively_detach(v):
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
return v.detach()
|
||||||
|
elif isinstance(v, list) or isinstance(v, tuple):
|
||||||
|
out = [recursively_detach(i) for i in v]
|
||||||
|
if isinstance(v, tuple):
|
||||||
|
return tuple(out)
|
||||||
|
return out
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
out = {}
|
||||||
|
for k, t in v.items():
|
||||||
|
out[k] = recursively_detach(t)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported type")
|
Loading…
Reference in New Issue
Block a user