Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
e650800447 |
|
@ -50,21 +50,32 @@ def get_device_batch_size():
|
||||||
name = get_device_name()
|
name = get_device_name()
|
||||||
|
|
||||||
if name == "dml":
|
if name == "dml":
|
||||||
# there's nothing publically accessible in the DML API that exposes this
|
# there's nothing publicly accessible in the DML API that exposes this
|
||||||
# there's a method to get currently used RAM statistics... as tiles
|
# there's a method to get currently used RAM statistics... as tiles
|
||||||
available = 1
|
available = 1
|
||||||
elif name == "cuda":
|
elif name == "cuda":
|
||||||
_, available = torch.cuda.mem_get_info()
|
_,available = torch.cuda.mem_get_info()
|
||||||
elif name == "cpu":
|
elif name == "cpu":
|
||||||
available = psutil.virtual_memory()[4]
|
available = psutil.virtual_memory()[4]
|
||||||
|
|
||||||
availableGb = available / (1024 ** 3)
|
availableGb = available / (1024 ** 3)
|
||||||
if availableGb > 14:
|
|
||||||
|
print(f"Total device memory available: {availableGb}")
|
||||||
|
if availableGb > 18:
|
||||||
|
print(f"Setting AutoRegressive Batch Size to: 32")
|
||||||
|
print(f"Damn. Nice GPU Dude.")
|
||||||
|
return 32
|
||||||
|
elif availableGb > 14:
|
||||||
|
print(f"Setting AutoRegressive Batch Size to: 16")
|
||||||
return 16
|
return 16
|
||||||
elif availableGb > 10:
|
elif availableGb > 10:
|
||||||
|
print(f"Setting AutoRegressive Batch Size to: 8")
|
||||||
return 8
|
return 8
|
||||||
elif availableGb > 7:
|
elif availableGb > 7:
|
||||||
|
print(f"Setting AutoRegressive Batch Size to: 4")
|
||||||
return 4
|
return 4
|
||||||
|
print(f"Setting AutoRegressive Batch Size to: 1")
|
||||||
|
print(f"Don't cry about it if it doesn't work.")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def get_device_count(name=get_device_name()):
|
def get_device_count(name=get_device_name()):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user