Fix imports

This commit is contained in:
Max Ryabinin 2022-07-01 17:25:44 +03:00
parent 8258b4364a
commit e4cf33f2a3
3 changed files with 4 additions and 3 deletions

View File

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .nn import modules from .nn import modules
from cextension import COMPILED_WITH_CUDA from .cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA: if COMPILED_WITH_CUDA:
from .optim import adam from .optim import adam

View File

@ -9,7 +9,7 @@ from typing import Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from cextension import lib, COMPILED_WITH_CUDA from .cextension import lib, COMPILED_WITH_CUDA
name2qmap = {} name2qmap = {}

View File

@ -13,4 +13,5 @@ if COMPILED_WITH_CUDA:
from .lamb import LAMB, LAMB8bit, LAMB32bit from .lamb import LAMB, LAMB8bit, LAMB32bit
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager