From 1271237d894e84a3f7303ae4edd5f3f352d44d6e Mon Sep 17 00:00:00 2001 From: a-One-Fan Date: Thu, 13 Jul 2023 10:16:57 +0300 Subject: [PATCH] Workaround for WSL VRAM leaks --- tortoise/utils/device.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index b4d51ea..8b31ecc 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -1,6 +1,7 @@ import torch import psutil import importlib +import os DEVICE_OVERRIDE = None DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] @@ -8,6 +9,8 @@ DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] from inspect import currentframe, getframeinfo import gc +is_WSL = 'wsl' in os.popen("uname -r").read().tolower() + def xpu_get_mem(device=0): total_memory = ipex.xpu.get_device_properties(device).total_memory return total_memory, total_memory - torch.xpu.memory_allocated(device) @@ -19,10 +22,11 @@ def do_gc(): except Exception as e: pass - try: - torch.xpu.empty_cache() - except Exception as e: - pass + if not is_WSL: + try: + torch.xpu.empty_cache() + except Exception as e: + pass def print_stats(collect=False): cf = currentframe().f_back