I guess I can't easily toggle it outside of here, but it works
This commit is contained in:
parent
0ef8ab6872
commit
7bcedca771
|
@ -11,11 +11,25 @@ from torch.optim.adam import Adam
|
||||||
from torch.optim.adamw import AdamW
|
from torch.optim.adamw import AdamW
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
OVERRIDE_LINEAR = False
|
OVERRIDE_LINEAR = False
|
||||||
OVERRIDE_EMBEDDING = False
|
OVERRIDE_EMBEDDING = False
|
||||||
OVERRIDE_ADAM = False # True
|
OVERRIDE_ADAM = False # True
|
||||||
OVERRIDE_ADAMW = False # True
|
OVERRIDE_ADAMW = False # True
|
||||||
|
"""
|
||||||
|
|
||||||
USE_STABLE_EMBEDDING = True
|
USE_STABLE_EMBEDDING = True
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
OVERRIDE_LINEAR = False
|
||||||
|
OVERRIDE_EMBEDDING = False
|
||||||
|
OVERRIDE_ADAM = True
|
||||||
|
OVERRIDE_ADAMW = True
|
||||||
|
except Exception as e:
|
||||||
|
OVERRIDE_LINEAR = False
|
||||||
|
OVERRIDE_EMBEDDING = False
|
||||||
|
OVERRIDE_ADAM = False
|
||||||
|
OVERRIDE_ADAMW = False
|
||||||
|
|
||||||
if OVERRIDE_LINEAR:
|
if OVERRIDE_LINEAR:
|
||||||
from bitsandbytes.nn import Linear8bitLt as Linear
|
from bitsandbytes.nn import Linear8bitLt as Linear
|
||||||
|
|
Loading…
Reference in New Issue
Block a user