Workaround for WSL VRAM leaks

This commit is contained in:
a-One-Fan 2023-07-13 10:16:57 +03:00
parent 8618922a33
commit 1271237d89

View File

@ -1,6 +1,7 @@
import torch import torch
import psutil import psutil
import importlib import importlib
import os
DEVICE_OVERRIDE = None DEVICE_OVERRIDE = None
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
@ -8,6 +9,8 @@ DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
from inspect import currentframe, getframeinfo from inspect import currentframe, getframeinfo
import gc import gc
is_WSL = 'wsl' in os.popen("uname -r").read().tolower()
def xpu_get_mem(device=0): def xpu_get_mem(device=0):
total_memory = ipex.xpu.get_device_properties(device).total_memory total_memory = ipex.xpu.get_device_properties(device).total_memory
return total_memory, total_memory - torch.xpu.memory_allocated(device) return total_memory, total_memory - torch.xpu.memory_allocated(device)
@ -19,10 +22,11 @@ def do_gc():
except Exception as e: except Exception as e:
pass pass
try: if not is_WSL:
torch.xpu.empty_cache() try:
except Exception as e: torch.xpu.empty_cache()
pass except Exception as e:
pass
def print_stats(collect=False): def print_stats(collect=False):
cf = currentframe().f_back cf = currentframe().f_back