17 lines
681 B
Python
17 lines
681 B
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from bitsandbytes.cextension import COMPILED_WITH_CUDA
|
|
|
|
if COMPILED_WITH_CUDA:
|
|
from .adam import Adam, Adam8bit, Adam32bit
|
|
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
|
from .sgd import SGD, SGD8bit, SGD32bit
|
|
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
|
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
|
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
|
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
|
from .optimizer import GlobalOptimManager
|