working Embedding override

This commit is contained in:
mrq 2023-02-23 07:28:27 +00:00
parent 94aefa3e4c
commit 1433b7c0ea

View File

@ -18,11 +18,11 @@ OVERRIDE_ADAM = False # True
OVERRIDE_ADAMW = False # True
"""
USE_STABLE_EMBEDDING = True
USE_STABLE_EMBEDDING = False
try:
import bitsandbytes as bnb
OVERRIDE_LINEAR = False
OVERRIDE_EMBEDDING = False
OVERRIDE_EMBEDDING = True
OVERRIDE_ADAM = True
OVERRIDE_ADAMW = True
except Exception as e:
@ -40,7 +40,7 @@ if OVERRIDE_EMBEDDING:
if USE_STABLE_EMBEDDING:
from bitsandbytes.nn import StableEmbedding as Embedding
else:
from bitsandbytes.nn import Embedding as Embedding
from bitsandbytes.nn.modules import Embedding as Embedding
else:
from torch.nn import Embedding