forked from mrq/tortoise-tts
Workaround for WSL VRAM leaks
This commit is contained in:
parent
8618922a33
commit
1271237d89
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
import psutil
|
||||
import importlib
|
||||
import os
|
||||
|
||||
DEVICE_OVERRIDE = None
|
||||
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
||||
|
@ -8,6 +9,8 @@ DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
|||
from inspect import currentframe, getframeinfo
|
||||
import gc
|
||||
|
||||
is_WSL = 'wsl' in os.popen("uname -r").read().tolower()
|
||||
|
||||
def xpu_get_mem(device=0):
|
||||
total_memory = ipex.xpu.get_device_properties(device).total_memory
|
||||
return total_memory, total_memory - torch.xpu.memory_allocated(device)
|
||||
|
@ -19,10 +22,11 @@ def do_gc():
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
try:
|
||||
torch.xpu.empty_cache()
|
||||
except Exception as e:
|
||||
pass
|
||||
if not is_WSL:
|
||||
try:
|
||||
torch.xpu.empty_cache()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def print_stats(collect=False):
|
||||
cf = currentframe().f_back
|
||||
|
|
Loading…
Reference in New Issue
Block a user