diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index 8e065b8..df7cbb5 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -82,7 +82,9 @@ def get_device_vram( name=get_device_name() ): return available / (1024 ** 3) def get_device_batch_size(name=None): - name = get_device_name() + if not name: + name = get_device_name() + vram = get_device_vram(name) if vram > 14: diff --git a/tortoise/utils/torch_intermediary.py b/tortoise/utils/torch_intermediary.py index fa509dc..6bc773f 100644 --- a/tortoise/utils/torch_intermediary.py +++ b/tortoise/utils/torch_intermediary.py @@ -22,17 +22,19 @@ import os USE_STABLE_EMBEDDING = False try: - import bitsandbytes as bnb OVERRIDE_LINEAR = False - OVERRIDE_EMBEDDING = True - OVERRIDE_ADAM = True - OVERRIDE_ADAMW = True + OVERRIDE_EMBEDDING = False + OVERRIDE_ADAM = False + OVERRIDE_ADAMW = False USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1' OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1' OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1' OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1' OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1' + + if OVERRIDE_LINEAR or OVERRIDE_EMBEDDING or OVERRIDE_ADAM or OVERRIDE_ADAMW: + import bitsandbytes as bnb except Exception as e: OVERRIDE_LINEAR = False OVERRIDE_EMBEDDING = False