diff --git a/codes/torch_intermediary/__init__.py b/codes/torch_intermediary/__init__.py index d220cd3b..fa509dc7 100644 --- a/codes/torch_intermediary/__init__.py +++ b/codes/torch_intermediary/__init__.py @@ -18,6 +18,8 @@ OVERRIDE_ADAM = False # True OVERRIDE_ADAMW = False # True """ +import os + USE_STABLE_EMBEDDING = False try: import bitsandbytes as bnb @@ -25,6 +27,12 @@ try: OVERRIDE_EMBEDDING = True OVERRIDE_ADAM = True OVERRIDE_ADAMW = True + + USE_STABLE_EMBEDDING = os.environ.get('BITSANDBYTES_USE_STABLE_EMBEDDING', '1' if USE_STABLE_EMBEDDING else '0') == '1' + OVERRIDE_LINEAR = os.environ.get('BITSANDBYTES_OVERRIDE_LINEAR', '1' if OVERRIDE_LINEAR else '0') == '1' + OVERRIDE_EMBEDDING = os.environ.get('BITSANDBYTES_OVERRIDE_EMBEDDING', '1' if OVERRIDE_EMBEDDING else '0') == '1' + OVERRIDE_ADAM = os.environ.get('BITSANDBYTES_OVERRIDE_ADAM', '1' if OVERRIDE_ADAM else '0') == '1' + OVERRIDE_ADAMW = os.environ.get('BITSANDBYTES_OVERRIDE_ADAMW', '1' if OVERRIDE_ADAMW else '0') == '1' except Exception as e: OVERRIDE_LINEAR = False OVERRIDE_EMBEDDING = False