From 00be48670b5ba358e86aa5781131e9920d8f4def Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 9 Mar 2023 02:06:44 +0000 Subject: [PATCH] i am very smart --- tortoise/utils/device.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index e4b5e6d..f8960ec 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -3,6 +3,7 @@ import psutil import importlib DEVICE_OVERRIDE = None +DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)] def has_dml(): loader = importlib.find_loader('torch_directml') @@ -60,14 +61,9 @@ def get_device_vram( name=get_device_name() ): def get_device_batch_size(name=None): vram = get_device_vram(name) - # I'll need to rework this better - # simply adding more tiers clearly is not a good way to go about it - if vram > 14: - return 16 - elif vram > 10: - return 8 - elif vram > 7: - return 4 + for k, v in DEVICE_BATCH_SIZE_MAP: + if vram > k: + return v return 1 def get_device_count(name=get_device_name()):