disable BNB for inferencing by default because I'm pretty sure it makes zero differences (can be force enabled with env vars if you'r erelying on this for some reason)

remotes/1710274000886183304/main
mrq 2023-04-29 00:38:18 +07:00
parent f025470d60
commit 2f7d9ab932
2 changed files with 9 additions and 5 deletions

@ -82,7 +82,9 @@ def get_device_vram( name=get_device_name() ):
return available / (1024 ** 3)
def get_device_batch_size(name=None):
name = get_device_name()
if not name:
name = get_device_name()
vram = get_device_vram(name)
if vram > 14:

@ -22,17 +22,19 @@ import os
USE_STABLE_EMBEDDING = False
try:
import bitsandbytes as bnb
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = True
OVERRIDE_ADAM = True
OVERRIDE_ADAMW = True
OVERRIDE_EMBEDDING = False
OVERRIDE_ADAM = False
OVERRIDE_ADAMW = False
USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1'
OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1'
OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1'
OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1'
OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1'
if OVERRIDE_LINEAR or OVERRIDE_EMBEDDING or OVERRIDE_ADAM or OVERRIDE_ADAMW:
import bitsandbytes as bnb
except Exception as e:
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = False