From 2f7d9ab93265ad2d3005d9f67872d3c2c88a76ab Mon Sep 17 00:00:00 2001 From: mrq Date: Sat, 29 Apr 2023 00:38:18 +0000 Subject: [PATCH] disable BNB for inferencing by default because I'm pretty sure it makes zero differences (can be force enabled with env vars if you'r erelying on this for some reason) --- tortoise/utils/device.py | 4 +++- tortoise/utils/torch_intermediary.py | 10 ++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tortoise/utils/device.py b/tortoise/utils/device.py index 8e065b8..df7cbb5 100755 --- a/tortoise/utils/device.py +++ b/tortoise/utils/device.py @@ -82,7 +82,9 @@ def get_device_vram( name=get_device_name() ): return available / (1024 ** 3) def get_device_batch_size(name=None): - name = get_device_name() + if not name: + name = get_device_name() + vram = get_device_vram(name) if vram > 14: diff --git a/tortoise/utils/torch_intermediary.py b/tortoise/utils/torch_intermediary.py index fa509dc..6bc773f 100644 --- a/tortoise/utils/torch_intermediary.py +++ b/tortoise/utils/torch_intermediary.py @@ -22,17 +22,19 @@ import os USE_STABLE_EMBEDDING = False try: - import bitsandbytes as bnb OVERRIDE_LINEAR = False - OVERRIDE_EMBEDDING = True - OVERRIDE_ADAM = True - OVERRIDE_ADAMW = True + OVERRIDE_EMBEDDING = False + OVERRIDE_ADAM = False + OVERRIDE_ADAMW = False USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1' OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1' OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1' OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1' OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1' + + if OVERRIDE_LINEAR or OVERRIDE_EMBEDDING or OVERRIDE_ADAM or OVERRIDE_ADAMW: + import bitsandbytes as bnb except Exception as e: OVERRIDE_LINEAR = False OVERRIDE_EMBEDDING = False