diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index b4d51ea..8b31ecc 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -1,6 +1,7 @@ import torch import psutil import importlib +import os DEVICE_OVERRIDE = None DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] @@ -8,6 +9,8 @@ DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] from inspect import currentframe, getframeinfo import gc +is_WSL = 'wsl' in os.popen("uname -r").read().tolower() + def xpu_get_mem(device=0): total_memory = ipex.xpu.get_device_properties(device).total_memory return total_memory, total_memory - torch.xpu.memory_allocated(device) @@ -19,10 +22,11 @@ def do_gc(): except Exception as e: pass - try: - torch.xpu.empty_cache() - except Exception as e: - pass + if not is_WSL: + try: + torch.xpu.empty_cache() + except Exception as e: + pass def print_stats(collect=False): cf = currentframe().f_back