forked from mrq/tortoise-tts
expose VRAM easily
This commit is contained in:
parent
3dd5cad324
commit
6410df569b
|
@ -47,19 +47,19 @@ def get_device(verbose=False):
|
||||||
|
|
||||||
return torch.device(name)
|
return torch.device(name)
|
||||||
|
|
||||||
def get_device_batch_size():
|
def get_device_vram( name=get_device_name() ):
|
||||||
available = 1
|
available = 1
|
||||||
name = get_device_name()
|
|
||||||
|
|
||||||
if name == "dml":
|
if name == "cuda":
|
||||||
# there's nothing publically accessible in the DML API that exposes this
|
|
||||||
# there's a method to get currently used RAM statistics... as tiles
|
|
||||||
available = 1
|
|
||||||
elif name == "cuda":
|
|
||||||
_, available = torch.cuda.mem_get_info()
|
_, available = torch.cuda.mem_get_info()
|
||||||
elif name == "cpu":
|
elif name == "cpu":
|
||||||
available = psutil.virtual_memory()[4]
|
available = psutil.virtual_memory()[4]
|
||||||
|
|
||||||
|
return available
|
||||||
|
|
||||||
|
def get_device_batch_size(name=None):
|
||||||
|
available = get_device_vram(name)
|
||||||
|
|
||||||
vram = available / (1024 ** 3)
|
vram = available / (1024 ** 3)
|
||||||
# I'll need to rework this better
|
# I'll need to rework this better
|
||||||
# simply adding more tiers clearly is not a good way to go about it
|
# simply adding more tiers clearly is not a good way to go about it
|
||||||
|
|
Loading…
Reference in New Issue
Block a user