From 7bcedca771335f2d62339d04ee6cd094cd5717f0 Mon Sep 17 00:00:00 2001 From: mrq Date: Thu, 23 Feb 2023 07:02:06 +0000 Subject: [PATCH] I guess I can't easily toggle it outside of here, but it works --- codes/torch_intermediary/__init__.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/codes/torch_intermediary/__init__.py b/codes/torch_intermediary/__init__.py index 0fb0cbca..5e050226 100644 --- a/codes/torch_intermediary/__init__.py +++ b/codes/torch_intermediary/__init__.py @@ -11,11 +11,25 @@ from torch.optim.adam import Adam from torch.optim.adamw import AdamW """ +""" OVERRIDE_LINEAR = False OVERRIDE_EMBEDDING = False OVERRIDE_ADAM = False # True OVERRIDE_ADAMW = False # True +""" + USE_STABLE_EMBEDDING = True +try: + import bitsandbytes as bnb + OVERRIDE_LINEAR = False + OVERRIDE_EMBEDDING = False + OVERRIDE_ADAM = True + OVERRIDE_ADAMW = True +except Exception as e: + OVERRIDE_LINEAR = False + OVERRIDE_EMBEDDING = False + OVERRIDE_ADAM = False + OVERRIDE_ADAMW = False if OVERRIDE_LINEAR: from bitsandbytes.nn import Linear8bitLt as Linear