diff --git a/codes/torch_intermediary/__init__.py b/codes/torch_intermediary/__init__.py index 70556a2c..d220cd3b 100644 --- a/codes/torch_intermediary/__init__.py +++ b/codes/torch_intermediary/__init__.py @@ -18,11 +18,11 @@ OVERRIDE_ADAM = False # True OVERRIDE_ADAMW = False # True """ -USE_STABLE_EMBEDDING = True +USE_STABLE_EMBEDDING = False try: import bitsandbytes as bnb OVERRIDE_LINEAR = False - OVERRIDE_EMBEDDING = False + OVERRIDE_EMBEDDING = True OVERRIDE_ADAM = True OVERRIDE_ADAMW = True except Exception as e: @@ -40,7 +40,7 @@ if OVERRIDE_EMBEDDING: if USE_STABLE_EMBEDDING: from bitsandbytes.nn import StableEmbedding as Embedding else: - from bitsandbytes.nn import Embedding as Embedding + from bitsandbytes.nn.modules import Embedding as Embedding else: from torch.nn import Embedding