Workaround for WSL VRAM leaks

This commit is contained in:
a-One-Fan 2023-07-13 10:16:57 +03:00
parent 8618922a33
commit 1271237d89

View File

@ -1,6 +1,7 @@
import torch
import psutil
import importlib
import os
DEVICE_OVERRIDE = None
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
@ -8,6 +9,8 @@ DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
from inspect import currentframe, getframeinfo
import gc
is_WSL = 'wsl' in os.popen("uname -r").read().tolower()
def xpu_get_mem(device=0):
total_memory = ipex.xpu.get_device_properties(device).total_memory
return total_memory, total_memory - torch.xpu.memory_allocated(device)
@ -19,10 +22,11 @@ def do_gc():
except Exception as e:
pass
try:
torch.xpu.empty_cache()
except Exception as e:
pass
if not is_WSL:
try:
torch.xpu.empty_cache()
except Exception as e:
pass
def print_stats(collect=False):
cf = currentframe().f_back