didn't get a chance to commit this this morning
This commit is contained in:
parent
fffea7fc03
commit
cc36c0997c
|
@ -16,7 +16,7 @@ def set_device_name(name):
|
|||
global DEVICE_OVERRIDE
|
||||
DEVICE_OVERRIDE = name
|
||||
|
||||
def get_device_name():
|
||||
def get_device_name(attempt_gc=True):
|
||||
global DEVICE_OVERRIDE
|
||||
if DEVICE_OVERRIDE is not None and DEVICE_OVERRIDE != "":
|
||||
return DEVICE_OVERRIDE
|
||||
|
@ -25,6 +25,8 @@ def get_device_name():
|
|||
|
||||
if torch.cuda.is_available():
|
||||
name = 'cuda'
|
||||
if attempt_gc:
|
||||
torch.cuda.empty_cache() # may have performance implications
|
||||
elif has_dml():
|
||||
name = 'dml'
|
||||
|
||||
|
@ -58,12 +60,16 @@ def get_device_batch_size():
|
|||
elif name == "cpu":
|
||||
available = psutil.virtual_memory()[4]
|
||||
|
||||
availableGb = available / (1024 ** 3)
|
||||
if availableGb > 14:
|
||||
vram = available / (1024 ** 3)
|
||||
if vram > 18:
|
||||
return 32
|
||||
if vram > 16:
|
||||
return 24
|
||||
if vram > 14:
|
||||
return 16
|
||||
elif availableGb > 10:
|
||||
elif vram > 10:
|
||||
return 8
|
||||
elif availableGb > 7:
|
||||
elif vram > 7:
|
||||
return 4
|
||||
return 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user