From e9a39bfa143dff0ffed5ac12649a24883680afc9 Mon Sep 17 00:00:00 2001 From: James Betker Date: Sat, 19 Sep 2020 21:47:34 -0600 Subject: [PATCH] Recursively detach all outputs, even if they are nested in data structures --- codes/models/steps/steps.py | 5 ++--- codes/utils/util.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index f2226d69..c8c7de38 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -7,6 +7,7 @@ from apex import amp from collections import OrderedDict from .injectors import create_injector from models.novograd import NovoGrad +from utils.util import recursively_detach logger = logging.getLogger('base') @@ -147,9 +148,7 @@ class ConfigurableStep(Module): # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients. - for k, v in new_state.items(): - if isinstance(v, torch.Tensor): - new_state[k] = v.detach() + new_state = recursively_detach(new_state) return new_state # Performs the optimizer step after all gradient accumulation is completed. Default implementation simply steps() diff --git a/codes/utils/util.py b/codes/utils/util.py index fe3d39a8..17d1341b 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -342,3 +342,21 @@ class ProgressBar(object): sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( self.completed, int(elapsed + 0.5), fps)) sys.stdout.flush() + + +# Recursively detaches all tensors in a tree of lists, dicts and tuples and returns the same structure. +def recursively_detach(v): + if isinstance(v, torch.Tensor): + return v.detach() + elif isinstance(v, list) or isinstance(v, tuple): + out = [recursively_detach(i) for i in v] + if isinstance(v, tuple): + return tuple(out) + return out + elif isinstance(v, dict): + out = {} + for k, t in v.items(): + out[k] = recursively_detach(t) + return out + else: + raise ValueError("Unsupported type") \ No newline at end of file