Merge pull request 'Update tortoise/utils/devices.py vram issue' (#44) from aJoe/tortoise-tts:main into main

Reviewed-on: #44
This commit is contained in:
mrq 2023-04-12 19:58:02 +00:00
commit f025470d60

View File

@ -1,127 +1,128 @@
import torch import torch
import psutil import psutil
import importlib import importlib
DEVICE_OVERRIDE = None DEVICE_OVERRIDE = None
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
from inspect import currentframe, getframeinfo from inspect import currentframe, getframeinfo
import gc import gc
def do_gc(): def do_gc():
gc.collect() gc.collect()
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
except Exception as e: except Exception as e:
pass pass
def print_stats(collect=False): def print_stats(collect=False):
cf = currentframe().f_back cf = currentframe().f_back
msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}' msg = f'{getframeinfo(cf).filename}:{cf.f_lineno}'
if collect: if collect:
do_gc() do_gc()
tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3) tot = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
res = torch.cuda.memory_reserved(0) / (1024 ** 3) res = torch.cuda.memory_reserved(0) / (1024 ** 3)
alloc = torch.cuda.memory_allocated(0) / (1024 ** 3) alloc = torch.cuda.memory_allocated(0) / (1024 ** 3)
print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res )) print("[{}] Total: {:.3f} | Reserved: {:.3f} | Allocated: {:.3f} | Free: {:.3f}".format( msg, tot, res, alloc, tot-res ))
def has_dml(): def has_dml():
loader = importlib.find_loader('torch_directml') loader = importlib.find_loader('torch_directml')
if loader is None: if loader is None:
return False return False
import torch_directml import torch_directml
return torch_directml.is_available() return torch_directml.is_available()
def set_device_name(name): def set_device_name(name):
global DEVICE_OVERRIDE global DEVICE_OVERRIDE
DEVICE_OVERRIDE = name DEVICE_OVERRIDE = name
def get_device_name(attempt_gc=True): def get_device_name(attempt_gc=True):
global DEVICE_OVERRIDE global DEVICE_OVERRIDE
if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "": if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
return DEVICE_OVERRIDE return DEVICE_OVERRIDE
name = 'cpu' name = 'cpu'
if torch.cuda.is_available(): if torch.cuda.is_available():
name = 'cuda' name = 'cuda'
if attempt_gc: if attempt_gc:
torch.cuda.empty_cache() # may have performance implications torch.cuda.empty_cache() # may have performance implications
elif has_dml(): elif has_dml():
name = 'dml' name = 'dml'
return name return name
def get_device(verbose=False): def get_device(verbose=False):
name = get_device_name() name = get_device_name()
if verbose: if verbose:
if name == 'cpu': if name == 'cpu':
print("No hardware acceleration is available, falling back to CPU...") print("No hardware acceleration is available, falling back to CPU...")
else: else:
print(f"Hardware acceleration found: {name}") print(f"Hardware acceleration found: {name}")
if name == "dml": if name == "dml":
import torch_directml import torch_directml
return torch_directml.device() return torch_directml.device()
return torch.device(name) return torch.device(name)
def get_device_vram( name=get_device_name() ): def get_device_vram( name=get_device_name() ):
available = 1 available = 1
if name == "cuda": if name == "cuda":
_, available = torch.cuda.mem_get_info() _, available = torch.cuda.mem_get_info()
elif name == "cpu": elif name == "cpu":
available = psutil.virtual_memory()[4] available = psutil.virtual_memory()[4]
return available / (1024 ** 3) return available / (1024 ** 3)
def get_device_batch_size(name=None): def get_device_batch_size(name=None):
vram = get_device_vram(name) name = get_device_name()
vram = get_device_vram(name)
if vram > 14:
return 16 if vram > 14:
elif vram > 10: return 16
return 8 elif vram > 10:
elif vram > 7: return 8
return 4 elif vram > 7:
""" return 4
for k, v in DEVICE_BATCH_SIZE_MAP: """
if vram > k: for k, v in DEVICE_BATCH_SIZE_MAP:
return v if vram > k:
""" return v
return 1 """
return 1
def get_device_count(name=get_device_name()):
if name == "cuda": def get_device_count(name=get_device_name()):
return torch.cuda.device_count() if name == "cuda":
if name == "dml": return torch.cuda.device_count()
import torch_directml if name == "dml":
return torch_directml.device_count() import torch_directml
return torch_directml.device_count()
return 1
return 1
if has_dml():
_cumsum = torch.cumsum if has_dml():
_repeat_interleave = torch.repeat_interleave _cumsum = torch.cumsum
_multinomial = torch.multinomial _repeat_interleave = torch.repeat_interleave
_multinomial = torch.multinomial
_Tensor_new = torch.Tensor.new
_Tensor_cumsum = torch.Tensor.cumsum _Tensor_new = torch.Tensor.new
_Tensor_repeat_interleave = torch.Tensor.repeat_interleave _Tensor_cumsum = torch.Tensor.cumsum
_Tensor_multinomial = torch.Tensor.multinomial _Tensor_repeat_interleave = torch.Tensor.repeat_interleave
_Tensor_multinomial = torch.Tensor.multinomial
torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) ) torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) ) torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) ) torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )