diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index ae93132..cf899e0 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -47,19 +47,19 @@ def get_device(verbose=False): return torch.device(name) -def get_device_batch_size(): +def get_device_vram( name=get_device_name() ): available = 1 - name = get_device_name() - if name == "dml": - # there's nothing publically accessible in the DML API that exposes this - # there's a method to get currently used RAM statistics... as tiles - available = 1 - elif name == "cuda": + if name == "cuda": _, available = torch.cuda.mem_get_info() elif name == "cpu": available = psutil.virtual_memory()[4] + return available + +def get_device_batch_size(name=None): + available = get_device_vram(name) + vram = available / (1024 ** 3) # I'll need to rework this better # simply adding more tiers clearly is not a good way to go about it