diff --git a/codes/torch_intermediary/__init__.py b/codes/torch_intermediary/__init__.py index 0fb0cbca..5e050226 100644 --- a/codes/torch_intermediary/__init__.py +++ b/codes/torch_intermediary/__init__.py @@ -11,11 +11,25 @@ from torch.optim.adam import Adam from torch.optim.adamw import AdamW """ +""" OVERRIDE_LINEAR = False OVERRIDE_EMBEDDING = False OVERRIDE_ADAM = False # True OVERRIDE_ADAMW = False # True +""" + USE_STABLE_EMBEDDING = True +try: + import bitsandbytes as bnb + OVERRIDE_LINEAR = False + OVERRIDE_EMBEDDING = False + OVERRIDE_ADAM = True + OVERRIDE_ADAMW = True +except Exception as e: + OVERRIDE_LINEAR = False + OVERRIDE_EMBEDDING = False + OVERRIDE_ADAM = False + OVERRIDE_ADAMW = False if OVERRIDE_LINEAR: from bitsandbytes.nn import Linear8bitLt as Linear