diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 7ca017d..6e5b6ac 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,8 +3,13 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .autograd._functions import (MatmulLtState, bmm_cublas, matmul, - matmul_cublas, mm_cublas) +from .autograd._functions import ( + MatmulLtState, + bmm_cublas, + matmul, + matmul_cublas, + mm_cublas, +) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index a08b560..b56b2ee 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function): qgrad_output, S1 = F.vectorwise_quant( grad_output, dim=dims, quant_type=quant_type ) - qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) + qA, S2 = F.vectorwise_quant( + A, dim=dims, quant_type=quant_type + ) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) grad_B = F.vectorwise_mm_dequant( igrad_B, @@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function): qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) grad_A = F.vectorwise_mm_dequant( - igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type + igrad_A, + S1, + S3.permute(permute_dim), + grad_output.dtype, + quant_type, ) return grad_A, grad_B, None, None, None @@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A, threshold=state.threshold + ) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function): if state.CxB is None: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format - state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + state.CxB, state.SB = F.transform( + state.CB, to_order=formatB + ) # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # # generate outlier index and subB @@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function): if (state.is_training and not has_grad) or state.CxB is None: state.reset_grads() - CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B) + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B) state.CxB, state.SB = F.transform(CB, to_order=formatB) else: has_grad = False @@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function): # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) state.subB = ( - (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half() + (outliers * state.SCB.view(-1, 1) / 127.0) + .t() + .contiguous() + .half() ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 @@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert state.has_fp16_weights, "Backprop only supported for fp16 weights." + assert ( + state.has_fp16_weights + ), "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: - grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() + grad_output = grad_output.view( + -1, grad_output.shape[-1] + ).contiguous() grad_A = grad_B = None @@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply def matmul( - A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0 + A: tensor, + B: tensor, + out: tensor = None, + state: MatmulLtState = None, + threshold=0.0, ): state = state or MatmulLtState() if threshold > 0.0: diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py index 8cc2c03..6e37606 100644 --- a/bitsandbytes/cuda_setup.py +++ b/bitsandbytes/cuda_setup.py @@ -1,7 +1,7 @@ """ -build is dependent on -- compute capability - - dependent on GPU family +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? - CUDA version - Software: - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) @@ -19,6 +19,8 @@ evaluation: """ import ctypes +import shlex +import subprocess from os import environ as env from pathlib import Path from typing import Set, Union @@ -26,10 +28,31 @@ from typing import Set, Union from .utils import print_err, warn_of_missing_prerequisite +def execute_and_return(command_string: str) -> Tuple[str, str]: + def _decode(subprocess_err_out_tuple): + return tuple( + to_decode.decode("UTF-8").strip() + for to_decode in subprocess_err_out_tuple + ) + + def execute_and_return_decoded_std_streams(command_string): + return _decode( + subprocess.Popen( + shlex.split(command_string), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).communicate() + ) + + std_out, std_err = execute_and_return_decoded_std_streams() + return std_out, std_err + + def check_cuda_result(cuda, result_val): if result_val != 0: + # TODO: undefined name 'error_str' cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) - print(f"Count not initialize CUDA - failure!") + print("Count not initialize CUDA - failure!") raise Exception("CUDA exception!") return result_val @@ -53,7 +76,9 @@ def get_compute_capability(): result = ctypes.c_int() device = ctypes.c_int() + # TODO: local variable 'context' is assigned to but never used context = ctypes.c_void_p() + # TODO: local variable 'error_str' is assigned to but never used error_str = ctypes.c_char_p() result = check_cuda_result(cuda, cuda.cuInit(0)) @@ -61,7 +86,9 @@ def get_compute_capability(): result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) ccs = [] for i in range(nGpus.value): - result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) + result = check_cuda_result( + cuda, cuda.cuDeviceGet(ctypes.byref(device), i) + ) result = check_cuda_result( cuda, cuda.cuDeviceComputeCapability( @@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path( } - non_existent_directories if len(cuda_runtime_libs) > 1: - err_msg = f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." + err_msg = ( + f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." + ) raise FileNotFoundError(err_msg) elif len(cuda_runtime_libs) < 1: - err_msg = f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." + err_msg = ( + f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." + ) raise FileNotFoundError(err_msg) single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 2e86958..236ef39 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) - str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) - str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) - str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) + str2optimizer32bit["momentum"] = ( + lib.cmomentum32bit_g32, + lib.cmomentum32bit_g16, + ) + str2optimizer32bit["rmsprop"] = ( + lib.crmsprop32bit_g32, + lib.crmsprop32bit_g16, + ) + str2optimizer32bit["adagrad"] = ( + lib.cadagrad32bit_g32, + lib.cadagrad32bit_g16, + ) + str2optimizer32bit["lars"] = ( + lib.cmomentum32bit_g32, + lib.cmomentum32bit_g16, + ) str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} - str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["adam"] = ( + lib.cadam_static_8bit_g32, + lib.cadam_static_8bit_g16, + ) str2optimizer8bit["momentum"] = ( lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16, @@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA: lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16, ) - str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["lamb"] = ( + lib.cadam_static_8bit_g32, + lib.cadam_static_8bit_g16, + ) str2optimizer8bit["lars"] = ( lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16, @@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7): if not signed: additional_items = 2 * additional_items for i in range(n): - fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 + fraction_items = ( + 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 + ) boundaries = torch.linspace(0.1, 1, fraction_items) means = (boundaries[:-1] + boundaries[1:]) / 2.0 data += ((10 ** (-(n - 1) + i)) * means).tolist() @@ -272,7 +292,13 @@ def get_transform_buffer( def nvidia_transform( - A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, ): if state is None: state = (A.shape, from_order) @@ -352,7 +378,11 @@ def estimate_quantiles( def quantize_blockwise( - A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None + A: Tensor, + code: Tensor = None, + absmax: Tensor = None, + rand=None, + out: Tensor = None, ) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: """ if out is None: out = torch.zeros_like(A, dtype=torch.float32) - lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + lib.cdequantize( + get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()) + ) return out @@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d( ) -def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): +def check_matmul( + A, B, out, transposed_A, transposed_B, expected_type=torch.int8 +): if not torch.cuda.is_initialized(): torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: @@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 def igemm( - A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, ): sout = check_matmul(A, B, out, transposed_A, transposed_B) if out is None: @@ -1193,7 +1231,11 @@ def igemm( def batched_igemm( - A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, ): if not len(A.shape) == 3 or not len(B.shape) == 3: raise ValueError( @@ -1392,9 +1434,13 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) if new_col_stats is None: - new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) assert ( new_row_stats.shape[0] == row_stats.shape[0] ), f"{new_row_stats.shape} vs {row_stats.shape}" @@ -1440,13 +1486,13 @@ def get_colrow_absmax( col_tiles = (cols + 255) // 256 tiled_rows = ((rows + 15) // 16) * 16 if row_stats is None: - row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( - -50000.0 - ) + row_stats = torch.empty( + (rows,), dtype=torch.float32, device=device + ).fill_(-50000.0) if col_stats is None: - col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( - -50000.0 - ) + col_stats = torch.empty( + (cols,), dtype=torch.float32, device=device + ).fill_(-50000.0) if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros( @@ -1462,7 +1508,13 @@ def get_colrow_absmax( prev_device = pre_call(A.device) lib.cget_col_row_stats( - ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols + ptrA, + ptrRowStats, + ptrColStats, + ptrNnzrows, + ct.c_float(threshold), + rows, + cols, ) post_call(prev_device) @@ -1526,7 +1578,9 @@ class CSCSparseTensor(object): def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) + rowptr = torch.zeros( + (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device + ) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) return CSRSparseTensor( @@ -1540,10 +1594,14 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) + colptr = torch.zeros( + (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device + ) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) - return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) + return CSCSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values + ) def coo_zeros(rows, cols, nnz, device, dtype=torch.half): @@ -1568,7 +1626,9 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( + A, threshold=threshold + ) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -1663,7 +1723,13 @@ def get_special_format_str(): def transform( - A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, ): if state is None: state = (A.shape, from_order) @@ -1716,7 +1782,9 @@ def transform( def spmm_coo(cooA, B, out=None): if out is None: - out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + out = torch.empty( + (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz @@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + out = torch.zeros( + (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device + ) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9ce3ac8..454dba5 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -2,8 +2,19 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set, - Tuple, TypeVar, Union, overload) +from typing import ( + Any, + Callable, + Dict, + Iterator, + Mapping, + Optional, + Set, + Tuple, + TypeVar, + Union, + overload, +) import torch import torch.nn.functional as F @@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding): class Int8Params(torch.nn.Parameter): def __new__( - cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None + cls, + data=None, + requires_grad=True, + has_fp16_weights=False, + CB=None, + SCB=None, ): cls.has_fp16_weights = has_fp16_weights cls.CB = None @@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter): return self.cuda(device) else: new_param = Int8Params( - super().to(device=device, dtype=dtype, non_blocking=non_blocking), + super().to( + device=device, dtype=dtype, non_blocking=non_blocking + ), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights, ) @@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear): threshold=0.0, index=None, ): - super(Linear8bitLt, self).__init__(input_features, output_features, bias) + super(Linear8bitLt, self).__init__( + input_features, output_features, bias + ) self.state = bnb.MatmulLtState() self.index = index @@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear): if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True - self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights + ) def init_8bit_state(self): self.state.CB = self.weight.CB diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 43e3973..7e2f566 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -23,7 +23,9 @@ class Adagrad(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: @@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: @@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 5cfaa28..3634971 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer): savedir=None, ): defaults = dict( - lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, ) super(AnalysisAdam, self).__init__(params, defaults) self.analysis = bnb_analysis @@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer): state["relerrors"] = torch.zeros( (256, 256), device=p_data_fp32.device ) - state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) + state["counts"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer): beta1, beta2 = group["betas"] bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] - step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) e = state["abserrors"] rele = state["relerrors"] counts = state["counts"] @@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer): denom = exp_avg_sq.sqrt().add_(group["eps"]) update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: + if ( + p_data_fp32.numel() <= 8192 + or p_data_fp32.numel() > 50000 * 1000 + ): # embedding layer or too small p_data_fp32 += -step_size * update_fp32 else: @@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer): # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f"Invalid analysis value: {self.analysis}!") + raise ValueError( + f"Invalid analysis value: {self.analysis}!" + ) denom = state2.sqrt().add_(group["eps"]) update_8bit = state1 / denom @@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer): if self.savedir != "" and state["step"] % 100 == 0: if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + shapestr = "_".join( + [str(dim) for dim in p_data_fp32.shape] + ) pathe = os.path.join( self.savedir, f"{p_id}_{shapestr}_abserr.pkl" ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index c6cf5c6..8a89fb0 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -24,7 +24,9 @@ class LARS(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS, self).__init__( "lars", params, @@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS8bit, self).__init__( "lars", params, @@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State): max_unorm=0.02, ): if momentum == 0: - raise NotImplementedError(f"LARS without momentum is not supported!") + raise NotImplementedError( + f"LARS without momentum is not supported!" + ) super(LARS32bit, self).__init__( "lars", params, @@ -121,7 +127,9 @@ class PytorchLARS(Optimizer): if momentum < 0.0: raise ValueError("Invalid momentum value: {}".format(momentum)) if weight_decay < 0.0: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict( lr=lr, @@ -132,7 +140,9 @@ class PytorchLARS(Optimizer): max_unorm=max_unorm, ) if nesterov and (momentum <= 0 or dampening != 0): - raise ValueError("Nesterov momentum requires a momentum and zero dampening") + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening" + ) super(PytorchLARS, self).__init__(params, defaults) def __setstate__(self, state): diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index b942e34..4fb30cd 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -46,9 +46,13 @@ class GlobalOptimManager(object): for group_index, group in enumerate(param_groups): for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: - self.index2config[(group_index, p_index)] = self.pid2config[id(p)] + self.index2config[(group_index, p_index)] = self.pid2config[ + id(p) + ] - def override_config(self, parameters, key=None, value=None, key_value_dict=None): + def override_config( + self, parameters, key=None, value=None, key_value_dict=None + ): """ Overrides initial optimizer config for specific parameters. @@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer): if len(groups) != len(saved_groups): raise ValueError( - "loaded state dict has a different number of " "parameter groups" + "loaded state dict has a different number of " + "parameter groups" ) param_lens = (len(g["params"]) for g in groups) saved_lens = (len(g["params"]) for g in saved_groups) @@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer): new_group["params"] = group["params"] return new_group - param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups) + ] self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): @@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[ - id(p) - ] + self.mng.index2config[ + (gindex, pindex) + ] = self.mng.pid2config[id(p)] found = True @torch.no_grad() @@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer): raise NotImplementedError(f"init_state method needs to be overidden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError(f"The update_step method needs to be overidden") + raise NotImplementedError( + f"The update_step method needs to be overidden" + ) class Optimizer2State(Optimizer8bit): @@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit): betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer2State, self).__init__(params, defaults, optim_bits) @@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): state["state1"] = torch.zeros_like( p, memory_format=torch.preserve_format, @@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit): if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) - self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( + p.device + ) state["state1"] = torch.zeros_like( p, @@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit): (blocks,), dtype=torch.float32, device=p.device ) else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max1"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) - state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max2"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) @@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit): state["new_max2"], config["weight_decay"], gnorm_scale=gnorm_scale, - unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + unorm_vec=state["unorm_vec"] + if config["max_unorm"] > 0.0 + else None, max_unorm=config["max_unorm"], ) @@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit): raise ValueError("Invalid epsilon value: {}".format(eps)) for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: - raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) if not 0.0 <= weight_decay: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer1State, self).__init__(params, defaults, optim_bits) @@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit): state = self.state[p] state["step"] = 0 - if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): state["state1"] = torch.zeros_like( p, memory_format=torch.preserve_format, @@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit): if state["step"] == 0: if "dynamic" not in self.name2qmap: self.fill_qmap() - self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) state["state1"] = torch.zeros_like( p, @@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit): (blocks,), dtype=torch.float32, device=p.device ) else: - state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) state["new_max1"] = torch.zeros( (1,), dtype=torch.float32, device=p.device ) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 679f783..7ddb12c 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -22,7 +22,9 @@ class RMSprop(Optimizer1State): block_wise=True, ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop, self).__init__( @@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State): block_wise=True, ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop8bit, self).__init__( @@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State): ): if alpha == 0: - raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") + raise NotImplementedError( + f"RMSprop with alpha==0.0 is not supported!" + ) if centered: raise NotImplementedError(f"Centered RMSprop is not supported!") super(RMSprop32bit, self).__init__( diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 6797407..8a9fc0e 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,6 +1,6 @@ +import sys import shlex import subprocess -import sys def execute_and_return(command_string: str) -> Tuple[str, str]: diff --git a/quicktest.py b/quicktest.py index 29d045d..0fcda64 100644 --- a/quicktest.py +++ b/quicktest.py @@ -14,23 +14,31 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): torch.int8 ) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( - torch.int8 - ) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) + A = torch.randint( + -128, 127, size=(dim1, dim2, dim3), device="cuda" + ).to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to( + torch.int8 + ) C1 = torch.matmul(A.float(), B.t().float()) A2, SA = F.transform(A, "col32") B2, SB = F.transform(B, "colx") if dims == 2: C2, SC = F.transform( - torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"), + torch.zeros( + A.shape[0], B.shape[0], dtype=torch.int32, device="cuda" + ), "col32", ) else: C2, SC = F.transform( torch.zeros( - A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda" + A.shape[0], + A.shape[1], + B.shape[0], + dtype=torch.int32, + device="cuda", ), "col32", ) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 9cd01a9..fc7a0e1 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"] transpose = [(False, False), (False, True), (True, True), (True, False)] str_transpose = ["FF", "FT", "TT", "TF"] dtype = [torch.float32, torch.float16] -values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) +values = list( + product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose) +) str_values = list( - product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose) + product( + dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose + ) ) names = [ "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( @@ -31,7 +35,9 @@ names = [ @pytest.mark.parametrize( - "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", + values, + ids=names, ) def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dim2 = dim2 - (dim2 % 16) @@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + torch.testing.assert_allclose( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) # batched matrix multiply if funcs[0] in [torch.bmm, torch.matmul]: A = torch.randn( - size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0] + size=(dim1, dim2, dim3), + device="cuda", + requires_grad=req_grad[0], ) B = torch.randn( - size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1] + size=(dim1, dim3, dim4), + device="cuda", + requires_grad=req_grad[1], ) target = torch.randn( - size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1] + size=(dim1, dim2, dim4), + device="cuda", + requires_grad=req_grad[1], ) torch.nn.init.xavier_uniform_(B) @@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) + torch.testing.assert_allclose( + out_bnb, out_torch, atol=0.027, rtol=0.2 + ) if any(req_grad): out_bnb.data.copy_(out_torch) @@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): if funcs[0] in [torch.matmul]: dim1 = dim1 - (dim1 % 16) A = torch.randn( - size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0] + size=(dim1, dim2, dim3), + device="cuda", + requires_grad=req_grad[0], ) dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) target = torch.randn( - size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1] + size=(dim1, dim2, dim4), + device="cuda", + requires_grad=req_grad[1], ) torch.nn.init.xavier_uniform_(B) @@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) @@ -258,7 +290,16 @@ names = [ ids=names, ) def test_matmullt( - dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + decomp, + has_fp16_weights, ): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) @@ -278,7 +319,10 @@ def test_matmullt( size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype ) target = torch.randn( - size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype + size=(dim2, dim4), + device="cuda", + requires_grad=req_grad[1], + dtype=dtype, ) torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -317,14 +361,18 @@ def test_matmullt( if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb = torch.nn.functional.mse_loss( + out_bnb, target + ).mean() loss_bnb.backward() gradA1 = A.grad gradB1 = B.grad A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -332,7 +380,9 @@ def test_matmullt( B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() assert torch.abs(gradB1).sum() > 0.0 @@ -341,4 +391,6 @@ def test_matmullt( assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + torch.testing.assert_allclose( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index d45354f..5a58be4 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -3,8 +3,12 @@ from typing import List, NamedTuple import pytest -from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup, - get_cuda_runtime_lib_path, tokenize_paths) +from bitsandbytes.cuda_setup import ( + CUDA_RUNTIME_LIB, + evaluate_cuda_setup, + get_cuda_runtime_lib_path, + tokenize_paths, +) class InputAndExpectedOutput(NamedTuple): @@ -13,11 +17,26 @@ class InputAndExpectedOutput(NamedTuple): HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [ - (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), - (f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), - (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), - (f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), - (f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), + ( + f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", + f"dir/with/{CUDA_RUNTIME_LIB}", + ), + ( + f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", + f"dir/with/{CUDA_RUNTIME_LIB}", + ), + ( + f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", + f"dir/with/{CUDA_RUNTIME_LIB}", + ), + ( + f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", + f"dir/with/{CUDA_RUNTIME_LIB}", + ), + ( + f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", + f"dir/with/{CUDA_RUNTIME_LIB}", + ), ( f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}", diff --git a/tests/test_functional.py b/tests/test_functional.py index 11cd198..ab7d672 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -86,7 +86,9 @@ def teardown(): pass -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16], ids=["float", "half"] +) def test_estimate_quantiles(dtype): A = torch.rand(1024, 1024, device="cuda") A = A.to(dtype) @@ -190,7 +192,9 @@ def test_dynamic_blockwise_stochastic_quantization(): ) -@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) +@pytest.mark.parametrize( + "gtype", [torch.float32, torch.float16], ids=["float", "half"] +) def test_percentile_clipping(gtype): gnorm_vec1 = torch.zeros(100, device="cuda") gnorm_vec2 = torch.zeros(100, device="cuda") @@ -270,7 +274,13 @@ def mean(xx): dim1 = [1024 * 2] dim2 = [1024 * 16] methods = [ - (lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant) + ( + lambda x, dim: quant(x), + lambda x, dim: quant(x), + dequant, + dequant, + mm_dequant, + ) ] methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) # methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) @@ -279,11 +289,14 @@ batched = [False, True] values = list(product(dim1, dim2, methods, batched)) values_names = list(product(dim1, dim2, method_names, batched)) names = [ - "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names + "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) + for vals in values_names ] -@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names) +@pytest.mark.parametrize( + "dim1, dim2, quant_methods, batched", values, ids=names +) def test_approx_igemm(dim1, dim2, quant_methods, batched): dim1 = dim1 - (dim1 % 32) dim2 = dim2 - (dim2 % 32) @@ -339,14 +352,18 @@ names = [ ] -@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names) +@pytest.mark.parametrize( + "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names +) def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 16) seq_dim = seq_dim - (seq_dim % 16) for i in range(k): shapeA = ( - (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + (batch_dim, hidden_dim) + if not transpose[0] + else (hidden_dim, batch_dim) ) shapeB = ( (32 * random.randint(1, 4), hidden_dim) @@ -394,7 +411,9 @@ seq_dim = torch.randint(32, 512, size=(n,)).tolist() hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() batch_dim = torch.randint(2, 16, size=(n,)).tolist() values = list(product(seq_dim, hidden_dim, batch_dim)) -names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values] +names = [ + "seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values +] @pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) @@ -406,11 +425,13 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): A = torch.randint( -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" ).to(torch.int8) - B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to( - torch.int8 - ) + B = torch.randint( + -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda" + ).to(torch.int8) out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) - iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) + iout = torch.empty( + A.shape[2], B.shape[2], dtype=torch.int32, device=A.device + ) out = F.igemm(A, B, out=iout) torch.testing.assert_allclose(out.float(), out2) @@ -428,7 +449,9 @@ names = [ ] -@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names) +@pytest.mark.parametrize( + "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names +) def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) @@ -444,7 +467,9 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): errs2 = [] relerrs2 = [] for i in range(k): - A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") + A = torch.normal( + 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda" + ) if transpose: B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") else: @@ -504,7 +529,8 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist() transpose = [(False, False), (True, False), (False, True), (True, True)] values = list(product(dim1, dim2, dim3, dim4, transpose)) names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) + for vals in values ] @@ -529,7 +555,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) out = F.igemm(A.permute([0, 2, 1]), B) elif transpose[0] and transpose[1]: - out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) + out2 = torch.bmm( + A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() + ) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) torch.testing.assert_allclose(out.float(), out2.float()) @@ -563,7 +591,9 @@ a_order = ["row"] out_order = ["col", "row", "col32"] transpose = [False] dims = [2, 3] -values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) +values = list( + product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) +) names = [ "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format( @@ -574,9 +604,13 @@ names = [ @pytest.mark.parametrize( - "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names + "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", + values, + ids=names, ) -def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): +def test_nvidia_transform( + dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose +): if dims == 3 and out_order != "col32": return if dtype == torch.int32 and out_order != "col32": @@ -586,7 +620,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( + dtype + ) out, S = F.nvidia_transform(A, to_order=orderOut) @@ -598,7 +634,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans if dims == 2: n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) elif dims == 3: - n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) + n = ( + A.shape[0] + * A.shape[1] + * (A.shape[2] + (32 - (A.shape[2] % 32))) + ) assert out.numel() == n elif orderOut == "col_turing": # 32 col 8 row tiles @@ -613,7 +653,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans j = col coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile + rowtile = ( + (row // 8) + (1 if row % 8 != 0 else 0) + ) * total_coltile offset = 32 * 8 * (rowtile + coltile) col2 = col % 32 row2 = (row % 8) * 32 @@ -624,7 +666,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) if orderOut == "col32": - out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) + out2, S = F.nvidia_transform( + out, from_order=orderOut, to_order="row", state=S + ) torch.testing.assert_allclose(A, out2) @@ -657,10 +701,12 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): torch.int8 ) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( - torch.int8 - ) - B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) + A = torch.randint( + -128, 127, size=(dim1, dim2, dim3), device="cuda" + ).to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to( + torch.int8 + ) C1 = torch.matmul(A.float(), B.t().float()) A2, SA = F.transform(A, "col32") @@ -670,7 +716,9 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): torch.testing.assert_allclose(C1, C3.float()) # transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( + torch.int8 + ) C1 = torch.matmul(A.float(), B.float()) B2t, SBt = F.transform(B, "col_turing", transpose=True) @@ -688,7 +736,8 @@ dims = (2,) # ldb = list(range(256, 1*1024, 256)) values = list(product(dim1, dim2, dim3, dim4, dims)) names = [ - "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) + for vals in values ] @@ -699,7 +748,9 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): if dims == 2: A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() elif dims == 3: - A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() + A = torch.normal( + 0, 0.5, size=(dim1, dim2, dim3), device="cuda" + ).half() B = torch.randn((dim4, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) @@ -742,7 +793,9 @@ values = [ # values = list(product(batch, seq, model, hidden)) -names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values] +names = [ + "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @@ -909,7 +962,9 @@ dims = (2,) # ldb = list(range(256, 1*1024, 256)) formatB = ["col_turing", "col_ampere"] values = list(product(dim1, dim4, dims, formatB)) -names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) @@ -992,7 +1047,9 @@ def test_colrow_absmax(dim1, dim2, dims): torch.testing.assert_allclose(row_stats1_trunc, row_stats2) torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2) - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( + A, threshold=0.0 + ) torch.testing.assert_allclose(col_stats1, col_stats2) torch.testing.assert_allclose(row_stats1, row_stats2) @@ -1023,8 +1080,12 @@ def test_double_quant(dim1, dim2): torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) n = CAt.numel() - num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() - num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() + num_not_close_rows = ( + (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + ) + num_not_close_cols = ( + (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() + ) # allow for 1:500 error due to rounding differences min_error = 1 / 500 @@ -1123,7 +1184,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner): c = 10.0 * inner * scale row_scale = torch.ones_like(maxA) / c - outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + outC32, SC = F.igemmlt( + A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale + ) C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: @@ -1204,7 +1267,9 @@ def test_row_scale_bench(dim1, dim4, inner): torch.cuda.synchronize() t0 = time.time() for i in range(k): - outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) + outC32, SC = F.igemmlt( + A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale + ) torch.cuda.synchronize() print("row-wise", time.time() - t0) @@ -1230,7 +1295,9 @@ a_order = ["row"] out_order = ["col32", "col_turing", "col_ampere"] transpose = [False, True] dims = [2] -values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) +values = list( + product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) +) names = [ "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format( *vals @@ -1240,14 +1307,20 @@ names = [ @pytest.mark.parametrize( - "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names + "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", + values, + ids=names, ) def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to( + dtype + ) elif dims == 3: - A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) + A = torch.randint( + 10, 99, size=(dim1, dim2, dim3), device="cuda" + ).to(dtype) A.view(-1)[-1] = -1 if transpose: @@ -1282,7 +1355,9 @@ names = [ ] -@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names) +@pytest.mark.parametrize( + "dim1, dim2, dtype, orderA, orderOut", values, ids=names +) def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): for i in range(1): A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype) @@ -1332,17 +1407,23 @@ def test_coo_double_quant(dim1, dim2): idx = torch.abs(A) >= threshold CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( + A, threshold=threshold + ) if coo_tensor is not None: A1 = A * idx A2 = torch.zeros_like(A) - A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values + A2[ + coo_tensor.rowidx.long(), coo_tensor.colidx.long() + ] = coo_tensor.values torch.testing.assert_allclose(A1, A2) A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + torch.testing.assert_allclose( + A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 + ) n = 2 @@ -1454,7 +1535,9 @@ def test_integrated_sparse_decomp(dim1, dim2): out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( + A, threshold=threshold + ) C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) @@ -1494,7 +1577,9 @@ dim2 = [12288] dtype = [torch.float16] out_function = ["zeros", "ones"] values = list(product(dim1, dim2, dtype, out_function)) -names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) @@ -1536,7 +1621,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): std = out1.std() out1 /= std out2 /= std - assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) + assert_all_approx_close( + out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count + ) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) idx_col = torch.randint(0, A2.shape[-1], size=(15,)) @@ -1734,7 +1821,9 @@ values.append((batch_size, seqdim, 768, 4 * 768)) # values.append((batch_size, seqdim, 4096, 4*4096)) # values.append((batch_size, seqdim, 5140, 4*5140)) # values.append((batch_size, seqdim, 12288, 4*12288)) -names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values] +names = [ + "batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) diff --git a/tests/test_modules.py b/tests/test_modules.py index 6b8d641..7faadb8 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) norm = math.sqrt(math.pi) / math.sqrt(2.0) # std = torch.abs(x).mean()*norm std = torch.std(x) @@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function): return x.to(dtype) def get_8bit_linear(x, stochastic=False): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) max1 = torch.abs(x).max() x = x / max1 * 127 x = round_func(x) / 127 * max1 @@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_vector_wise(x, dim, stochastic=False): - round_func = LinearFunction.round_stoachastic if stochastic else torch.round + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1[max1 == 0] = 1.0 x = (x * 127) / max1 @@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function): weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) outputq = bnb.functional.igemm(x8, weight8.t()) - output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) + output = LinearFunction.dequant( + outputq, S1, S2, x.dtype, args.quant_type + ) # if torch.rand(1) < 0.01: # output32 = torch.matmul(x, weight.t()) # err = torch.abs(output-output32).float() @@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function): grad_weight8, S1, S2, grad_output.dtype, args.quant_type ) - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=2 + ) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) grad_input8 = bnb.functional.igemm(grad_output8, weight8) grad_input = LinearFunction.dequant( @@ -338,8 +348,12 @@ def test_linear8bit(): loss2.backward() loss3.backward() - assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) - assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) + assert_all_approx_close( + l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2 + ) + assert_all_approx_close( + l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2 + ) assert_all_approx_close( l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 ) @@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient(): l1 = torch.nn.Sequential( *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] ) - l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential( + *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)] + ) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) @@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .cuda() + .half() + ) assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 @@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): if threshold > 0: assert mlp.fc2.state.idx is not None - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .half() + .cuda() + ) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda") + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .half() + .to("cuda") + ) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() diff --git a/tests/test_optim.py b/tests/test_optim.py index b84425e..8e12761 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [ ("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2"), ] -str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["momentum8bit"] = [ + ("momentum_buffer", "state1", "qmap1", "max1") +] str2statenames["momentum8bit_blockwise"] = [ ("momentum_buffer", "state1", "qmap1", "absmax1") ] str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] -str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] +str2statenames["rmsprop8bit_blockwise"] = [ + ("square_avg", "state1", "qmap1", "absmax1") +] dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) + bnb.optim.GlobalOptimManager.get_instance().override_config( + p3, "optim_bits", 8 + ) - bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) + bnb.optim.GlobalOptimManager.get_instance().register_parameters( + [p1, p2, p3] + ) p1 = p1.cuda() p2 = p2.cuda() p3 = p3.cuda() @@ -245,7 +255,9 @@ optimizer_names = [ "rmsprop8bit_blockwise", ] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) - torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) + torch.testing.assert_allclose( + raws1cpy, bnb_optimizer.state[p2][name2] + ) + torch.testing.assert_allclose( + qmap1, bnb_optimizer.state[p2][qmap] + ) if "blockwise" in optim_name: s1 = F.dequantize_blockwise( @@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): num_not_close = ( torch.isclose( - torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol + torch_optimizer.state[p1][name1], + s1, + atol=atol, + rtol=rtol, ) == 0 ) assert num_not_close.sum().item() < 20 - torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + torch.testing.assert_allclose( + p1, p2.float(), atol=patol, rtol=prtol + ) # the parameters diverge quickly. Here we keep them close # together so we can test against the Adam error @@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097] gtype = [torch.float32] optim_bits = [32, 8] values = list(product(dim1, dim2, gtype, optim_bits)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) + for vals in values +] @pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) @@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): p2 = p1.clone() adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) adam2 = bnb.optim.Adam( - [p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5 + [p2], + lr, + (beta1, beta2), + eps, + optim_bits=optim_bits, + percentile_clipping=5, ) gnorm_vec = torch.zeros(100).cuda() @@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): for i in range(50): step += 1 - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + ( + 0.01 * i + ) g2 = g1.clone() p2.grad = g2 @@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): elif optim_bits == 8: torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) torch.testing.assert_allclose( - adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3 + adam1.state[p1]["state1"], + adam2.state[p2]["state1"], + atol=2, + rtol=1e-3, ) torch.testing.assert_allclose( - adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3 + adam1.state[p1]["state2"], + adam2.state[p2]["state2"], + atol=2, + rtol=1e-3, ) adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"]) adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"]) @@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16] # optimizer_names = ['lars_apex', 'lars8bit'] optimizer_names = ["adam8bit_blockwise"] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] +names = [ + "dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values +] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)