tortoise-tts/tortoise/utils/device.py

103 lines
3.2 KiB
Python
Raw Normal View History

2023-02-09 01:53:25 +00:00
import torch
2023-02-09 20:42:38 +00:00
import psutil
import importlib
2023-02-09 01:53:25 +00:00
DEVICE_OVERRIDE = None
2023-02-09 01:53:25 +00:00
def has_dml():
loader = importlib.find_loader('torch_directml')
2023-02-09 20:42:38 +00:00
if loader is None:
return False
import torch_directml
return torch_directml.is_available()
2023-02-09 01:53:25 +00:00
def set_device_name(name):
global DEVICE_OVERRIDE
DEVICE_OVERRIDE = name
def get_device_name(attempt_gc=True):
global DEVICE_OVERRIDE
2023-02-16 13:23:07 +00:00
if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
return DEVICE_OVERRIDE
2023-02-09 01:53:25 +00:00
name = 'cpu'
if torch.cuda.is_available():
2023-02-09 01:53:25 +00:00
name = 'cuda'
if attempt_gc:
torch.cuda.empty_cache() # may have performance implications
elif has_dml():
name = 'dml'
2023-02-09 01:53:25 +00:00
return name
def get_device(verbose=False):
name = get_device_name()
if verbose:
if name == 'cpu':
print("No hardware acceleration is available, falling back to CPU...")
else:
print(f"Hardware acceleration found: {name}")
if name == "dml":
import torch_directml
return torch_directml.device()
return torch.device(name)
def get_device_batch_size():
2023-02-09 20:42:38 +00:00
available = 1
name = get_device_name()
if name == "dml":
# there's nothing publically accessible in the DML API that exposes this
# there's a method to get currently used RAM statistics... as tiles
available = 1
elif name == "cuda":
2023-02-09 01:53:25 +00:00
_, available = torch.cuda.mem_get_info()
2023-02-09 20:42:38 +00:00
elif name == "cpu":
available = psutil.virtual_memory()[4]
vram = available / (1024 ** 3)
if vram > 18:
return 32
if vram > 16:
return 24
if vram > 14:
2023-02-09 20:42:38 +00:00
return 16
elif vram > 10:
2023-02-09 20:42:38 +00:00
return 8
elif vram > 7:
2023-02-09 20:42:38 +00:00
return 4
return 1
def get_device_count(name=get_device_name()):
2023-02-09 20:42:38 +00:00
if name == "cuda":
return torch.cuda.device_count()
if name == "dml":
import torch_directml
return torch_directml.device_count()
return 1
2023-02-09 20:42:38 +00:00
if has_dml():
_cumsum = torch.cumsum
_repeat_interleave = torch.repeat_interleave
_multinomial = torch.multinomial
_Tensor_new = torch.Tensor.new
_Tensor_cumsum = torch.Tensor.cumsum
_Tensor_repeat_interleave = torch.Tensor.repeat_interleave
_Tensor_multinomial = torch.Tensor.multinomial
torch.cumsum = lambda input, *args, **kwargs: ( _cumsum(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.repeat_interleave = lambda input, *args, **kwargs: ( _repeat_interleave(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.multinomial = lambda input, *args, **kwargs: ( _multinomial(input.to("cpu"), *args, **kwargs).to(input.device) )
torch.Tensor.new = lambda self, *args, **kwargs: ( _Tensor_new(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( _Tensor_cumsum(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.repeat_interleave = lambda self, *args, **kwargs: ( _Tensor_repeat_interleave(self.to("cpu"), *args, **kwargs).to(self.device) )
torch.Tensor.multinomial = lambda self, *args, **kwargs: ( _Tensor_multinomial(self.to("cpu"), *args, **kwargs).to(self.device) )