Fix recursive checkpoint

This commit is contained in:
James Betker 2020-10-03 16:15:52 -06:00
parent 3cbb9ecd45
commit c896939523

View File

@ -47,7 +47,7 @@ def OrderedYaml():
def checkpoint(fn, *args):
enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True
if enabled:
return checkpoint(fn, *args)
return torch.utils.checkpoint.checkpoint(fn, *args)
else:
return fn(*args)