Remove unused imports, fix NotImplementedError

This commit is contained in:
Max Ryabinin 2022-06-30 18:14:20 +03:00
parent 4e60e7dc62
commit 33efe4a09f
9 changed files with 14 additions and 26 deletions

View File

@ -2,13 +2,13 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import os
import random
import math
import ctypes as ct
from typing import Tuple
import torch
from torch import Tensor
from typing import Tuple
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
name2qmap = {}

View File

@ -7,7 +7,6 @@ import torch
from typing import Optional
from torch import Tensor
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from bitsandbytes.optim import GlobalOptimManager

View File

@ -2,11 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer1State
torch.optim.Adagrad
class Adagrad(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):

View File

@ -2,9 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer2State
import bitsandbytes.functional as F
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,

View File

@ -12,7 +12,7 @@ class LARS(Optimizer1State):
weight_decay=0, nesterov=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!')
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
@ -21,7 +21,7 @@ class LARS8bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!')
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
@ -30,7 +30,7 @@ class LARS32bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
if momentum == 0:
raise NotImplementError(f'LARS without momentum is not supported!')
raise NotImplementedError(f'LARS without momentum is not supported!')
super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)

View File

@ -2,16 +2,15 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from bitsandbytes.optim.optimizer import Optimizer1State
class RMSprop(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
raise NotImplementError(f'Centered RMSprop is not supported!')
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
@ -19,9 +18,9 @@ class RMSprop8bit(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
raise NotImplementError(f'Centered RMSprop is not supported!')
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
@ -30,7 +29,7 @@ class RMSprop32bit(Optimizer1State):
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if alpha == 0:
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
raise NotImplementError(f'Centered RMSprop is not supported!')
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,

View File

@ -9,7 +9,7 @@ class SGD(Optimizer1State):
weight_decay=0, nesterov=False, optim_bits=32, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!')
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
@ -18,7 +18,7 @@ class SGD8bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!')
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
@ -27,6 +27,6 @@ class SGD32bit(Optimizer1State):
weight_decay=0, nesterov=False, args=None,
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
if momentum == 0:
raise NotImplementError(f'SGD without momentum is not supported!')
raise NotImplementedError(f'SGD without momentum is not supported!')
super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)

View File

@ -6,10 +6,6 @@ import pytest
import torch
import bitsandbytes as bnb
from itertools import product
from bitsandbytes import functional as F
@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
def test_embeddings(embcls):

View File

@ -7,7 +7,6 @@ import time
import shutil
import uuid
import pytest
import ctypes
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F