didn't get a chance to commit this this morning
This commit is contained in:
parent
fffea7fc03
commit
cc36c0997c
|
@ -16,7 +16,7 @@ def set_device_name(name):
|
||||||
global DEVICE_OVERRIDE
|
global DEVICE_OVERRIDE
|
||||||
DEVICE_OVERRIDE = name
|
DEVICE_OVERRIDE = name
|
||||||
|
|
||||||
def get_device_name():
|
def get_device_name(attempt_gc=True):
|
||||||
global DEVICE_OVERRIDE
|
global DEVICE_OVERRIDE
|
||||||
if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
|
if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
|
||||||
return DEVICE_OVERRIDE
|
return DEVICE_OVERRIDE
|
||||||
|
@ -25,6 +25,8 @@ def get_device_name():
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
name = 'cuda'
|
name = 'cuda'
|
||||||
|
if attempt_gc:
|
||||||
|
torch.cuda.empty_cache() # may have performance implications
|
||||||
elif has_dml():
|
elif has_dml():
|
||||||
name = 'dml'
|
name = 'dml'
|
||||||
|
|
||||||
|
@ -58,12 +60,16 @@ def get_device_batch_size():
|
||||||
elif name == "cpu":
|
elif name == "cpu":
|
||||||
available = psutil.virtual_memory()[4]
|
available = psutil.virtual_memory()[4]
|
||||||
|
|
||||||
availableGb = available / (1024 ** 3)
|
vram = available / (1024 ** 3)
|
||||||
if availableGb > 14:
|
if vram > 18:
|
||||||
|
return 32
|
||||||
|
if vram > 16:
|
||||||
|
return 24
|
||||||
|
if vram > 14:
|
||||||
return 16
|
return 16
|
||||||
elif availableGb > 10:
|
elif vram > 10:
|
||||||
return 8
|
return 8
|
||||||
elif availableGb > 7:
|
elif vram > 7:
|
||||||
return 4
|
return 4
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user