Simplify statements into equivalent, modern variants

via pyupgrade --py37-plus. The changes e.g. are subclassing from object, calling super() with super(ThisClass, self), or old-style syntax formatting.
This commit is contained in:
Tom Aarsen 2022-10-27 13:14:13 +02:00
parent 1eec77d34c
commit 0b078403ee
17 changed files with 103 additions and 105 deletions

View File

@ -18,7 +18,7 @@ tensor = torch.Tensor
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
are less systematic and occur with low frequency. are less systematic and occur with low frequency.
""" """
class GlobalOutlierPooler(object): class GlobalOutlierPooler:
_instance = None _instance = None
def __init__(self): def __init__(self):

View File

@ -5,7 +5,7 @@ from pathlib import Path
from warnings import warn from warnings import warn
class CUDASetup(object): class CUDASetup:
_instance = None _instance = None
def __init__(self): def __init__(self):

View File

@ -127,7 +127,7 @@ def evaluate_cuda_setup():
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True) cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
return binary_name return binary_name
cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}")) cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
cuda = get_cuda_lib_handle() cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda) cc = get_compute_capability(cuda)
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")

View File

@ -82,7 +82,7 @@ if COMPILED_WITH_CUDA:
) )
class CUBLAS_Context(object): class CUBLAS_Context:
_instance = None _instance = None
def __init__(self): def __init__(self):
@ -112,7 +112,7 @@ class CUBLAS_Context(object):
return self.context[device.index] return self.context[device.index]
class Cusparse_Context(object): class Cusparse_Context:
_instance = None _instance = None
def __init__(self): def __init__(self):
@ -1417,7 +1417,7 @@ def get_colrow_absmax(
return row_stats, col_stats, nnz_block_ptr return row_stats, col_stats, nnz_block_ptr
class COOSparseTensor(object): class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values): def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32 assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32 assert colidx.dtype == torch.int32
@ -1434,7 +1434,7 @@ class COOSparseTensor(object):
self.values = values self.values = values
class CSRSparseTensor(object): class CSRSparseTensor:
def __init__(self, rows, cols, nnz, rowptr, colidx, values): def __init__(self, rows, cols, nnz, rowptr, colidx, values):
assert rowptr.dtype == torch.int32 assert rowptr.dtype == torch.int32
assert colidx.dtype == torch.int32 assert colidx.dtype == torch.int32
@ -1451,7 +1451,7 @@ class CSRSparseTensor(object):
self.values = values self.values = values
class CSCSparseTensor(object): class CSCSparseTensor:
def __init__(self, rows, cols, nnz, colptr, rowidx, values): def __init__(self, rows, cols, nnz, colptr, rowidx, values):
assert colptr.dtype == torch.int32 assert colptr.dtype == torch.int32
assert rowidx.dtype == torch.int32 assert rowidx.dtype == torch.int32

View File

@ -39,7 +39,7 @@ class StableEmbedding(torch.nn.Embedding):
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
) -> None: ) -> None:
super(StableEmbedding, self).__init__( super().__init__(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
padding_idx, padding_idx,
@ -96,7 +96,7 @@ class Embedding(torch.nn.Embedding):
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
) -> None: ) -> None:
super(Embedding, self).__init__( super().__init__(
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
padding_idx, padding_idx,
@ -225,7 +225,7 @@ class Linear8bitLt(nn.Linear):
threshold=0.0, threshold=0.0,
index=None, index=None,
): ):
super(Linear8bitLt, self).__init__( super().__init__(
input_features, output_features, bias input_features, output_features, bias
) )
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()

View File

@ -21,18 +21,18 @@ class Adagrad(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,
@ -63,19 +63,19 @@ class Adagrad8bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise assert block_wise
super(Adagrad8bit, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,
@ -106,18 +106,18 @@ class Adagrad32bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError("Initial accumulator value != 0.0 not supported!") raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError("Lr Decay != 0.0 not supported!") raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad32bit, self).__init__( super().__init__(
"adagrad", "adagrad",
params, params,
lr, lr,

View File

@ -28,7 +28,7 @@ class Adam(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -57,7 +57,7 @@ class Adam8bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam8bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -86,7 +86,7 @@ class Adam32bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(Adam32bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -146,7 +146,7 @@ class AnalysisAdam(torch.optim.Optimizer):
weight_decay=weight_decay, weight_decay=weight_decay,
amsgrad=amsgrad, amsgrad=amsgrad,
) )
super(AnalysisAdam, self).__init__(params, defaults) super().__init__(params, defaults)
self.analysis = bnb_analysis self.analysis = bnb_analysis
self.savedir = savedir self.savedir = savedir

View File

@ -20,7 +20,7 @@ class AdamW(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -49,7 +49,7 @@ class AdamW8bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW8bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,
@ -78,7 +78,7 @@ class AdamW32bit(Optimizer2State):
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
super(AdamW32bit, self).__init__( super().__init__(
"adam", "adam",
params, params,
lr, lr,

View File

@ -23,7 +23,7 @@ class LAMB(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,
@ -56,7 +56,7 @@ class LAMB8bit(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB8bit, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,
@ -89,7 +89,7 @@ class LAMB32bit(Optimizer2State):
block_wise=False, block_wise=False,
max_unorm=1.0, max_unorm=1.0,
): ):
super(LAMB32bit, self).__init__( super().__init__(
"lamb", "lamb",
params, params,
lr, lr,

View File

@ -27,7 +27,7 @@ class LARS(Optimizer1State):
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" f"LARS without momentum is not supported!"
) )
super(LARS, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
@ -61,7 +61,7 @@ class LARS8bit(Optimizer1State):
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" f"LARS without momentum is not supported!"
) )
super(LARS8bit, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
@ -95,7 +95,7 @@ class LARS32bit(Optimizer1State):
raise NotImplementedError( raise NotImplementedError(
f"LARS without momentum is not supported!" f"LARS without momentum is not supported!"
) )
super(LARS32bit, self).__init__( super().__init__(
"lars", "lars",
params, params,
lr, lr,
@ -123,12 +123,12 @@ class PytorchLARS(Optimizer):
max_unorm=0.02, max_unorm=0.02,
): ):
if lr < 0.0: if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if momentum < 0.0: if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum)) raise ValueError(f"Invalid momentum value: {momentum}")
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict( defaults = dict(
@ -143,10 +143,10 @@ class PytorchLARS(Optimizer):
raise ValueError( raise ValueError(
"Nesterov momentum requires a momentum and zero dampening" "Nesterov momentum requires a momentum and zero dampening"
) )
super(PytorchLARS, self).__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state) super().__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault("nesterov", False) group.setdefault("nesterov", False)

View File

@ -12,13 +12,13 @@ import torch
import bitsandbytes.functional as F import bitsandbytes.functional as F
class MockArgs(object): class MockArgs:
def __init__(self, initial_data): def __init__(self, initial_data):
for key in initial_data: for key in initial_data:
setattr(self, key, initial_data[key]) setattr(self, key, initial_data[key])
class GlobalOptimManager(object): class GlobalOptimManager:
_instance = None _instance = None
def __init__(self): def __init__(self):
@ -93,13 +93,12 @@ class GlobalOptimManager(object):
class Optimizer8bit(torch.optim.Optimizer): class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32): def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults) super().__init__(params, defaults)
self.initialized = False self.initialized = False
self.name2qmap = {} self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance() self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set( self.non_castable_tensor_keys = {
[
"qmap1", "qmap1",
"qmap2", "qmap2",
"max1", "max1",
@ -112,8 +111,7 @@ class Optimizer8bit(torch.optim.Optimizer):
"absmax1", "absmax1",
"absmax2", "absmax2",
"unorm_vec", "unorm_vec",
] }
)
if optim_bits == 8: if optim_bits == 8:
self.fill_qmap() self.fill_qmap()
@ -123,7 +121,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state): def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state) super().__setstate__(state)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r"""Loads the optimizer state. r"""Loads the optimizer state.
@ -155,8 +153,8 @@ class Optimizer8bit(torch.optim.Optimizer):
id_map = { id_map = {
old_id: p old_id: p
for old_id, p in zip( for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)), chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable((g["params"] for g in groups)), chain.from_iterable(g["params"] for g in groups),
) )
} }
@ -310,9 +308,9 @@ class Optimizer2State(Optimizer8bit):
skip_zeros=False, skip_zeros=False,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
if isinstance(betas, str): if isinstance(betas, str):
# format: '(beta1, beta2)' # format: '(beta1, beta2)'
betas = betas.replace("(", "").replace(")", "").strip().split(",") betas = betas.replace("(", "").replace(")", "").strip().split(",")
@ -324,10 +322,10 @@ class Optimizer2State(Optimizer8bit):
) )
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits)
if args is None: if args is None:
args = {} args = {}
@ -542,9 +540,9 @@ class Optimizer1State(Optimizer8bit):
skip_zeros=False, skip_zeros=False,
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError(f"Invalid epsilon value: {eps}")
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError( raise ValueError(
@ -552,10 +550,10 @@ class Optimizer1State(Optimizer8bit):
) )
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError( raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits)
if args is None: if args is None:
args = {} args = {}

View File

@ -27,7 +27,7 @@ class RMSprop(Optimizer1State):
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
@ -63,7 +63,7 @@ class RMSprop8bit(Optimizer1State):
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
@ -100,7 +100,7 @@ class RMSprop32bit(Optimizer1State):
) )
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,

View File

@ -22,7 +22,7 @@ class SGD(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,
@ -53,7 +53,7 @@ class SGD8bit(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD8bit, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,
@ -84,7 +84,7 @@ class SGD32bit(Optimizer1State):
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"SGD without momentum is not supported!") raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD32bit, self).__init__( super().__init__(
"momentum", "momentum",
params, params,
lr, lr,

View File

@ -27,7 +27,7 @@ str_values = list(
) )
) )
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
*vals *vals
) )
for vals in str_values for vals in str_values
@ -286,7 +286,7 @@ str_values = list(
has_bias has_bias
) )
) )
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values] names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -26,7 +26,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
class FFN(torch.nn.Module): class FFN(torch.nn.Module):
def __init__(self, input_features, hidden_size, bias=True): def __init__(self, input_features, hidden_size, bias=True):
super(FFN, self).__init__() super().__init__()
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias) self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias) self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
@ -40,7 +40,7 @@ class FFN(torch.nn.Module):
return x return x
class Timer(object): class Timer:
def __init__(self): def __init__(self):
self.starts = {} self.starts = {}
self.ends = {} self.ends = {}
@ -67,7 +67,7 @@ class Timer(object):
self.ends.pop(name) self.ends.pop(name)
if print_ms and name in self.agg: if print_ms and name in self.agg:
print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0)) print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
return self.agg[name] return self.agg[name]
@ -289,7 +289,7 @@ batched = [False, True]
values = list(product(dim1, dim2, methods, batched)) values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched)) values_names = list(product(dim1, dim2, method_names, batched))
names = [ names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
for vals in values_names for vals in values_names
] ]
@ -347,7 +347,7 @@ seq_dim = torch.randint(16, 256, size=(n,)).tolist()
transpose = [(False, False), (False, True), (True, False), (True, True)] transpose = [(False, False), (False, True), (True, False), (True, True)]
values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
names = [ names = [
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals) "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
for vals in values for vals in values
] ]
@ -412,7 +412,7 @@ hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist() batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim)) values = list(product(seq_dim, hidden_dim, batch_dim))
names = [ names = [
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
] ]
@ -444,7 +444,7 @@ batch_dim = torch.randint(2, 16, size=(n,)).tolist()
transpose = [False, True] transpose = [False, True]
values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
names = [ names = [
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals) "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
for vals in values for vals in values
] ]
@ -529,7 +529,7 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)] transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose)) values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
for vals in values for vals in values
] ]
@ -567,7 +567,7 @@ dim1 = torch.randint(1, 64, size=(n,)).tolist()
dim2 = torch.randint(32, 128, size=(n,)).tolist() dim2 = torch.randint(32, 128, size=(n,)).tolist()
dim3 = torch.randint(32, 256, size=(n,)).tolist() dim3 = torch.randint(32, 256, size=(n,)).tolist()
values = list(product(dim1, dim2, dim3)) values = list(product(dim1, dim2, dim3))
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
@ -596,7 +596,7 @@ transpose = [False]
dims = [2, 3] dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values] names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names) @pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
@ -678,7 +678,7 @@ ldb = [0]
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
for vals in values for vals in values
] ]
@ -726,7 +726,7 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims)) values = list(product(dim1, dim2, dim3, dim4, dims))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
for vals in values for vals in values
] ]
@ -784,7 +784,7 @@ values = [
# values = list(product(batch, seq, model, hidden)) # values = list(product(batch, seq, model, hidden))
names = [ names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
] ]
@ -952,7 +952,7 @@ dims = (2,)
formatB = ["col_turing", "col_ampere"] formatB = ["col_turing", "col_ampere"]
has_bias = [True, False] has_bias = [True, False]
values = list(product(dim1, dim4, dims, formatB, has_bias)) values = list(product(dim1, dim4, dims, formatB, has_bias))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) @pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
@ -1002,7 +1002,7 @@ dim2 = [1 * 1024]
dims = (2,) dims = (2,)
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dims)) values = list(product(dim1, dim2, dims))
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
@ -1058,7 +1058,7 @@ dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1105,7 +1105,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1149,7 +1149,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1224,7 +1224,7 @@ inner = [12288 * 4, 4096 * 4]
dim4 = [12288, 4096] dim4 = [12288, 4096]
values = list(zip(dim1, dim4, inner)) values = list(zip(dim1, dim4, inner))
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
@ -1290,7 +1290,7 @@ values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
) )
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format( "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
*vals *vals
) )
for vals in values for vals in values
@ -1341,7 +1341,7 @@ a_order = ["col_turing"]
out_order = ["row"] out_order = ["row"]
values = list(product(dim1, dim2, dtype, a_order, out_order)) values = list(product(dim1, dim2, dtype, a_order, out_order))
names = [ names = [
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals) "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
for vals in values for vals in values
] ]
@ -1367,7 +1367,7 @@ dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
# dim2 = [5] # dim2 = [5]
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1404,7 +1404,7 @@ dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
# dim2 = [11] # dim2 = [11]
transposed_B = [False, True] transposed_B = [False, True]
values = list(product(dim1, dim2, transposed_B)) values = list(product(dim1, dim2, transposed_B))
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) @pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
@ -1485,7 +1485,7 @@ n = 2
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
values = list(product(dim1, dim2)) values = list(product(dim1, dim2))
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2", values, ids=names) @pytest.mark.parametrize("dim1, dim2", values, ids=names)
@ -1550,7 +1550,7 @@ dtype = [torch.float16]
out_function = ["zeros", "ones"] out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function)) values = list(product(dim1, dim2, dtype, out_function))
names = [ names = [
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
] ]
@ -1678,7 +1678,7 @@ dim2 = [2048]
# dim2 = [2] # dim2 = [2]
dtype = [torch.int8] dtype = [torch.int8]
values = list(product(dim1, dim2, dtype)) values = list(product(dim1, dim2, dtype))
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
@ -1794,7 +1794,7 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 5140, 4*5140)) # values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288)) #values.append((batch_size, seqdim, 12288, 4*12288))
names = [ names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
] ]

View File

@ -7,7 +7,7 @@ from torch import nn
import bitsandbytes as bnb import bitsandbytes as bnb
class MockArgs(object): class MockArgs:
def __init__(self, initial_data): def __init__(self, initial_data):
for key in initial_data: for key in initial_data:
setattr(self, key, initial_data[key]) setattr(self, key, initial_data[key])
@ -15,7 +15,7 @@ class MockArgs(object):
class MLP8bit(torch.nn.Module): class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
super(MLP8bit, self).__init__() super().__init__()
self.fc1 = bnb.nn.Linear8bitLt( self.fc1 = bnb.nn.Linear8bitLt(
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
threshold=threshold threshold=threshold
@ -289,7 +289,7 @@ class LinearFunction(torch.autograd.Function):
class Linear8bit(nn.Module): class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None): def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__() super().__init__()
self.input_features = input_features self.input_features = input_features
self.output_features = output_features self.output_features = output_features
self.args = args self.args = args
@ -312,7 +312,7 @@ class Linear8bit(nn.Module):
threshold = [0.0, 3.0] threshold = [0.0, 3.0]
values = threshold values = threshold
names = ["threshold_{0}".format(vals) for vals in values] names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
@ -378,7 +378,7 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0] threshold = [0.0, 2.0]
values = threshold values = threshold
names = ["threshold_{0}".format(vals) for vals in values] names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)

View File

@ -18,7 +18,7 @@ k = 20
def get_temp_dir(): def get_temp_dir():
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4())) path = f"/tmp/autoswap/{str(uuid.uuid4())}"
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
@ -116,7 +116,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"] optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]
@ -187,7 +187,7 @@ dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
values = list(product(dim1, dim2, gtype)) values = list(product(dim1, dim2, gtype))
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
@ -250,7 +250,7 @@ optimizer_names = [
] ]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]
@ -391,7 +391,7 @@ gtype = [torch.float32]
optim_bits = [32, 8] optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits)) values = list(product(dim1, dim2, gtype, optim_bits))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
for vals in values for vals in values
] ]
@ -495,7 +495,7 @@ gtype = [torch.float32, torch.float16]
optimizer_names = ["adam8bit_blockwise"] optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
] ]