|
|
|
@ -16,7 +16,7 @@ def set_device_name(name):
|
|
|
|
|
global DEVICE_OVERRIDE
|
|
|
|
|
DEVICE_OVERRIDE = name
|
|
|
|
|
|
|
|
|
|
def get_device_name():
|
|
|
|
|
def get_device_name(attempt_gc=True):
|
|
|
|
|
global DEVICE_OVERRIDE
|
|
|
|
|
if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
|
|
|
|
|
return DEVICE_OVERRIDE
|
|
|
|
@ -25,6 +25,8 @@ def get_device_name():
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
name = 'cuda'
|
|
|
|
|
if attempt_gc:
|
|
|
|
|
torch.cuda.empty_cache() # may have performance implications
|
|
|
|
|
elif has_dml():
|
|
|
|
|
name = 'dml'
|
|
|
|
|
|
|
|
|
@ -58,12 +60,16 @@ def get_device_batch_size():
|
|
|
|
|
elif name == "cpu":
|
|
|
|
|
available = psutil.virtual_memory()[4]
|
|
|
|
|
|
|
|
|
|
availableGb = available / (1024 ** 3)
|
|
|
|
|
if availableGb > 14:
|
|
|
|
|
vram = available / (1024 ** 3)
|
|
|
|
|
if vram > 18:
|
|
|
|
|
return 32
|
|
|
|
|
if vram > 16:
|
|
|
|
|
return 24
|
|
|
|
|
if vram > 14:
|
|
|
|
|
return 16
|
|
|
|
|
elif availableGb > 10:
|
|
|
|
|
elif vram > 10:
|
|
|
|
|
return 8
|
|
|
|
|
elif availableGb > 7:
|
|
|
|
|
elif vram > 7:
|
|
|
|
|
return 4
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|