forked from mrq/tortoise-tts
i am very smart
This commit is contained in:
parent
bbeee40ab3
commit
00be48670b
|
@ -3,6 +3,7 @@ import psutil
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
DEVICE_OVERRIDE = None
|
DEVICE_OVERRIDE = None
|
||||||
|
DEVICE_BATCH_SIZE_MAP = [(14, 16), (10,8), (7,4)]
|
||||||
|
|
||||||
def has_dml():
|
def has_dml():
|
||||||
loader = importlib.find_loader('torch_directml')
|
loader = importlib.find_loader('torch_directml')
|
||||||
|
@ -60,14 +61,9 @@ def get_device_vram( name=get_device_name() ):
|
||||||
def get_device_batch_size(name=None):
|
def get_device_batch_size(name=None):
|
||||||
vram = get_device_vram(name)
|
vram = get_device_vram(name)
|
||||||
|
|
||||||
# I'll need to rework this better
|
for k, v in DEVICE_BATCH_SIZE_MAP:
|
||||||
# simply adding more tiers clearly is not a good way to go about it
|
if vram > k:
|
||||||
if vram > 14:
|
return v
|
||||||
return 16
|
|
||||||
elif vram > 10:
|
|
||||||
return 8
|
|
||||||
elif vram > 7:
|
|
||||||
return 4
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def get_device_count(name=get_device_name()):
|
def get_device_count(name=get_device_name()):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user