diff --git a/codes/utils/util.py b/codes/utils/util.py index cf904807..c3ca12b8 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -47,7 +47,7 @@ def OrderedYaml(): def checkpoint(fn, *args): enabled = options.loaded_options['checkpointing_enabled'] if 'checkpointing_enabled' in options.loaded_options.keys() else True if enabled: - return checkpoint(fn, *args) + return torch.utils.checkpoint.checkpoint(fn, *args) else: return fn(*args) @@ -366,4 +366,4 @@ def recursively_detach(v): out = {} for k, t in v.items(): out[k] = recursively_detach(t) - return out \ No newline at end of file + return out