From 00be48670b5ba358e86aa5781131e9920d8f4def Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Thu, 9 Mar 2023 02:06:44 +0000
Subject: [PATCH] i am very smart

---
 tortoise/utils/device.py | 12 ++++--------
 1 file changed, 4 insertions(+), 8 deletions(-)

diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py
index e4b5e6d..f8960ec 100755
--- a/tortoise/utils/device.py
+++ b/tortoise/utils/device.py
@@ -3,6 +3,7 @@ import psutil
 import importlib
 
 DEVICE_OVERRIDE = None
+DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
 
 def has_dml():
     loader = importlib.find_loader('torch_directml')
@@ -60,14 +61,9 @@ def get_device_vram( name=get_device_name() ):
 def get_device_batch_size(name=None):
     vram = get_device_vram(name)
 
-    # I'll need to rework this better
-    # simply adding more tiers clearly is not a good way to go about it
-    if vram > 14:
-        return 16
-    elif vram > 10:
-        return 8
-    elif vram > 7:
-        return 4
+    for k, v in DEVICE_BATCH_SIZE_MAP:
+        if vram > k:
+            return v
     return 1
 
 def get_device_count(name=get_device_name()):