Fix checkpoint recursion

This commit is contained in:
James Betker 2020-10-03 12:52:50 -06:00
parent 9b4ed82093
commit 35731502c3

View File

@ -47,7 +47,7 @@ def OrderedYaml():
def checkpoint(fn, *args): def checkpoint(fn, *args):
enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True
if enabled: if enabled:
return checkpoint(fn, *args) return torch.utils.checkpoint.checkpoint(fn, *args)
else: else:
return fn(*args) return fn(*args)
@ -366,4 +366,4 @@ def recursively_detach(v):
out = {} out = {}
for k, t in v.items(): for k, t in v.items():
out[k] = recursively_detach(t) out[k] = recursively_detach(t)
return out return out