Fix checkpoint recursion

This commit is contained in:
James Betker 2020-10-03 12:52:50 -06:00
parent 9b4ed82093
commit 35731502c3

View File

@ -47,7 +47,7 @@ def OrderedYaml():
def checkpoint(fn, *args): def checkpoint(fn, *args):
enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True
if enabled: if enabled:
return checkpoint(fn, *args) return torch.utils.checkpoint.checkpoint(fn, *args)
else: else:
return fn(*args) return fn(*args)