forked from mrq/tortoise-tts
Workaround for WSL VRAM leaks
This commit is contained in:
parent
8618922a33
commit
1271237d89
|
@ -1,6 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import psutil
|
import psutil
|
||||||
import importlib
|
import importlib
|
||||||
|
import os
|
||||||
|
|
||||||
DEVICE_OVERRIDE = None
|
DEVICE_OVERRIDE = None
|
||||||
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
||||||
|
@ -8,6 +9,8 @@ DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
||||||
from inspect import currentframe, getframeinfo
|
from inspect import currentframe, getframeinfo
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
is_WSL = 'wsl' in os.popen("uname -r").read().tolower()
|
||||||
|
|
||||||
def xpu_get_mem(device=0):
|
def xpu_get_mem(device=0):
|
||||||
total_memory = ipex.xpu.get_device_properties(device).total_memory
|
total_memory = ipex.xpu.get_device_properties(device).total_memory
|
||||||
return total_memory, total_memory - torch.xpu.memory_allocated(device)
|
return total_memory, total_memory - torch.xpu.memory_allocated(device)
|
||||||
|
@ -19,10 +22,11 @@ def do_gc():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
if not is_WSL:
|
||||||
torch.xpu.empty_cache()
|
try:
|
||||||
except Exception as e:
|
torch.xpu.empty_cache()
|
||||||
pass
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
def print_stats(collect=False):
|
def print_stats(collect=False):
|
||||||
cf = currentframe().f_back
|
cf = currentframe().f_back
|
||||||
|
|
Loading…
Reference in New Issue
Block a user