From cc36c0997c8711889ef8028002fc9e41abd5c5f0 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 7 Mar 2023 15:43:09 +0000 Subject: [PATCH] didn't get a chance to commit this this morning --- tortoise/utils/device.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index 3ab52e2..c2f7b5c 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -16,7 +16,7 @@ def set_device_name(name): global DEVICE_OVERRIDE DEVICE_OVERRIDE = name -def get_device_name(): +def get_device_name(attempt_gc=True): global DEVICE_OVERRIDE if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "": return DEVICE_OVERRIDE @@ -25,6 +25,8 @@ def get_device_name(): if torch.cuda.is_available(): name = 'cuda' + if attempt_gc: + torch.cuda.empty_cache() # may have performance implications elif has_dml(): name = 'dml' @@ -58,12 +60,16 @@ def get_device_batch_size(): elif name == "cpu": available = psutil.virtual_memory()[4] - availableGb = available / (1024 ** 3) - if availableGb > 14: + vram = available / (1024 ** 3) + if vram > 18: + return 32 + if vram > 16: + return 24 + if vram > 14: return 16 - elif availableGb > 10: + elif vram > 10: return 8 - elif availableGb > 7: + elif vram > 7: return 4 return 1