diff --git a/codes/utils/util.py b/codes/utils/util.py index d2077701..90e4cfdd 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -22,7 +22,13 @@ from shutil import get_terminal_size import scp import paramiko from torch.utils.checkpoint import checkpoint -from torch._six import inf + +try: + # 1.13.1 + from torch._six import inf +except Exception as e: + # 2.0 + from torch import inf import yaml