diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index f8960ec..366a3af 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -61,9 +61,17 @@ def get_device_vram( name=get_device_name() ): def get_device_batch_size(name=None): vram = get_device_vram(name) + if vram > 14: + return 16 + elif vram > 10: + return 8 + elif vram > 7: + return 4 + """ for k, v in DEVICE_BATCH_SIZE_MAP: if vram > k: return v + """ return 1 def get_device_count(name=get_device_name()):