diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 3c3affa..7ca017d 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -1,16 +1,18 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .nn import modules -from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState +from .autograd._functions import (MatmulLtState, bmm_cublas, matmul, + matmul_cublas, mm_cublas) from .cextension import COMPILED_WITH_CUDA +from .nn import modules if COMPILED_WITH_CUDA: from .optim import adam -__pdoc__ = {'libbitsandbytes': False, - 'optim.optimizer.Optimizer8bit': False, - 'optim.optimizer.MockArgs': False - } +__pdoc__ = { + "libbitsandbytes": False, + "optim.optimizer.Optimizer8bit": False, + "optim.optimizer.MockArgs": False, +} diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e641583..a08b560 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,21 +1,24 @@ +from dataclasses import dataclass + import torch + import bitsandbytes as bnb import bitsandbytes.functional as F -from dataclasses import dataclass - tensor = torch.Tensor -''' +""" This class pools outlier dimensions across layers. This is particularly important for small models where outlier features are less systematic and occur with low frequency. -''' +""" + + class GlobalOutlierPooler(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.outliers = set() @@ -29,25 +32,29 @@ class GlobalOutlierPooler(object): return cls._instance def add_outliers(self, outlier_idx, feature_dim): - if self.model_dim is None: self.model_dim = feature_dim - if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer + if self.model_dim is None: + self.model_dim = feature_dim + if feature_dim != self.model_dim: + return # we do not encode outliers for the 2nd FFN layer self.outliers.update(outlier_idx.tolist()) def get_current_outlier_idx(self): return torch.Tensor(list(self.outliers)).to(torch.int64) -class MatMul8bit(torch.autograd.Function): +class MatMul8bit(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]): + def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]): if precision[0] != 8: with torch.no_grad(): output = torch.matmul(A, B) else: - if len(B.shape) == 2: dim = 0 - else: dim = 1 + if len(B.shape) == 2: + dim = 0 + else: + dim = 1 qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type) qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type) iout = F.igemm(qA, qB) @@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function): else: if len(B.shape) == 2 and len(A.shape) == 3: grad_output = grad_output.contiguous() - if not grad_output.is_contiguous(): grad_output.contiguous() - qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type) - if not A.is_contiguous(): A = A.contiguous() - qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) + if not grad_output.is_contiguous(): + grad_output.contiguous() + qgrad_output, S1 = F.vectorwise_quant( + grad_output.view(-1, grad_output.shape[2]), + dim=0, + quant_type=quant_type, + ) + if not A.is_contiguous(): + A = A.contiguous() + qA, S2 = F.vectorwise_quant( + A.view(-1, A.shape[2]), dim=0, quant_type=quant_type + ) igrad_B = F.igemm(qA.t(), qgrad_output) - grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) + grad_B = F.vectorwise_mm_dequant( + igrad_B, S2.t(), S1, grad_output.dtype, quant_type + ) else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) - grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type) + grad_B = F.vectorwise_mm_dequant( + igrad_B, + S2.permute(permute_dim), + S1, + grad_output.dtype, + quant_type, + ) if A.requires_grad: - if len(grad_output.shape) == 3: dims = [2] - else: dims = [1] + if len(grad_output.shape) == 3: + dims = [2] + else: + dims = [1] if len(B.shape) == 3: # bio -> boi @@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function): with torch.no_grad(): grad_A = torch.matmul(grad_output, B.permute(permute_dim)) else: - qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) - grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type) + grad_A = F.vectorwise_mm_dequant( + igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type + ) return grad_A, grad_B, None, None, None @@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply bmm_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply + @dataclass class MatmulLtState: CB = None @@ -159,7 +191,6 @@ class MatmulLtState: class MatMul8bitLt(torch.autograd.Function): - @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): # 1. Quantize A @@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function): requires_gradB = B.requires_grad formatB = state.formatB input_shape = A.shape - if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() - assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!' + if state.outlier_pool is None: + state.outlier_pool = GlobalOutlierPooler.get_instance() + assert ( + A.dtype == torch.float16 + ), f"The input data type needs to be fp16 but {A.dtype} was found!" # 1. Quantize A - if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() + if len(A.shape) == 3: + A = A.view(-1, A.shape[-1]).contiguous() CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: @@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function): # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) - #state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() - #if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: + # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() + # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # # generate outlier index and subB # outlier_idx = torch.unique(coo_tensorA.colidx).long() # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) @@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function): # state.idx = outlier_idx # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() - #if state.idx is not None: + # if state.idx is not None: # # extract outliers # CA[:, state.idx] = 0 # CAt[:, state.idx] = 0 # subA = A[:, state.idx] - #else: + # else: # subA = None else: if not state.has_fp16_weights and state.CxB is None: state.CxB, state.SB = F.transform(state.CB, to_order=formatB) subA = None - # 2. Quantize B if state.has_fp16_weights: - has_grad = (True if (getattr(B, 'grad', None) is not None) else False) + has_grad = True if (getattr(B, "grad", None) is not None) else False is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) - if is_transposed: B = B.contiguous() + if is_transposed: + B = B.contiguous() if (state.is_training and not has_grad) or state.CxB is None: state.reset_grads() @@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function): outlier_idx = torch.unique(coo_tensorA.colidx) state.idx = outlier_idx - #state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) - #if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: # # do not use pool for 2nd FFN layer # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) - #else: + # else: # state.idx = outlier_idx outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) - state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half() + state.subB = ( + (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half() + ) CA[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0 subA = A[:, state.idx.long()] @@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function): output_shape = (input_shape[0], shapeB[0]) # 3. Matmul - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) output = F.mm_dequant(out32, Sout32, SCA, state.SCB) @@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) - #clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x clone_func = torch.clone return clone_func(output.view(output_shape)) @@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function): SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state - assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.' + assert state.has_fp16_weights, "Backprop only supported for fp16 weights." if len(grad_output.shape) == 3: grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() @@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function): Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) - C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True) + C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: - C32grad, Sgrad = F.transform(Cgrad, 'col32') + C32grad, Sgrad = F.transform(Cgrad, "col32") if state.CxBt is None: - state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view( + ctx.grad_shape + ) return grad_A, grad_B, None, None, None, None, None @@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function): matmul = MatMul8bitLt.apply -def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0): +def matmul( + A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0 +): state = state or MatmulLtState() if threshold > 0.0: state.threshold = threshold return MatMul8bitLt.apply(A, B, out, state) - diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 4bc7bf7..bc11474 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -1,6 +1,7 @@ import ctypes as ct import os from warnings import warn + from bitsandbytes.cuda_setup import evaluate_cuda_setup @@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.context = {} binary_name = evaluate_cuda_setup() - if not os.path.exists(os.path.dirname(__file__) + f'/{binary_name}'): - print(f'TODO: compile library for specific version: {binary_name}') - print('defaulting to libbitsandbytes.so') - self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so') + if not os.path.exists(os.path.dirname(__file__) + f"/{binary_name}"): + print(f"TODO: compile library for specific version: {binary_name}") + print("defaulting to libbitsandbytes.so") + self.lib = ct.cdll.LoadLibrary( + os.path.dirname(__file__) + "/libbitsandbytes.so" + ) else: - self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + f'/{binary_name}') + self.lib = ct.cdll.LoadLibrary( + os.path.dirname(__file__) + f"/{binary_name}" + ) @classmethod def get_instance(cls): @@ -35,6 +40,8 @@ try: lib.get_cusparse.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError: - warn("The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers and GPU quantization are unavailable.") + warn( + "The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers and GPU quantization are unavailable." + ) COMPILED_WITH_CUDA = False diff --git a/bitsandbytes/cuda_setup.py b/bitsandbytes/cuda_setup.py index 5ed0c89..0dd53c5 100644 --- a/bitsandbytes/cuda_setup.py +++ b/bitsandbytes/cuda_setup.py @@ -18,31 +18,36 @@ evaluation: - based on that set the default path """ -from os import environ as env -from pathlib import Path -from typing import Set, Union -from .utils import warn_of_missing_prerequisite, print_err - import ctypes import shlex import subprocess +from os import environ as env +from pathlib import Path +from typing import Set, Union + +from .utils import print_err, warn_of_missing_prerequisite + def execute_and_return(strCMD): - proc = subprocess.Popen(shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + proc = subprocess.Popen( + shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) out, err = proc.communicate() out, err = out.decode("UTF-8").strip(), err.decode("UTF-8").strip() return out, err + def check_cuda_result(cuda, result_val): if result_val != 0: cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) print(f"Count not initialize CUDA - failure!") - raise Exception('CUDA exception!') + raise Exception("CUDA exception!") return result_val + # taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 def get_compute_capability(): - libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') + libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll") for libname in libnames: try: cuda = ctypes.CDLL(libname) @@ -51,8 +56,7 @@ def get_compute_capability(): else: break else: - raise OSError("could not load any of: " + ' '.join(libnames)) - + raise OSError("could not load any of: " + " ".join(libnames)) nGpus = ctypes.c_int() cc_major = ctypes.c_int() @@ -69,39 +73,43 @@ def get_compute_capability(): ccs = [] for i in range(nGpus.value): result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) - result = check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device)) - ccs.append(f'{cc_major.value}.{cc_minor.value}') + result = check_cuda_result( + cuda, + cuda.cuDeviceComputeCapability( + ctypes.byref(cc_major), ctypes.byref(cc_minor), device + ), + ) + ccs.append(f"{cc_major.value}.{cc_minor.value}") - #TODO: handle different compute capabilities; for now, take the max + # TODO: handle different compute capabilities; for now, take the max ccs.sort() - return ccs[-1] + # return ccs[-1] + return ccs + CUDA_RUNTIME_LIB: str = "libcudart.so" + def tokenize_paths(paths: str) -> Set[Path]: - return { - Path(ld_path) for ld_path in paths.split(':') - if ld_path - } + return {Path(ld_path) for ld_path in paths.split(":") if ld_path} + def get_cuda_runtime_lib_path( # TODO: replace this with logic for all paths in env vars LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH") ) -> Union[Path, None]: - """ # TODO: add doc-string - """ + """# TODO: add doc-string""" if not LD_LIBRARY_PATH: warn_of_missing_prerequisite( - 'LD_LIBRARY_PATH is completely missing from environment!' + "LD_LIBRARY_PATH is completely missing from environment!" ) return None ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH) - non_existent_directories: Set[Path] = { - path for path in ld_library_paths - if not path.exists() + non_existent_directories: Set[Path] = { + path for path in ld_library_paths if not path.exists() } if non_existent_directories: @@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path( ) cuda_runtime_libs: Set[Path] = { - path / CUDA_RUNTIME_LIB for path in ld_library_paths + path / CUDA_RUNTIME_LIB + for path in ld_library_paths if (path / CUDA_RUNTIME_LIB).is_file() } - non_existent_directories @@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path( single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) return single_cuda_runtime_lib_dir + def evaluate_cuda_setup(): cuda_path = get_cuda_runtime_lib_path() cc = get_compute_capability() - binary_name = 'libbitsandbytes_cpu.so' + binary_name = "libbitsandbytes_cpu.so" if not (has_gpu := bool(cc)): - print('WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library...') + print( + "WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library..." + ) return binary_name - has_cublaslt = cc in ['7.5', '8.0', '8.6'] + has_cublaslt = cc in ["7.5", "8.0", "8.6"] - # TODO: + # TODO: # (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) # (2) Multiple CUDA versions installed cuda_home = str(Path(cuda_path).parent.parent) - ls_output, err = execute_and_return(f'{cuda_home}/bin/nvcc --version') - cuda_version = ls_output.split('\n')[3].split(',')[-1].strip().lower().replace('v', '') - major, minor, revision = cuda_version.split('.') - cuda_version_string = f'{major}{minor}' + ls_output, err = execute_and_return(f"{cuda_home}/bin/nvcc --version") + cuda_version = ( + ls_output.split("\n")[3].split(",")[-1].strip().lower().replace("v", "") + ) + major, minor, revision = cuda_version.split(".") + cuda_version_string = f"{major}{minor}" binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so' diff --git a/bitsandbytes/debug_cli.py b/bitsandbytes/debug_cli.py index 88307a6..4306bc0 100644 --- a/bitsandbytes/debug_cli.py +++ b/bitsandbytes/debug_cli.py @@ -1,6 +1,5 @@ import typer - cli = typer.Typer() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index ac85f88..2e86958 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import ctypes as ct import random @@ -9,47 +9,68 @@ from typing import Tuple import torch from torch import Tensor -from .cextension import lib, COMPILED_WITH_CUDA +from .cextension import COMPILED_WITH_CUDA, lib name2qmap = {} if COMPILED_WITH_CUDA: - ''' C FUNCTIONS FOR OPTIMIZERS ''' + """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) - str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit['rmsprop'] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) - str2optimizer32bit['adagrad'] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) - str2optimizer32bit['lars'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) - str2optimizer32bit['lamb'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) + str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) + str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) + str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) + str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} - str2optimizer8bit['adam'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) - str2optimizer8bit['momentum'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16) - str2optimizer8bit['rmsprop'] = (lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g16) - str2optimizer8bit['lamb'] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) - str2optimizer8bit['lars'] = (lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g16) + str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["momentum"] = ( + lib.cmomentum_static_8bit_g32, + lib.cmomentum_static_8bit_g16, + ) + str2optimizer8bit["rmsprop"] = ( + lib.crmsprop_static_8bit_g32, + lib.crmsprop_static_8bit_g16, + ) + str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) + str2optimizer8bit["lars"] = ( + lib.cmomentum_static_8bit_g32, + lib.cmomentum_static_8bit_g16, + ) str2optimizer8bit_blockwise = {} - str2optimizer8bit_blockwise['adam'] = (lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['momentum'] = (lib.cmomentum_8bit_blockwise_fp32, lib.cmomentum_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['rmsprop'] = (lib.crmsprop_8bit_blockwise_fp32, lib.crmsprop_8bit_blockwise_fp16) - str2optimizer8bit_blockwise['adagrad'] = (lib.cadagrad_8bit_blockwise_fp32, lib.cadagrad_8bit_blockwise_fp16) + str2optimizer8bit_blockwise["adam"] = ( + lib.cadam_8bit_blockwise_fp32, + lib.cadam_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["momentum"] = ( + lib.cmomentum_8bit_blockwise_fp32, + lib.cmomentum_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["rmsprop"] = ( + lib.crmsprop_8bit_blockwise_fp32, + lib.crmsprop_8bit_blockwise_fp16, + ) + str2optimizer8bit_blockwise["adagrad"] = ( + lib.cadagrad_8bit_blockwise_fp32, + lib.cadagrad_8bit_blockwise_fp16, + ) class CUBLAS_Context(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.context = {} - #prev_device = torch.cuda.current_device() - #for i in range(torch.cuda.device_count()): + # prev_device = torch.cuda.current_device() + # for i in range(torch.cuda.device_count()): # torch.cuda.set_device(torch.device('cuda', i)) # self.context.append(ct.c_void_p(lib.get_context())) - #torch.cuda.set_device(prev_device) + # torch.cuda.set_device(prev_device) @classmethod def get_instance(cls): @@ -66,11 +87,12 @@ class CUBLAS_Context(object): torch.cuda.set_device(prev_device) return self.context[device.index] + class Cusparse_Context(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.context = ct.c_void_p(lib.get_cusparse()) @@ -82,14 +104,16 @@ class Cusparse_Context(object): cls._instance.initialize() return cls._instance + def create_linear_map(signed=True): if signed: return torch.linspace(-1.0, 1.0, 256) else: return torch.linspace(0.0, 1.0, 256) + def create_dynamic_map(signed=True, n=7): - ''' + """ Creates the dynamic quantiztion map. The dynamic data type is made up of a dynamic exponent and @@ -103,46 +127,54 @@ def create_dynamic_map(signed=True, n=7): For more details see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] - ''' + """ data = [] # these are additional items that come from the case # where all the exponent bits are zero and no # indicator bit is present - additional_items = 2**(7-n)-1 - if not signed: additional_items = 2*additional_items + additional_items = 2 ** (7 - n) - 1 + if not signed: + additional_items = 2 * additional_items for i in range(n): - fraction_items = 2**(i+7-n)+1 if signed else 2**(i+7-n+1)+1 + fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 boundaries = torch.linspace(0.1, 1, fraction_items) - means = (boundaries[:-1]+boundaries[1:])/2.0 - data += ((10**(-(n-1)+i))*means).tolist() + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(n - 1) + i)) * means).tolist() if signed: - data += (-(10**(-(n-1)+i))*means).tolist() + data += (-(10 ** (-(n - 1) + i)) * means).tolist() if additional_items > 0: - boundaries = torch.linspace(0.1, 1, additional_items+1) - means = (boundaries[:-1]+boundaries[1:])/2.0 - data += ((10**(-(n-1)+i))*means).tolist() + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(n - 1) + i)) * means).tolist() if signed: - data += (-(10**(-(n-1)+i))*means).tolist() + data += (-(10 ** (-(n - 1) + i)) * means).tolist() data.append(0) data.append(1.0) data.sort() return Tensor(data) + def get_special_format_str(): major, minor = torch.cuda.get_device_capability() if major < 7: - print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!') + print( + f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!" + ) assert major >= 7 - if major == 7: return 'col_turing' - elif major == 8: return 'col_ampere' - else: return 'col_turing' + if major == 7: + return "col_turing" + elif major == 8: + return "col_ampere" + else: + return "col_turing" + def get_ptr(A: Tensor) -> ct.c_void_p: - ''' + """ Get the ctypes pointer from a PyTorch Tensor. Parameters @@ -153,31 +185,39 @@ def get_ptr(A: Tensor) -> ct.c_void_p: Returns ------- ctypes.c_void_p - ''' - if A is None: return None - else: return ct.c_void_p(A.data.storage().data_ptr()) + """ + if A is None: + return None + else: + return ct.c_void_p(A.data.storage().data_ptr()) + def pre_call(device): prev_device = torch.cuda.current_device() torch.cuda.set_device(device) return prev_device + def post_call(prev_device): torch.cuda.set_device(prev_device) + def get_transform_func(dtype, orderA, orderOut, transpose=False): name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' if not hasattr(lib, name): print(name) - raise ValueError(f'Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}') + raise ValueError( + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" + ) else: return getattr(lib, name) + class GlobalData(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.data = {} @@ -190,15 +230,17 @@ class GlobalData(object): return cls._instance -def get_transform_buffer(shape, dtype, device, to_order, from_order='row', transpose=False): - #init_func = torch.empty +def get_transform_buffer( + shape, dtype, device, to_order, from_order="row", transpose=False +): + # init_func = torch.empty init_func = torch.zeros dims = len(shape) if dims == 2: rows = shape[0] elif dims == 3: - rows = shape[0]*shape[1] + rows = shape[0] * shape[1] cols = shape[-1] state = (shape, to_order) @@ -209,30 +251,39 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order='row', trans cols = tmp state = (shape[::-1], to_order) - if to_order == 'row' or to_order == 'col': + if to_order == "row" or to_order == "col": return init_func(shape, dtype=dtype, device=device), state - elif to_order == 'col32': + elif to_order == "col32": # blocks of 32 columns (padded) - cols = 32*((cols+31)//32) + cols = 32 * ((cols + 31) // 32) return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == 'col_turing': + elif to_order == "col_turing": # blocks of 32 columns and 8 rows - cols = 32*((cols+31)//32) - rows = 8*((rows+7)//8) + cols = 32 * ((cols + 31) // 32) + rows = 8 * ((rows + 7) // 8) return init_func((rows, cols), dtype=dtype, device=device), state - elif to_order == 'col_ampere': + elif to_order == "col_ampere": # blocks of 32 columns and 32 rows - cols = 32*((cols+31)//32) - rows = 32*((rows+31)//32) + cols = 32 * ((cols + 31) // 32) + rows = 32 * ((rows + 31) // 32) return init_func((rows, cols), dtype=dtype, device=device), state else: - raise NotImplementedError(f'To_order not supported: {to_order}') + raise NotImplementedError(f"To_order not supported: {to_order}") -def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1]) - else: new_state = (state[1], to_order) + +def nvidia_transform( + A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None +): + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1] + ) + else: + new_state = (state[1], to_order) func = get_transform_func(A.dtype, from_order, to_order, transpose) shape = state[0] @@ -242,10 +293,10 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s elif ld is not None: n = math.prod(shape) dim1 = math.prod([shape[i] for i in ld]) - dim2 = ct.c_int32(n//dim1) + dim2 = ct.c_int32(n // dim1) dim1 = ct.c_int32(dim1) else: - dim1 = ct.c_int32(shape[0]*shape[1]) + dim1 = ct.c_int32(shape[0] * shape[1]) dim2 = ct.c_int32(shape[2]) ptr = CUBLAS_Context.get_instance().get_context(A.device) @@ -253,11 +304,13 @@ def nvidia_transform(A, to_order, from_order='row', out=None, transpose=False, s ptrOut = get_ptr(out) func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) - return out, new_state -def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tensor: - ''' + +def estimate_quantiles( + A: Tensor, out: Tensor = None, offset: float = 1 / 512 +) -> Tensor: + """ Estimates 256 equidistant quantiles on the input tensor eCDF. Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles @@ -282,18 +335,26 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens ------- torch.Tensor: The 256 quantiles in float32 datatype. - ''' - if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + """ + if out is None: + out = torch.zeros((256,), dtype=torch.float32, device=A.device) if A.dtype == torch.float32: - lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + lib.cestimate_quantiles_fp32( + get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) + ) elif A.dtype == torch.float16: - lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + lib.cestimate_quantiles_fp16( + get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()) + ) else: - raise NotImplementedError(f'Not supported data type {A.dtype}') + raise NotImplementedError(f"Not supported data type {A.dtype}") return out -def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor: - ''' + +def quantize_blockwise( + A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None +) -> Tensor: + """ Quantize tensor A in blocks of size 4096 values. Quantizes tensor A by dividing it into blocks of 4096 values. @@ -319,51 +380,96 @@ def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=N The 8-bit tensor. tuple(torch.Tensor, torch.Tensor): The quantization state to undo the quantization. - ''' + """ if code is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) if absmax is None: n = A.numel() num_blocks = 4096 - blocks = n//num_blocks + blocks = n // num_blocks blocks += 1 if n % num_blocks > 0 else 0 absmax = torch.zeros((blocks,), device=A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) - - if A.device.type != 'cpu': + if A.device.type != "cpu": if rand is not None: assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) + lib.cquantize_blockwise_stochastic_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + get_ptr(rand), + ct.c_int32(rand_offset), + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) + lib.cquantize_blockwise_stochastic_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + get_ptr(rand), + ct.c_int32(rand_offset), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: # cpu assert rand is None - lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(A.numel())) + lib.cquantize_blockwise_cpu_fp32( + get_ptr(code), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(A.numel()), + ) return out, (absmax, code) -def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, - absmax: Tensor=None, code: Tensor=None, out: Tensor=None, - blocksize: int=4096) -> Tensor: - ''' + +def dequantize_blockwise( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, + blocksize: int = 4096, +) -> Tensor: + """ Dequantizes blockwise quantized values. Dequantizes the tensor A with maximum absolute values absmax in @@ -374,7 +480,7 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, A : torch.Tensor The input 8-bit tensor. quant_state : tuple(torch.Tensor, torch.Tensor) - Tuple of code and absmax values. + Tuple of code and absmax values. absmax : torch.Tensor The absmax values. code : torch.Tensor @@ -387,57 +493,94 @@ def dequantize_blockwise(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, ------- torch.Tensor: Dequantized tensor (default: float32) - ''' + """ assert quant_state is not None or absmax is not None if code is None and quant_state is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) - if out is None: out = torch.zeros_like(A, dtype=torch.float32) - if quant_state is None: quant_state = (absmax, code) + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) + if quant_state is None: + quant_state = (absmax, code) if blocksize not in [2048, 4096]: - raise ValueError(f'The blockwise of {blocksize} is not supported. Supported values: [2048 4096]') + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]" + ) - if A.device.type != 'cpu': + if A.device.type != "cpu": if out.dtype == torch.float32: - lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp32( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + ) elif out.dtype == torch.float16: - lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + lib.cdequantize_blockwise_fp16( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(blocksize), + ct.c_int(A.numel()), + ) else: - raise ValueError(f'Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}') + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) else: - lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(A.numel())) - + lib.cdequantize_blockwise_cpu_fp32( + get_ptr(quant_state[1]), + get_ptr(A), + get_ptr(quant_state[0]), + get_ptr(out), + ct.c_int(A.numel()), + ) return out -def quantize(A: Tensor, code: Tensor=None, out: Tensor=None) -> Tensor: +def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: if code is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) absmax = torch.abs(A).max() - inp = A/absmax + inp = A / absmax out = quantize_no_absmax(inp, code, out) return out, (absmax, code) -def dequantize(A: Tensor, quant_state: Tuple[Tensor, Tensor]=None, absmax: Tensor=None, code: Tensor=None, out: Tensor=None) -> Tensor: + +def dequantize( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, +) -> Tensor: assert quant_state is not None or absmax is not None if code is None and quant_state is None: - if 'dynamic' not in name2qmap: name2qmap['dynamic'] = create_dynamic_map().to(A.device) - code = name2qmap['dynamic'] + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] code = code.to(A.device) - if quant_state is None: quant_state = (absmax, code) + if quant_state is None: + quant_state = (absmax, code) out = dequantize_no_absmax(A, quant_state[1], out) - return out*quant_state[0] + return out * quant_state[0] -def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: - ''' + +def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: + """ Quantizes input tensor to 8-bit. Quantizes the 32-bit input tensor `A` to the 8-bit output tensor @@ -456,13 +599,15 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: ------- torch.Tensor: Quantized 8-bit tensor. - ''' - if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + """ + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out -def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: - ''' + +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: + """ Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via @@ -481,17 +626,31 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor=None) -> Tensor: ------- torch.Tensor: 32-bit output tensor. - ''' - if out is None: out = torch.zeros_like(A, dtype=torch.float32) + """ + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) return out -def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Tensor, - beta1: float, eps: float, step: int, lr: float, - state2: Tensor=None, beta2: float=0.0, - weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0, skip_zeros=False) -> None: - ''' + +def optimizer_update_32bit( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Tensor = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Tensor = None, + max_unorm: float = 0.0, + skip_zeros=False, +) -> None: + """ Performs an inplace optimizer update with one or two optimizer states. Universal optimizer update for 32-bit state and 32/16-bit gradients/weights. @@ -528,33 +687,84 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten The maximum update norm relative to the weight norm. skip_zeros : bool Whether to skip zero-valued gradients or not (default: False). - ''' + """ param_norm = 0.0 if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) if optimizer_name not in str2optimizer32bit: - raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}') + raise NotImplementedError( + f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' + ) if g.dtype == torch.float32 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), - ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer32bit[optimizer_name][0]( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][1](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), - ct.c_float(param_norm), ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), ct.c_float(weight_decay), - ct.c_int32(step), ct.c_float(lr), ct.c_float(gnorm_scale), ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer32bit[optimizer_name][1]( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) -def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, - beta1: float, beta2: float, eps: float, - step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - max1: Tensor, max2: Tensor, new_max1: Tensor, new_max2: Tensor, - weight_decay: float=0.0, gnorm_scale: float=1.0, - unorm_vec: Tensor=None, max_unorm: float=0.0) -> None: - ''' + +def optimizer_update_8bit( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + state2: Tensor, + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: Tensor, + qmap2: Tensor, + max1: Tensor, + max2: Tensor, + new_max1: Tensor, + new_max2: Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Tensor = None, + max_unorm: float = 0.0, +) -> None: + """ Performs an inplace Adam update. Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights. @@ -602,56 +812,135 @@ def optimizer_update_8bit(optimizer_name: str, g: Tensor, p: Tensor, state1: Ten The tensor for the update norm. max_unorm : float The maximum update norm relative to the weight norm. - ''' + """ param_norm = 0.0 if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), - get_ptr(qmap1), get_ptr(qmap2), - get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2), - ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + str2optimizer8bit[optimizer_name][0]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - get_ptr(unorm_vec), ct.c_float(max_unorm), ct.c_float(param_norm), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), - get_ptr(qmap1), get_ptr(qmap2), - get_ptr(max1), get_ptr(max2), get_ptr(new_max1), get_ptr(new_max2), - ct.c_float(weight_decay),ct.c_float(gnorm_scale), ct.c_int32(g.numel())) + str2optimizer8bit[optimizer_name][1]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) -def optimizer_update_8bit_blockwise(optimizer_name: str, g: Tensor, p: Tensor, state1: Tensor, state2: Tensor, - beta1: float, beta2: float, eps: float, - step: int, lr: float, qmap1: Tensor, qmap2: Tensor, - absmax1: Tensor, absmax2: Tensor, weight_decay: float=0.0, gnorm_scale: float=1.0, - skip_zeros=False) -> None: - +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + state2: Tensor, + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: Tensor, + qmap2: Tensor, + absmax1: Tensor, + absmax2: Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][0](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer8bit_blockwise[optimizer_name][0]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][1](get_ptr(p), get_ptr(g), get_ptr(state1), get_ptr(state2), - ct.c_float(beta1), ct.c_float(beta2), ct.c_float(eps), - ct.c_int32(step), ct.c_float(lr), get_ptr(qmap1), get_ptr(qmap2), - get_ptr(absmax1), get_ptr(absmax2), ct.c_float(weight_decay), ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), ct.c_int32(g.numel())) + str2optimizer8bit_blockwise[optimizer_name][1]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) else: - raise ValueError(f'Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}') + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) -def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int=5): +def percentile_clipping( + grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 +): """Applies percentile clipping grad: torch.Tensor @@ -663,11 +952,21 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: """ if grad.dtype == torch.float32: - lib.cpercentile_clipping_g32(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel())) + lib.cpercentile_clipping_g32( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) elif grad.dtype == torch.float16: - lib.cpercentile_clipping_g16(get_ptr(grad), get_ptr(gnorm_vec), ct.c_int32(step), ct.c_int32(grad.numel())) + lib.cpercentile_clipping_g16( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) else: - raise ValueError(f'Gradient type {grad.dtype} not supported!') + raise ValueError(f"Gradient type {grad.dtype} not supported!") current_gnorm = torch.sqrt(gnorm_vec[step % 100]) vals, idx = torch.sort(gnorm_vec) @@ -675,31 +974,44 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: gnorm_scale = 1.0 if current_gnorm > clip_value: - gnorm_scale = clip_value/current_gnorm + gnorm_scale = clip_value / current_gnorm return current_gnorm, clip_value, gnorm_scale -def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor): +def histogram_scatter_add_2d( + histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor +): assert len(histogram.shape) == 2 assert histogram.dtype == torch.float32 assert source.dtype == torch.float32 assert index1.dtype == torch.int32 assert index2.dtype == torch.int32 - assert histogram.device.type == 'cuda' - assert index1.device.type == 'cuda' - assert index2.device.type == 'cuda' - assert source.device.type == 'cuda' + assert histogram.device.type == "cuda" + assert index1.device.type == "cuda" + assert index2.device.type == "cuda" + assert source.device.type == "cuda" maxdim1 = ct.c_int32(histogram.shape[0]) n = ct.c_int32(index1.numel()) - lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + lib.chistogram_scatter_add_2d( + get_ptr(histogram), + get_ptr(index1), + get_ptr(index2), + get_ptr(source), + maxdim1, + n, + ) + def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): - if not torch.cuda.is_initialized(): torch.cuda.init() + if not torch.cuda.is_initialized(): + torch.cuda.init() if A.dtype != expected_type or B.dtype != expected_type: - raise TypeError(f'Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}') + raise TypeError( + f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" + ) sA = A.shape sB = B.shape @@ -709,64 +1021,101 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 correct = True if len(sA) == 2 and len(sB) == 2: - if not tA and not tB and A.shape[1] != B.shape[0]: correct = False - elif tA and not tB and A.shape[0] != B.shape[0]: correct = False - elif tA and tB and A.shape[0] != B.shape[1]: correct = False - elif not tA and tB and A.shape[1] != B.shape[1]: correct = False + if not tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[0] != B.shape[0]: + correct = False + elif tA and tB and A.shape[0] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[1] != B.shape[1]: + correct = False elif len(sA) == 3 and len(sB) == 2: - if not tA and not tB and A.shape[2] != B.shape[0]: correct = False - elif tA and not tB and A.shape[1] != B.shape[0]: correct = False - elif tA and tB and A.shape[1] != B.shape[1]: correct = False - elif not tA and tB and A.shape[2] != B.shape[1]: correct = False + if not tA and not tB and A.shape[2] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[1]: + correct = False elif len(sA) == 3 and len(sB) == 3: - if not tA and not tB and A.shape[2] != B.shape[1]: correct = False - elif tA and not tB and A.shape[1] != B.shape[1]: correct = False - elif tA and tB and A.shape[1] != B.shape[2]: correct = False - elif not tA and tB and A.shape[2] != B.shape[2]: correct = False + if not tA and not tB and A.shape[2] != B.shape[1]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[1]: + correct = False + elif tA and tB and A.shape[1] != B.shape[2]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[2]: + correct = False if out is not None: sout = out.shape # special case common in backprop if not correct and len(sA) == 3 and len(sB) == 3: - if (sout[0] == sA[2] and sout[1] == sB[2] and - sA[0] == sB[0] and sA[1] == sB[1]): + if ( + sout[0] == sA[2] + and sout[1] == sB[2] + and sA[0] == sB[0] + and sA[1] == sB[1] + ): correct = True else: if len(sA) == 2 and len(sB) == 2: - if not tA and not tB: sout = (sA[0], sB[1]) - elif tA and tB: sout = (sA[1], sB[0]) - elif tA and not tB: sout = (sA[1], sB[1]) - elif not tA and tB: sout = (sA[0], sB[0]) + if not tA and not tB: + sout = (sA[0], sB[1]) + elif tA and tB: + sout = (sA[1], sB[0]) + elif tA and not tB: + sout = (sA[1], sB[1]) + elif not tA and tB: + sout = (sA[0], sB[0]) elif len(sA) == 3 and len(sB) == 2: - if not tA and not tB: sout = (sA[0], sA[1], sB[1]) - elif tA and tB: sout = (sA[0], sA[2], sB[0]) - elif tA and not tB: sout = (sA[0], sA[2], sB[1]) - elif not tA and tB: sout = (sA[0], sA[1], sB[0]) + if not tA and not tB: + sout = (sA[0], sA[1], sB[1]) + elif tA and tB: + sout = (sA[0], sA[2], sB[0]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[1]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[0]) elif len(sA) == 3 and len(sB) == 3: - if not tA and not tB: sout = (sA[0], sA[1], sB[2]) - elif tA and tB: sout = (sA[0], sA[2], sB[1]) - elif tA and not tB: sout = (sA[0], sA[2], sB[2]) - elif not tA and tB: sout = (sA[0], sA[1], sB[1]) - + if not tA and not tB: + sout = (sA[0], sA[1], sB[2]) + elif tA and tB: + sout = (sA[0], sA[2], sB[1]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[2]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[1]) if not correct: - raise ValueError(f'Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}.') + raise ValueError( + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + ) return sout -def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False): + +def igemm( + A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False +): sout = check_matmul(A, B, out, transposed_A, transposed_B) - if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + if out is None: + out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) if len(A.shape) == 3 and len(B.shape) == 3: if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]: return batched_igemm(A, B, out) sA = A.shape sB = B.shape - if transposed_A and len(sA) == 2: sA = (sA[1], sA[0]) - elif transposed_A and len(sA) == 3: sA = (sA[0], sA[2], sA[0]) - if transposed_B and len(sB) == 2: sB = (sB[1], sB[0]) - elif transposed_B and len(sB) == 3: sB = (sB[0], sB[2], sB[0]) + if transposed_A and len(sA) == 2: + sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: + sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: + sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: + sB = (sB[0], sB[2], sB[0]) # this is a mess: cuBLAS expect column major, but PyTorch is row major. # So to perform the matrix multiplication, we have to treat A, B, and C matrices # (transpose of row major is column major) @@ -777,23 +1126,28 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] if len(sB) == 2: - if B.stride()[0] == B.shape[1]: transposed_B = False - elif B.stride()[1] == B.shape[0]: transposed_B = True + if B.stride()[0] == B.shape[1]: + transposed_B = False + elif B.stride()[1] == B.shape[0]: + transposed_B = True if len(A.shape) == 2: - if A.stride()[0] == A.shape[1]: transposed_A = False - elif A.stride()[1] == A.shape[0]: transposed_A = True + if A.stride()[0] == A.shape[1]: + transposed_A = False + elif A.stride()[1] == A.shape[0]: + transposed_A = True else: - if A.stride()[1] == A.shape[2]: transposed_A = False - elif A.stride()[2] == A.shape[1]: transposed_A = True + if A.stride()[1] == A.shape[2]: + transposed_A = False + elif A.stride()[2] == A.shape[1]: + transposed_A = True if len(sA) == 2: n = sA[0] ldb = A.stride()[1 if transposed_A else 0] elif len(sA) == 3 and len(sB) == 2: - n = sA[0]*sA[1] + n = sA[0] * sA[1] ldb = sA[2] - m = sB[1] k = sB[0] lda = B.stride()[(1 if transposed_B else 0)] @@ -802,34 +1156,52 @@ def igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed # special case assert len(sA) == 3 if not (sA[0] == sB[0] and sA[1] == sB[1]): - raise ValueError(f'Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}') + raise ValueError( + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + ) transposed_A = True transposed_B = False m = sB[2] n = sA[2] - k = sB[0]*sB[1] + k = sB[0] * sB[1] lda = m ldb = sA[2] ldc = m - ptr = CUBLAS_Context.get_instance().get_context(A.device) # B^T @ A^T = C^T - # [km, nk -> mn] - lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + # [km, nk -> mn] + lib.cigemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ) return out -def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, transposed_B=False): +def batched_igemm( + A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False +): if not len(A.shape) == 3 or not len(B.shape) == 3: - raise ValueError(f'Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}') + raise ValueError( + f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" + ) sout = check_matmul(A, B, out, transposed_A, transposed_B) - if out is None: out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + if out is None: + out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) if B.is_contiguous(): lda = B.stride()[1] @@ -886,17 +1258,33 @@ def batched_igemm(A: Tensor, B: Tensor, out: Tensor=None, transposed_A=False, tr ldc = m - strideA = B.shape[1]*B.shape[2] - strideB = A.shape[1]*A.shape[2] - strideC = A.shape[1]*B.shape[2] + strideA = B.shape[1] * B.shape[2] + strideB = A.shape[1] * A.shape[2] + strideC = A.shape[1] * B.shape[2] ptr = CUBLAS_Context.get_instance().get_context(A.device) - lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), - get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), - ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + lib.cbatched_igemm( + ptr, + ct.c_bool(transposed_B), + ct.c_bool(transposed_A), + ct.c_int32(m), + ct.c_int32(n), + ct.c_int32(k), + get_ptr(B), + get_ptr(A), + get_ptr(out), + ct.c_int32(lda), + ct.c_int32(ldb), + ct.c_int32(ldc), + ct.c_long(strideA), + ct.c_long(strideB), + ct.c_long(strideC), + ct.c_uint32(num_batch), + ) return out + def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] @@ -905,28 +1293,34 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): if dimsA == 2: m = shapeA[0] elif dimsA == 3: - m = shapeA[0]*shapeA[1] + m = shapeA[0] * shapeA[1] if dimsB == 2: rows = n = shapeB[0] elif dimsB == 3: - rows = n = shapeB[0]*shapeB[1] + rows = n = shapeB[0] * shapeB[1] if dimsA == 2 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, 'col32', 'row') + out, Sout = get_transform_buffer( + (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" + ) elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, 'col32', 'row') + out, Sout = get_transform_buffer( + (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" + ) - assert dimsB != 3, 'len(B.shape)==3 not supported' - assert A.device.type == 'cuda' - assert B.device.type == 'cuda' + assert dimsB != 3, "len(B.shape)==3 not supported" + assert A.device.type == "cuda" + assert B.device.type == "cuda" assert A.dtype == torch.int8 assert B.dtype == torch.int8 assert out.dtype == dtype - assert SA[1] == 'col32' - assert SB[1] in ['col_turing', 'col_ampere'] - assert Sout[1] == 'col32' - assert shapeA[-1] == shapeB[-1], f'Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}' + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" + assert ( + shapeA[-1] == shapeB[-1] + ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" formatB = SB[1] prev_device = A.device torch.cuda.set_device(A.device) @@ -937,53 +1331,76 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ptrC = get_ptr(out) k = shapeA[-1] - lda = ct.c_int32(m*32) - if formatB == 'col_turing': + lda = ct.c_int32(m * 32) + if formatB == "col_turing": # turing: tiles with rows filled up to multiple of 8 rows by 32 columns # n = rows - ldb = ct.c_int32(((rows+7)//8)*8*32) + ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) else: # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns # n = rows - ldb = ct.c_int32(((rows+31)//32)*32*32) + ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - ldc = ct.c_int32(m*32) + ldc = ct.c_int32(m * 32) m = ct.c_int32(m) n = ct.c_int32(n) k = ct.c_int32(k) has_error = 0 ptrRowScale = get_ptr(None) - if formatB == 'col_turing': + if formatB == "col_turing": if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_turing_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) else: - has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) - elif formatB == 'col_ampere': + has_error = lib.cigemmlt_turing_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) else: - has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) + has_error = lib.cigemmlt_ampere_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) if has_error == 1: - raise Exception('cublasLt ran into an error!') + raise Exception("cublasLt ran into an error!") torch.cuda.set_device(prev_device) - return out, Sout -def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None): +def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, +): assert A.dtype == torch.int32 out_shape = quant_state[0] - if len(out_shape) == 3: out_shape = (out_shape[0]*out_shape[1], out_shape[2]) + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) - if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) - if new_col_stats is None: new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) - assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" - assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" + if out is None: + out = torch.empty(out_shape, dtype=torch.float16, device=A.device) + if new_row_stats is None: + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) + if new_col_stats is None: + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert ( + new_row_stats.shape[0] == row_stats.shape[0] + ), f"{new_row_stats.shape} vs {row_stats.shape}" + assert ( + new_col_stats.shape[0] == col_stats.shape[0] + ), f"{new_col_stats.shape} vs {col_stats.shape}" ptrA = get_ptr(A) ptrOut = get_ptr(out) @@ -994,27 +1411,47 @@ def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=Non numRows = ct.c_int32(out_shape[0]) numCols = ct.c_int32(out_shape[1]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, numRows, numCols) + lib.cdequant_mm_int32_fp16( + ptrA, + ptrRowStats, + ptrColStats, + ptrOut, + ptrNewRowStats, + ptrNewColStats, + numRows, + numCols, + ) return out -def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): +def get_colrow_absmax( + A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 +): assert A.dtype == torch.float16 device = A.device cols = A.shape[-1] if len(A.shape) == 3: - rows = A.shape[0]*A.shape[1] + rows = A.shape[0] * A.shape[1] else: rows = A.shape[0] - col_tiles = (cols+255)//256 - tiled_rows = ((rows+15)//16)*16 - if row_stats is None: row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) - if col_stats is None: col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) + col_tiles = (cols + 255) // 256 + tiled_rows = ((rows + 15) // 16) * 16 + if row_stats is None: + row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) + if col_stats is None: + col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( + -50000.0 + ) - if nnz_block_ptr is None and threshold > 0.0: nnz_block_ptr = torch.zeros(((tiled_rows*col_tiles)+1,), dtype=torch.int32, device=device) + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros( + ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device + ) ptrA = get_ptr(A) ptrRowStats = get_ptr(row_stats) @@ -1024,16 +1461,17 @@ def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, thr cols = ct.c_int32(cols) prev_device = pre_call(A.device) - lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) + lib.cget_col_row_stats( + ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols + ) post_call(prev_device) - if threshold > 0.0: nnz_block_ptr.cumsum_(0) - return row_stats, col_stats, nnz_block_ptr + class COOSparseTensor(object): def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 @@ -1050,6 +1488,7 @@ class COOSparseTensor(object): self.colidx = colidx self.values = values + class CSRSparseTensor(object): def __init__(self, rows, cols, nnz, rowptr, colidx, values): assert rowptr.dtype == torch.int32 @@ -1057,7 +1496,7 @@ class CSRSparseTensor(object): assert values.dtype == torch.float16 assert values.numel() == nnz assert colidx.numel() == nnz - assert rowptr.numel() == rows+1 + assert rowptr.numel() == rows + 1 self.rows = rows self.cols = cols @@ -1066,6 +1505,7 @@ class CSRSparseTensor(object): self.colidx = colidx self.values = values + class CSCSparseTensor(object): def __init__(self, rows, cols, nnz, colptr, rowidx, values): assert colptr.dtype == torch.int32 @@ -1073,7 +1513,7 @@ class CSCSparseTensor(object): assert values.dtype == torch.float16 assert values.numel() == nnz assert rowidx.numel() == nnz - assert colptr.numel() == cols+1 + assert colptr.numel() == cols + 1 self.rows = rows self.cols = cols @@ -1082,13 +1522,17 @@ class CSCSparseTensor(object): self.rowidx = rowidx self.values = values + def coo2csr(cooA): values, counts = torch.unique(cooA.rowidx, return_counts=True) values.add_(1) - rowptr = torch.zeros((cooA.rows+1, ), dtype=torch.int32, device=cooA.rowidx.device) + rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.cumsum_(0) - return CSRSparseTensor(cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values) + return CSRSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values + ) + def coo2csc(cooA): val, col2rowidx = torch.sort(cooA.colidx) @@ -1096,11 +1540,12 @@ def coo2csc(cooA): values = cooA.values[col2rowidx] colvalues, counts = torch.unique(val, return_counts=True) colvalues.add_(1) - colptr = torch.zeros((cooA.cols+1, ), dtype=torch.int32, device=cooA.colidx.device) + colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.cumsum_(0) return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) + def coo_zeros(rows, cols, nnz, device, dtype=torch.half): rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) @@ -1108,23 +1553,27 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): +def double_quant( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): device = A.device assert A.dtype == torch.half - assert device.type == 'cuda' + assert device.type == "cuda" prev_device = pre_call(A.device) cols = A.shape[-1] if len(A.shape) == 3: - rows = A.shape[0]*A.shape[1] + rows = A.shape[0] * A.shape[1] else: rows = A.shape[0] if row_stats is None or col_stats is None: row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) - if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) coo_tensor = None ptrA = get_ptr(A) @@ -1136,21 +1585,62 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) + coo_tensor = coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) ptrRowPtr = get_ptr(nnz_row_ptr) - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, ptrRowIdx, ptrColIdx, ptrVal, ptrRowPtr, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) val, idx = torch.sort(coo_tensor.rowidx) coo_tensor.rowidx = val coo_tensor.colidx = coo_tensor.colidx[idx] coo_tensor.values = coo_tensor.values[idx] else: - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(0.0), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(0.0), + ct.c_int32(rows), + ct.c_int32(cols), + ) else: - lib.cdouble_rowcol_quant(ptrA, ptrRowStats, ptrColStats, ptrOutCol, ptrOutRow, None, None, None, None, ct.c_float(threshold), ct.c_int32(rows), ct.c_int32(cols)) + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) post_call(prev_device) return out_row, out_col, row_stats, col_stats, coo_tensor @@ -1159,69 +1649,81 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, def get_special_format_str(): major, minor = torch.cuda.get_device_capability() if major < 7: - print(f'Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!') + print( + f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!" + ) assert major >= 7 - if major == 7: return 'col_turing' - elif major == 8: return 'col_ampere' - else: return 'col_turing' + if major == 7: + return "col_turing" + elif major == 8: + return "col_ampere" + else: + return "col_turing" - - -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) +def transform( + A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None +): + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1], transpose + ) + else: + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: dim1 = ct.c_int32(shape[0]) dim2 = ct.c_int32(shape[1]) else: - dim1 = ct.c_int32(shape[0]*shape[1]) + dim1 = ct.c_int32(shape[0] * shape[1]) dim2 = ct.c_int32(shape[2]) ptrA = get_ptr(A) ptrOut = get_ptr(out) - if to_order == 'col32': + if to_order == "col32": if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'col_turing': + elif to_order == "col_turing": if transpose: lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'col_ampere': + elif to_order == "col_ampere": if transpose: lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == 'row': - if from_order == 'col_turing': + elif to_order == "row": + if from_order == "col_turing": lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == 'col_ampere': + elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - - - + raise NotImplementedError( + f"Transform function not implemented: From {from_order} to {to_order}" + ) return out, new_state + def spmm_coo(cooA, B, out=None): - if out is None: out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) + if out is None: + out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz assert cooA.values.numel() == nnz assert cooA.cols == B.shape[0] - transposed_B = (False if B.is_contiguous() else True) + transposed_B = False if B.is_contiguous() else True ldb = B.stride()[(1 if transposed_B else 0)] ldc = B.shape[1] @@ -1240,19 +1742,37 @@ def spmm_coo(cooA, B, out=None): cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + lib.cspmm_coo( + ptr, + ptrRowidx, + ptrColidx, + ptrValues, + cnnz, + crowsA, + ccolsA, + ccolsB, + cldb, + ptrB, + cldc, + ptrC, + ct.c_bool(transposed_B), + ) return out + def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): - if out is None: out = torch.zeros((cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype) + if out is None: + out = torch.zeros( + (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype + ) nnz = cooA.nnz assert cooA.rowidx.numel() == nnz assert cooA.colidx.numel() == nnz assert cooA.values.numel() == nnz - assert cooA.cols == B.shape[0], f'{cooA.cols} vs {B.shape}' + assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}" - transposed_B = (False if B.is_contiguous() else True) + transposed_B = False if B.is_contiguous() else True ldb = B.stride()[(1 if transposed_B else 0)] ldc = B.shape[1] @@ -1262,7 +1782,9 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): max_count, max_idx = torch.sort(counts, descending=True) max_idx = max_idx.int() max_count = max_count.int() - assert max_count[0] <= 32, f'Current max count per row is 8 but found {max_count[0]}.' + assert ( + max_count[0] <= 32 + ), f"Current max count per row is 8 but found {max_count[0]}." assert B.dtype in [torch.float16, torch.int8] ptrOffset = get_ptr(offset) ptrMaxCount = get_ptr(max_count) @@ -1282,134 +1804,183 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ccolsB = ct.c_int32(B.shape[1]) cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - #print(cooA.rowidx[:64]) - #print(cooA.colidx[:64].sort()[0]) + # print(cooA.rowidx[:64]) + # print(cooA.colidx[:64].sort()[0]) if B.dtype == torch.float16: - lib.cspmm_coo_very_sparse_naive_fp16(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) + lib.cspmm_coo_very_sparse_naive_fp16( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) elif B.dtype == torch.int8: - lib.cspmm_coo_very_sparse_naive_int8(ptrMaxCount, ptrMaxIdx, ptrOffset, ptrRowidx, ptrColidx, ptrValues, ptrB, ptrC, ptrDequantStats, cnnz_rows, cnnz, crowsA, crowsB, ccolsB) - #else: assertion error + lib.cspmm_coo_very_sparse_naive_int8( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) + # else: assertion error return out C = 127.0 -def vectorwise_quant(x, dim=1, quant_type='vector'): - if quant_type == 'linear': + +def vectorwise_quant(x, dim=1, quant_type="vector"): + if quant_type == "linear": max1 = torch.abs(x).max().float() - xq = torch.round(x/max1*127).to(torch.int8) + xq = torch.round(x / max1 * 127).to(torch.int8) return xq, max1 - elif quant_type in ['vector', 'row']: + elif quant_type in ["vector", "row"]: max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x*(C/max1)).to(torch.int8) + xq = torch.round(x * (C / max1)).to(torch.int8) return xq, max1 - elif quant_type == 'zeropoint': + elif quant_type == "zeropoint": dtype = x.dtype x = x.float() dyna = x.max() - x.min() - if dyna == 0: dyna = 1 - qx = 255./dyna + if dyna == 0: + dyna = 1 + qx = 255.0 / dyna minx = x.min() - zpx = torch.round(minx* qx) - x = torch.round(qx*x - zpx) + zpx + zpx = torch.round(minx * qx) + x = torch.round(qx * x - zpx) + zpx return x, qx - elif quant_type in ['vector-zeropoint', 'row-zeropoint']: + elif quant_type in ["vector-zeropoint", "row-zeropoint"]: dtype = x.dtype x = x.float() - dyna = (torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)) - dyna[dyna==0] = 1 - qx = 255./dyna + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( + x, dim=dim, keepdim=True + ) + dyna[dyna == 0] = 1 + qx = 255.0 / dyna minx = torch.amin(x, dim=dim, keepdim=True) - zpx = torch.round(minx* qx) - x = torch.round(qx*x - zpx) + zpx + zpx = torch.round(minx * qx) + x = torch.round(qx * x - zpx) + zpx return x, qx - elif quant_type == 'truncated-vector': + elif quant_type == "truncated-vector": with torch.no_grad(): absx = torch.abs(x) max1 = torch.amax(absx, dim=dim, keepdim=True) - max1 = max1*0.7 - idx = (absx > max1.expand_as(absx)) + max1 = max1 * 0.7 + idx = absx > max1.expand_as(absx) sign = torch.sign(x[idx]) - x[idx] = max1.expand_as(absx)[idx]*sign - xq = torch.round(x/max1*C).to(torch.int8) + x[idx] = max1.expand_as(absx)[idx] * sign + xq = torch.round(x / max1 * C).to(torch.int8) return xq, max1 - else: return None + else: + return None -def vectorwise_dequant(xq, max1, quant_type='vector'): - if quant_type == 'vector': - x = (xq/C*max1).to(torch.float32) + +def vectorwise_dequant(xq, max1, quant_type="vector"): + if quant_type == "vector": + x = (xq / C * max1).to(torch.float32) return x - else: return None + else: + return None -def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type='vector'): - if quant_type == 'linear': - norm = S1*S2/(C*C) + +def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): + if quant_type == "linear": + norm = S1 * S2 / (C * C) # double cast needed to prevent overflows - return (xq.float()*norm).to(dtype) - elif quant_type == 'zeropoint': - norm = 1.0/(S1*S2) - return (xq.float()*norm).to(dtype) - elif quant_type == 'row-zeropoint': - norm = 1.0/(S1*S2) + return (xq.float() * norm).to(dtype) + elif quant_type == "zeropoint": + norm = 1.0 / (S1 * S2) + return (xq.float() * norm).to(dtype) + elif quant_type == "row-zeropoint": + norm = 1.0 / (S1 * S2) x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: x *= norm else: x *= norm return x.to(dtype) - elif quant_type == 'vector-zeropoint': + elif quant_type == "vector-zeropoint": x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= 1.0/S1 + x *= 1.0 / S1 else: - x *= 1.0/S1 - x *= 1.0/S2.t() + x *= 1.0 / S1 + x *= 1.0 / S2.t() return x.to(dtype) - elif quant_type == 'row': + elif quant_type == "row": x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= S1*S2/(C*C) + x *= S1 * S2 / (C * C) else: - x *= S1*S2/(C*C) + x *= S1 * S2 / (C * C) return x.to(dtype) - elif quant_type in ['truncated-vector', 'vector']: + elif quant_type in ["truncated-vector", "vector"]: x = xq.float() - if len(S1.shape) == 3 and len(x.shape) == 2: S1 = S1.squeeze(0) - if len(S2.shape) == 3 and len(x.shape) == 2: S2 = S2.squeeze(0) + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) if len(S1.shape) == 2: - x *= S1/C + x *= S1 / C else: - x *= S1/C - x *= S2/C + x *= S1 / C + x *= S2 / C return x.to(dtype) - else: return None + else: + return None def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): - offset = B.float().t().sum(0)*(SA[0]+SA[1]) + offset = B.float().t().sum(0) * (SA[0] + SA[1]) x = xq.float() - if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SB.shape) == 3: + SB = SB.squeeze(0) if len(SB.shape) == 2: - x *= SB.t()/127 + x *= SB.t() / 127 else: - x *= SB/127 - x *= SA[1]/127 - x +=offset + x *= SB / 127 + x *= SA[1] / 127 + x += offset return x.to(dtype) + def extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] - assert formatA in ['col_turing', 'col_ampere'] - assert A.device.type == 'cuda' + assert formatA in ["col_turing", "col_ampere"] + assert A.device.type == "cuda" out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) @@ -1420,13 +1991,9 @@ def extract_outliers(A, SA, idx): ptrIdx = get_ptr(idx) ptrOut = get_ptr(out) - if formatA == 'col_turing': + if formatA == "col_turing": lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == 'col_ampere': + elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) return out - - - - diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 03b4655..98d4aa0 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -1,5 +1,5 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params +from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 5013d0b..9ce3ac8 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -1,39 +1,59 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set, + Tuple, TypeVar, Union, overload) + import torch -import bitsandbytes as bnb - -from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict - -from torch import Tensor, device, dtype -from torch import nn -from torch.nn.parameter import Parameter import torch.nn.functional as F +from torch import Tensor, device, dtype, nn +from torch.nn.parameter import Parameter +import bitsandbytes as bnb from bitsandbytes.optim import GlobalOptimManager -T = TypeVar('T', bound='torch.nn.Module') +T = TypeVar("T", bound="torch.nn.Module") + class StableEmbedding(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None) -> None: - super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super(StableEmbedding, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + ) self.norm = torch.nn.LayerNorm(embedding_dim) - GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) self._fill_padding_idx_with_zero() - ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding to make the Layer compatible with Pytorch < 1.9. This means that if this changes in future PyTorch releases this need to change too which is cumbersome. However, with this we can ensure compatibility with previous PyTorch releases. - ''' + """ + def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): @@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding): def forward(self, input: Tensor) -> Tensor: emb = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return self.norm(emb) class Embedding(torch.nn.Embedding): - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, - max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, - sparse: bool = False, _weight: Optional[Tensor] = None) -> None: - super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) - GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super(Embedding, self).__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + ) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) def reset_parameters(self) -> None: torch.nn.init.xavier_uniform_(self.weight) self._fill_padding_idx_with_zero() - ''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding to make the Layer compatible with Pytorch < 1.9. This means that if this changes in future PyTorch releases this need to change too which is cumbersome. However, with this we can ensure compatibility with previous PyTorch releases. - ''' + """ + def _fill_padding_idx_with_zero(self) -> None: if self.padding_idx is not None: with torch.no_grad(): @@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding): def forward(self, input: Tensor) -> Tensor: emb = F.embedding( - input, self.weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse) + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) return emb + class Int8Params(torch.nn.Parameter): - def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None): + def __new__( + cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None + ): cls.has_fp16_weights = has_fp16_weights cls.CB = None cls.SCB = None @@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter): del CBt del SCBt self.data = CB - setattr(self, 'CB', CB) - setattr(self, 'SCB', SCB) + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) return self @overload - def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., - non_blocking: bool = ...) -> T: + def to( + self: T, + device: Optional[Union[int, device]] = ..., + dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ..., + ) -> T: ... @overload @@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter): ... def to(self, *args, **kwargs): - device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) - if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device) + if ( + device is not None + and device.type == "cuda" + and self.data.device.type == "cpu" + ): + return self.cuda(device) else: - new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights) + new_param = Int8Params( + super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, + has_fp16_weights=self.has_fp16_weights, + ) new_param.CB = self.CB new_param.SCB = self.SCB return new_param - class Linear8bitLt(nn.Linear): - def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + threshold=0.0, + index=None, + ): super(Linear8bitLt, self).__init__(input_features, output_features, bias) self.state = bnb.MatmulLtState() - self.index=index + self.index = index self.state.threshold = threshold self.state.has_fp16_weights = has_fp16_weights @@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear): def forward(self, x): self.state.is_training = self.training - if self.weight.CB is not None: self.init_8bit_state() - #assert not self.state.has_fp16_weights - #if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None + if self.weight.CB is not None: + self.init_8bit_state() + # assert not self.state.has_fp16_weights + # if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None out = bnb.matmul(x, self.weight, state=self.state) @@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear): return out + class Linear8bit(nn.Linear): - def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False): + def __init__( + self, + input_features, + output_features, + bias=True, + quant_type="vector", + index=None, + args=None, + sparse_decomp=False, + ): super(Linear8bit, self).__init__(input_features, output_features, bias) self.quant_type = quant_type self.index = index @@ -178,15 +266,24 @@ class Linear8bit(nn.Linear): self.iter += 1 if self.iter % self.args.clip_freq == 0: with torch.no_grad(): - maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx) + maxval, maxidx = torch.topk( + torch.abs(self.weight.flatten()), k=self.args.clip_idx + ) if not dist.is_initialized() or dist.get_rank() == 0: - print('clip', maxval[-1].item()) + print("clip", maxval[-1].item()) self.weight.clip_(-maxval[-1], maxval[-1]) - if self.args is not None: - out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type) + out = bnb.nn.functional.sparse_decomposed_linear8bit( + x, + self.weight, + self.bias, + qval=self.args.sparse_decomp_val, + quant_type=self.args.quant_type, + ) else: - out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type) + out = bnb.nn.functional.linear8bit( + x, self.weight, self.bias, quant_type=self.args.quant_type + ) return out diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 42b5bc0..a76d717 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.cextension import COMPILED_WITH_CUDA diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py index 4f51250..43e3973 100644 --- a/bitsandbytes/optim/adagrad.py +++ b/bitsandbytes/optim/adagrad.py @@ -1,12 +1,25 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class Adagrad(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: @@ -14,15 +27,39 @@ class Adagrad(Optimizer1State): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') - super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise ValueError("Lr Decay != 0.0 not supported!") + super(Adagrad, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adagrad8bit(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=8, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: @@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') + raise ValueError("Lr Decay != 0.0 not supported!") assert block_wise - super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + super(Adagrad8bit, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adagrad32bit(Optimizer1State): - def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, - optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= weight_decay: @@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if initial_accumulator_value != 0.0: - raise ValueError('Initial accumulator value != 0.0 not supported!') + raise ValueError("Initial accumulator value != 0.0 not supported!") if lr_decay != 0.0: - raise ValueError('Lr Decay != 0.0 not supported!') - super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise ValueError("Lr Decay != 0.0 not supported!") + super(Adagrad32bit, self).__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index ed1b9f0..5cfaa28 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -1,6 +1,6 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import math @@ -8,29 +8,97 @@ import os import torch import torch.distributed as dist -from bitsandbytes.optim.optimizer import Optimizer2State + import bitsandbytes.functional as F +from bitsandbytes.optim.optimizer import Optimizer2State + class Adam(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam, self).__init__('adam', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adam8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam8bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam8bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class Adam32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(Adam32bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(Adam32bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) class AnalysisAdam(torch.optim.Optimizer): @@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer): eps=1e-8, weight_decay=0, amsgrad=False, - bnb_analysis='dynamic-blockwise', - savedir=None + bnb_analysis="dynamic-blockwise", + savedir=None, ): defaults = dict( lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad @@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer): state["exp_avg"] = torch.zeros_like(p_data_fp32) # Exponential moving average of squared gradient values state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) - state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device) - state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device) - state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device) + state["abserrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["relerrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) @@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer): bias_correction1 = 1 - beta1 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"] step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 - e = state['abserrors'] - rele = state['relerrors'] - counts = state['counts'] + e = state["abserrors"] + rele = state["relerrors"] + counts = state["counts"] if group["weight_decay"] != 0: p_data_fp32.add_( @@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer): if amsgrad: max_exp_avg_sq = state["max_exp_avg_sq"] - # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) - update_fp32 = exp_avg/denom + update_fp32 = exp_avg / denom - if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000: + if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: # embedding layer or too small - p_data_fp32 += -step_size*update_fp32 + p_data_fp32 += -step_size * update_fp32 else: - if self.analysis == 'dynamic-blockwise': + if self.analysis == "dynamic-blockwise": code1 = F.create_dynamic_map(signed=True).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device) C1, S1 = F.quantize_blockwise(exp_avg, code=code1) state1 = F.dequantize_blockwise(C1, S1) C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2) state2 = F.dequantize_blockwise(C2, S2) - elif self.analysis == 'dynamic': + elif self.analysis == "dynamic": code1 = F.create_dynamic_map(signed=True).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device) C1, S1 = F.quantize(exp_avg, code=code1) state1 = F.dequantize(C1, S1) C2, S2 = F.quantize(exp_avg_sq, code=code2) state2 = F.dequantize(C2, S2) - elif self.analysis == 'linear': + elif self.analysis == "linear": code1 = F.create_linear_map(signed=True).to(p.device) code2 = F.create_linear_map(signed=False).to(p.device) C1, S1 = F.quantize(exp_avg, code=code1) state1 = F.dequantize(C1, S1) C2, S2 = F.quantize(exp_avg_sq, code=code2) state2 = F.dequantize(C2, S2) - elif self.analysis == 'quantile': + elif self.analysis == "quantile": code1 = F.estimate_quantiles(exp_avg) code2 = F.estimate_quantiles(exp_avg_sq) C1 = F.quantize_no_absmax(exp_avg, code=code1) state1 = F.dequantize_no_absmax(C1, code1) C2 = F.quantize_no_absmax(exp_avg_sq, code=code2) state2 = F.dequantize_no_absmax(C2, code2) - elif self.analysis == 'my-quantization-routine': + elif self.analysis == "my-quantization-routine": pass # 1. get code # 2. quantize # 3. dequantize # Error will be calculated automatically! else: - raise ValueError(f'Invalid analysis value: {self.analysis}!') + raise ValueError(f"Invalid analysis value: {self.analysis}!") denom = state2.sqrt().add_(group["eps"]) - update_8bit = state1/denom + update_8bit = state1 / denom - abserr = torch.abs(update_8bit-update_fp32) - relerr = abserr/torch.abs(update_fp32+1e-6) + abserr = torch.abs(update_8bit - update_fp32) + relerr = abserr / torch.abs(update_fp32 + 1e-6) C1, C2 = C1.int(), C2.int() F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) - F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) - - p_data_fp32 += -step_size*update_fp32 + F.histogram_scatter_add_2d( + counts, C1.int(), C2.int(), torch.ones_like(abserr) + ) + p_data_fp32 += -step_size * update_fp32 if not dist.is_initialized() or dist.get_rank() == 0: - if self.savedir != '' and state['step'] % 100 == 0: - if not os.path.exists(self.savedir): os.makedirs(self.savedir) - shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape]) - pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') - pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') - pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') + if self.savedir != "" and state["step"] % 100 == 0: + if not os.path.exists(self.savedir): + os.makedirs(self.savedir) + shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) + pathe = os.path.join( + self.savedir, f"{p_id}_{shapestr}_abserr.pkl" + ) + pathrele = os.path.join( + self.savedir, f"{p_id}_{shapestr}_relerr.pkl" + ) + pathcounts = os.path.join( + self.savedir, f"{p_id}_{shapestr}_counts.pkl" + ) torch.save(e, pathe) torch.save(rele, pathrele) torch.save(counts, pathcounts) @@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer): if p.data.dtype in {torch.float16, torch.bfloat16}: p.data.copy_(p_data_fp32) - - return loss diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index c4f0355..d0b3bde 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -1,27 +1,93 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer2State + class AdamW(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW, self).__init__('adam', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class AdamW8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW8bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW8bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class AdamW32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=1e-2, amsgrad=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): - super(AdamW32bit, self).__init__('adam', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) - + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=1e-2, + amsgrad=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super(AdamW32bit, self).__init__( + "adam", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py index 58cc13d..8f365f7 100644 --- a/bitsandbytes/optim/lamb.py +++ b/bitsandbytes/optim/lamb.py @@ -1,28 +1,105 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer2State + class LAMB(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB, self).__init__('lamb', params, lr, betas, eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) + class LAMB8bit(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB8bit, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) + class LAMB32bit(Optimizer2State): - def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): - super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) - - + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super(LAMB32bit, self).__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py index 912520d..c6cf5c6 100644 --- a/bitsandbytes/optim/lars.py +++ b/bitsandbytes/optim/lars.py @@ -1,43 +1,121 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch - from torch.optim import Optimizer + from bitsandbytes.optim.optimizer import Optimizer1State + class LARS(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError(f"LARS without momentum is not supported!") + super(LARS, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + class LARS8bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError(f"LARS without momentum is not supported!") + super(LARS8bit, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + class LARS32bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): if momentum == 0: - raise NotImplementedError(f'LARS without momentum is not supported!') - super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, - weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) + raise NotImplementedError(f"LARS without momentum is not supported!") + super(LARS32bit, self).__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) class PytorchLARS(Optimizer): - def __init__(self, params, lr=0.01, momentum=0, dampening=0, - weight_decay=0, nesterov=False, max_unorm=0.02): + def __init__( + self, + params, + lr=0.01, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + max_unorm=0.02, + ): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) if momentum < 0.0: @@ -45,8 +123,14 @@ class PytorchLARS(Optimizer): if weight_decay < 0.0: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, momentum=momentum, dampening=dampening, - weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm) + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + max_unorm=max_unorm, + ) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super(PytorchLARS, self).__init__(params, defaults) @@ -54,7 +138,7 @@ class PytorchLARS(Optimizer): def __setstate__(self, state): super(PytorchLARS, self).__setstate__(state) for group in self.param_groups: - group.setdefault('nesterov', False) + group.setdefault("nesterov", False) @torch.no_grad() def step(self, closure=None): @@ -73,15 +157,16 @@ class PytorchLARS(Optimizer): params_with_grad = [] d_p_list = [] momentum_buffer_list = [] - weight_decay = group['weight_decay'] - momentum = group['momentum'] - dampening = group['dampening'] - nesterov = group['nesterov'] - max_unorm = group['max_unorm'] - lr = group['lr'] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] + max_unorm = group["max_unorm"] + lr = group["lr"] - for p in group['params']: - if p.grad is None: continue + for p in group["params"]: + if p.grad is None: + continue state = self.state[p] d_p = p.grad @@ -89,16 +174,16 @@ class PytorchLARS(Optimizer): d_p = d_p.add(param, alpha=weight_decay) if momentum != 0: - buf = state.get('momentum_buffer', None) + buf = state.get("momentum_buffer", None) if buf is None: buf = torch.clone(d_p).detach() - state['momentum_buffer']= buf + state["momentum_buffer"] = buf else: buf.mul_(momentum).add_(d_p, alpha=1 - dampening) if nesterov: - update = d_p + buf*momentum + update = d_p + buf * momentum else: update = buf @@ -107,9 +192,9 @@ class PytorchLARS(Optimizer): assert p.dtype == torch.float32 pnorm = torch.norm(p.detach()) unorm = torch.norm(update) - if unorm > max_unorm*pnorm: - update_scale = max_unorm*pnorm/unorm + if unorm > max_unorm * pnorm: + update_scale = max_unorm * pnorm / unorm - p.add_(update, alpha=-lr*update_scale) + p.add_(update, alpha=-lr * update_scale) return loss diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 5a5bb1e..b942e34 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -1,13 +1,16 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch -import bitsandbytes.functional as F - +from collections import abc as container_abcs +from collections import defaultdict from copy import deepcopy from itertools import chain -from collections import defaultdict, abc as container_abcs + +import torch + +import bitsandbytes.functional as F + class MockArgs(object): def __init__(self, initial_data): @@ -19,7 +22,7 @@ class GlobalOptimManager(object): _instance = None def __init__(self): - raise RuntimeError('Call get_instance() instead') + raise RuntimeError("Call get_instance() instead") def initialize(self): self.pid2config = {} @@ -38,15 +41,15 @@ class GlobalOptimManager(object): def register_parameters(self, params): param_groups = list(params) if not isinstance(param_groups[0], dict): - param_groups = [{'params': param_groups}] + param_groups = [{"params": param_groups}] for group_index, group in enumerate(param_groups): - for p_index, p in enumerate(group['params']): + for p_index, p in enumerate(group["params"]): if id(p) in self.pid2config: self.index2config[(group_index, p_index)] = self.pid2config[id(p)] def override_config(self, parameters, key=None, value=None, key_value_dict=None): - ''' + """ Overrides initial optimizer config for specific parameters. The key-values of the optimizer config for the input parameters are overidden @@ -63,7 +66,7 @@ class GlobalOptimManager(object): The value for the hyperparamters. key_value_dict : dict A dictionary with multiple key-values to override. - ''' + """ self.uses_config_override = True if isinstance(parameters, torch.nn.Parameter): parameters = [parameters] @@ -75,16 +78,16 @@ class GlobalOptimManager(object): if key_value_dict is not None: for p in parameters: - if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict) - else: self.pid2config[id(p)] = key_value_dict + if id(p) in self.pid2config: + self.pid2config[id(p)].update(key_value_dict) + else: + self.pid2config[id(p)] = key_value_dict def register_module_override(self, module, param_name, config): self.module_weight_config_triple.append((module, param_name, config)) - class Optimizer8bit(torch.optim.Optimizer): - def __init__(self, params, defaults, optim_bits=32): super(Optimizer8bit, self).__init__(params, defaults) self.initialized = False @@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer): self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = set( - ['qmap1', 'qmap2', - 'max1', 'max2', - 'new_max1', 'new_max2', - 'state1', 'state2', - 'gnorm_vec', 'absmax1', 'absmax2', - 'unorm_vec']) + [ + "qmap1", + "qmap2", + "max1", + "max2", + "new_max1", + "new_max2", + "state1", + "state2", + "gnorm_vec", + "absmax1", + "absmax2", + "unorm_vec", + ] + ) - if optim_bits == 8: self.fill_qmap() + if optim_bits == 8: + self.fill_qmap() def fill_qmap(self): - self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True) - self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False) + self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True) + self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) def __setstate__(self, state): super(Optimizer8bit, self).__setstate__(state) - def load_state_dict(self, state_dict): r"""Loads the optimizer state. @@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer): state_dict = deepcopy(state_dict) # Validate the state_dict groups = self.param_groups - saved_groups = state_dict['param_groups'] + saved_groups = state_dict["param_groups"] if len(groups) != len(saved_groups): - raise ValueError("loaded state dict has a different number of " - "parameter groups") - param_lens = (len(g['params']) for g in groups) - saved_lens = (len(g['params']) for g in saved_groups) + raise ValueError( + "loaded state dict has a different number of " "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): - raise ValueError("loaded state dict contains a parameter group " - "that doesn't match the size of optimizer's group") + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) # Update the state - id_map = {old_id: p for old_id, p in - zip(chain.from_iterable((g['params'] for g in saved_groups)), - chain.from_iterable((g['params'] for g in groups)))} + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in groups)), + ) + } def cast(param, value): r"""Make a deep copy of value, casting all tensors to device of param.""" @@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer): # State that is not assigned to params is copied as is (needed for # backward compatibility). state = defaultdict(dict) - for k, v in state_dict['state'].items(): + for k, v in state_dict["state"].items(): if k in id_map: param = id_map[k] state[param] = cast(param, v) @@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer): # Update parameter groups, setting their 'params' value def update_group(group, new_group): - new_group['params'] = group['params'] + new_group["params"] = group["params"] return new_group - param_groups = [ - update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({'state': state, 'param_groups': param_groups}) + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) def to_gpu(self): for gindex, group in enumerate(self.param_groups): - for pindex, p in enumerate(group['params']): + for pindex, p in enumerate(group["params"]): if p in self.state: values = self.state[p] for k, v in values.items(): @@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer): for module, attr, config in self.mng.module_weight_config_triple: pmodule = getattr(module, attr) assert pmodule is not None - assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) + assert isinstance(pmodule, torch.Tensor) or isinstance( + pmodule, torch.Parameter + ) found = False for gindex, group in enumerate(self.param_groups): - if found: break - for pindex, p in enumerate(group['params']): - if found: break + if found: + break + for pindex, p in enumerate(group["params"]): + if found: + break if id(p) == id(pmodule): # found the matching parameter # init override self.mng.pid2config[id(p)] = config - self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] + self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[ + id(p) + ] found = True @torch.no_grad() @@ -219,11 +244,11 @@ class Optimizer8bit(torch.optim.Optimizer): if not self.initialized: self.check_overrides() - self.to_gpu() # needed for fairseq pure fp16 training + self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True for gindex, group in enumerate(self.param_groups): - for pindex, p in enumerate(group['params']): + for pindex, p in enumerate(group["params"]): if p.grad is None: continue state = self.state[p] @@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer): def get_config(self, gindex, pindex, group): config = {} - config['betas'] = group['betas'] - config['eps'] = group['eps'] - config['weight_decay'] = group['weight_decay'] - config['lr'] = group['lr'] - config['optim_bits'] = self.args.optim_bits - config['min_8bit_size'] = self.args.min_8bit_size - config['percentile_clipping'] = self.args.percentile_clipping - config['block_wise'] = self.args.block_wise - config['max_unorm'] = self.args.max_unorm - config['skip_zeros'] = self.args.skip_zeros + config["betas"] = group["betas"] + config["eps"] = group["eps"] + config["weight_decay"] = group["weight_decay"] + config["lr"] = group["lr"] + config["optim_bits"] = self.args.optim_bits + config["min_8bit_size"] = self.args.min_8bit_size + config["percentile_clipping"] = self.args.percentile_clipping + config["block_wise"] = self.args.block_wise + config["max_unorm"] = self.args.max_unorm + config["skip_zeros"] = self.args.skip_zeros if (gindex, pindex) in self.mng.index2config: config.update(self.mng.index2config[(gindex, pindex)]) return config def init_state(self, group, p, gindex, pindex): - raise NotImplementedError(f'init_state method needs to be overidden') + raise NotImplementedError(f"init_state method needs to be overidden") def update_step(self, group, p, gindex, pindex): - raise NotImplementedError(f'The update_step method needs to be overidden') + raise NotImplementedError(f"The update_step method needs to be overidden") + class Optimizer2State(Optimizer8bit): - def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, - weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, - skip_zeros=False): + def __init__( + self, + optimizer_name, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + max_unorm=0.0, + skip_zeros=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if isinstance(betas, str): # format: '(beta1, beta2)' - betas = betas.replace('(', '').replace(')', '').strip().split(',') + betas = betas.replace("(", "").replace(")", "").strip().split(",") betas = [float(b) for b in betas] for i in range(len(betas)): if not 0.0 <= betas[i] < 1.0: raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer2State, self).__init__(params, defaults, optim_bits) if args is None: args = {} - args['optim_bits'] = optim_bits - args['percentile_clipping'] = 100 - args['min_8bit_size'] = min_8bit_size - args['percentile_clipping'] = percentile_clipping - args['block_wise'] = block_wise - args['max_unorm'] = max_unorm - args['skip_zeros'] = skip_zeros + args["optim_bits"] = optim_bits + args["percentile_clipping"] = 100 + args["min_8bit_size"] = min_8bit_size + args["percentile_clipping"] = percentile_clipping + args["block_wise"] = block_wise + args["max_unorm"] = max_unorm + args["skip_zeros"] = skip_zeros self.args = MockArgs(args) else: @@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit): def init_state(self, group, p, gindex, pindex): config = self.get_config(gindex, pindex, group) - if config['optim_bits'] == 32: + if config["optim_bits"] == 32: dtype = torch.float32 - elif config['optim_bits'] == 8: + elif config["optim_bits"] == 8: dtype = torch.uint8 - else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + else: + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) - if p.numel() < config['min_8bit_size']: dtype = torch.float32 + if p.numel() < config["min_8bit_size"]: + dtype = torch.float32 state = self.state[p] - state['step'] = 0 + state["step"] = 0 if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) - state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) + state["state2"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) elif dtype == torch.uint8: - if state['step'] == 0: - if 'dynamic' not in self.name2qmap: self.fill_qmap() - self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) - self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device) + if state["step"] == 0: + if "dynamic" not in self.name2qmap: + self.fill_qmap() + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap1'] = self.name2qmap['dynamic'] + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap1"] = self.name2qmap["dynamic"] - state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap2'] = self.name2qmap['udynamic'] + state["state2"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap2"] = self.name2qmap["udynamic"] - if config['block_wise']: + if config["block_wise"]: n = p.numel() - blocks = n//2048 + blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) - state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) + state["absmax2"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) else: - state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) - if config['percentile_clipping'] < 100: - state['gnorm_vec'] = torch.zeros((100,), device=p.device) + if config["percentile_clipping"] < 100: + state["gnorm_vec"] = torch.zeros((100,), device=p.device) - if config['max_unorm'] > 0.0: - state['unorm_vec'] = torch.zeros((1,), device=p.device) + if config["max_unorm"] > 0.0: + state["unorm_vec"] = torch.zeros((1,), device=p.device) @torch.no_grad() def update_step(self, group, p, gindex, pindex): @@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit): config = self.get_config(gindex, pindex, group) - state['step'] += 1 - step = state['step'] + state["step"] += 1 + step = state["step"] - if config['percentile_clipping'] < 100: - current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) + if config["percentile_clipping"] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( + grad, state["gnorm_vec"], step, config["percentile_clipping"] + ) else: gnorm_scale = 1.0 - if state['state1'].dtype == torch.float: - F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], - state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros']) + if state["state1"].dtype == torch.float: + F.optimizer_update_32bit( + self.optimizer_name, + grad, + p, + state["state1"], + config["betas"][0], + config["eps"], + step, + config["lr"], + state["state2"], + config["betas"][1], + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + skip_zeros=config["skip_zeros"], + ) - elif state['state1'].dtype == torch.uint8 and not config['block_wise']: - F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'], - config['weight_decay'], gnorm_scale=gnorm_scale, - unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: + F.optimizer_update_8bit( + self.optimizer_name, + grad, + p, + state["state1"], + state["state2"], + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + state["qmap2"], + state["max1"], + state["max2"], + state["new_max1"], + state["new_max2"], + config["weight_decay"], + gnorm_scale=gnorm_scale, + unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + ) # swap maxes - state['max1'], state['new_max1'] = state['new_max1'], state['max1'] - state['max2'], state['new_max2'] = state['new_max2'], state['max2'] - elif state['state1'].dtype == torch.uint8 and config['block_wise']: - F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'], - config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) + state["max1"], state["new_max1"] = state["new_max1"], state["max1"] + state["max2"], state["new_max2"] = state["new_max2"], state["max2"] + elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + F.optimizer_update_8bit_blockwise( + self.optimizer_name, + grad, + p, + state["state1"], + state["state2"], + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + state["qmap2"], + state["absmax1"], + state["absmax2"], + config["weight_decay"], + gnorm_scale=gnorm_scale, + skip_zeros=config["skip_zeros"], + ) class Optimizer1State(Optimizer8bit): - def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, - weight_decay=0.0, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, - skip_zeros=False): + def __init__( + self, + optimizer_name, + params, + lr=1e-3, + betas=(0.9, 0.0), + eps=1e-8, + weight_decay=0.0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + max_unorm=0.0, + skip_zeros=False, + ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit): raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(Optimizer1State, self).__init__(params, defaults, optim_bits) if args is None: args = {} - args['optim_bits'] = optim_bits - args['percentile_clipping'] = 100 - args['min_8bit_size'] = min_8bit_size - args['percentile_clipping'] = percentile_clipping - args['block_wise'] = block_wise - args['max_unorm'] = max_unorm - args['skip_zeros'] = skip_zeros + args["optim_bits"] = optim_bits + args["percentile_clipping"] = 100 + args["min_8bit_size"] = min_8bit_size + args["percentile_clipping"] = percentile_clipping + args["block_wise"] = block_wise + args["max_unorm"] = max_unorm + args["skip_zeros"] = skip_zeros self.args = MockArgs(args) else: @@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit): def init_state(self, group, p, gindex, pindex): config = self.get_config(gindex, pindex, group) - if config['optim_bits'] == 32: + if config["optim_bits"] == 32: dtype = torch.float32 - elif config['optim_bits'] == 8: + elif config["optim_bits"] == 8: dtype = torch.uint8 - else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') + else: + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) - if p.numel() < config['min_8bit_size']: dtype = torch.float32 + if p.numel() < config["min_8bit_size"]: + dtype = torch.float32 state = self.state[p] - state['step'] = 0 + state["step"] = 0 if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.float32, + device=p.device, + ) elif dtype == torch.uint8: - if state['step'] == 0: - if 'dynamic' not in self.name2qmap: self.fill_qmap() - self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) + if state["step"] == 0: + if "dynamic" not in self.name2qmap: + self.fill_qmap() + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) - state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) - state['qmap1'] = self.name2qmap['dynamic'] + state["state1"] = torch.zeros_like( + p, + memory_format=torch.preserve_format, + dtype=torch.uint8, + device=p.device, + ) + state["qmap1"] = self.name2qmap["dynamic"] - if config['block_wise']: + if config["block_wise"]: n = p.numel() - blocks = n//2048 + blocks = n // 2048 blocks += 1 if n % 2048 > 0 else 0 - state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) else: - state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) - state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) - if config['percentile_clipping'] < 100: - state['gnorm_vec'] = torch.zeros((100,), device=p.device) - - if config['max_unorm'] > 0.0: - state['unorm_vec'] = torch.zeros((1,), device=p.device) + if config["percentile_clipping"] < 100: + state["gnorm_vec"] = torch.zeros((100,), device=p.device) + if config["max_unorm"] > 0.0: + state["unorm_vec"] = torch.zeros((1,), device=p.device) @torch.no_grad() def update_step(self, group, p, gindex, pindex): @@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit): config = self.get_config(gindex, pindex, group) - state['step'] += 1 - step = state['step'] + state["step"] += 1 + step = state["step"] - if config['percentile_clipping'] < 100: - current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) + if config["percentile_clipping"] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( + grad, state["gnorm_vec"], step, config["percentile_clipping"] + ) else: gnorm_scale = 1.0 - if state['state1'].dtype == torch.float: - F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], - None, 0.0, config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], - skip_zeros=config['skip_zeros']) + if state["state1"].dtype == torch.float: + F.optimizer_update_32bit( + self.optimizer_name, + grad, + p, + state["state1"], + config["betas"][0], + config["eps"], + step, + config["lr"], + None, + 0.0, + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + skip_zeros=config["skip_zeros"], + ) - elif state['state1'].dtype == torch.uint8 and not config['block_wise']: - F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None, - config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: + F.optimizer_update_8bit( + self.optimizer_name, + grad, + p, + state["state1"], + None, + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + None, + state["max1"], + None, + state["new_max1"], + None, + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + ) - state['max1'], state['new_max1'] = state['new_max1'], state['max1'] - elif state['state1'].dtype == torch.uint8 and config['block_wise']: - F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], - config['eps'], step, config['lr'], - state['qmap1'], None, state['absmax1'], None, - config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) + state["max1"], state["new_max1"] = state["new_max1"], state["max1"] + elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + F.optimizer_update_8bit_blockwise( + self.optimizer_name, + grad, + p, + state["state1"], + None, + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + None, + state["absmax1"], + None, + config["weight_decay"], + gnorm_scale=gnorm_scale, + skip_zeros=config["skip_zeros"], + ) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py index 0f1ffaa..679f783 100644 --- a/bitsandbytes/optim/rmsprop.py +++ b/bitsandbytes/optim/rmsprop.py @@ -1,36 +1,109 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class RMSprop(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class RMSprop8bit(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop8bit, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class RMSprop32bit(Optimizer1State): - def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if alpha == 0: - raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') + raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") if centered: - raise NotImplementedError(f'Centered RMSprop is not supported!') - super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"Centered RMSprop is not supported!") + super(RMSprop32bit, self).__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py index 0529879..f7b8934 100644 --- a/bitsandbytes/optim/sgd.py +++ b/bitsandbytes/optim/sgd.py @@ -1,32 +1,99 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State + class SGD(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, optim_bits=32, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class SGD8bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD8bit, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + class SGD32bit(Optimizer1State): - def __init__(self, params, lr, momentum=0, dampening=0, - weight_decay=0, nesterov=False, args=None, - min_8bit_size=4096, percentile_clipping=100, block_wise=True): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): if momentum == 0: - raise NotImplementedError(f'SGD without momentum is not supported!') - super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, - weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + raise NotImplementedError(f"SGD without momentum is not supported!") + super(SGD32bit, self).__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index a9eddf9..29b9c90 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,9 @@ import sys + def print_err(s: str) -> None: print(s, file=sys.stderr) + def warn_of_missing_prerequisite(s: str) -> None: - print_err('WARNING, missing pre-requisite: ' + s) + print_err("WARNING, missing pre-requisite: " + s) diff --git a/quicktest.py b/quicktest.py index 2db6afa..29d045d 100644 --- a/quicktest.py +++ b/quicktest.py @@ -1,31 +1,45 @@ +from itertools import product + import torch + import bitsandbytes as bnb import bitsandbytes.functional as F -from itertools import product def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): k = 25 for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( + torch.int8 + ) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( + torch.int8 + ) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) - A2, SA = F.transform(A, 'col32') - B2, SB = F.transform(B, 'colx') + A2, SA = F.transform(A, "col32") + B2, SB = F.transform(B, "colx") if dims == 2: - C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + C2, SC = F.transform( + torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"), + "col32", + ) else: - C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') + C2, SC = F.transform( + torch.zeros( + A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda" + ), + "col32", + ) F.igemmlt(A2, B2, C2, SA, SB, SC) - C3, S = F.transform(C2, 'row', state=SC) - #torch.testing.assert_allclose(C1, C3.float()) - #print(C1) - #print(C2) - #print(C3) + C3, S = F.transform(C2, "row", state=SC) + # torch.testing.assert_allclose(C1, C3.float()) + # print(C1) + # print(C2) + # print(C3) allclose = torch.allclose(C1, C3.float()) if allclose: print(C1) @@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): print(C3) ## transposed - #A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) - #if dims == 2: + # A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + # if dims == 2: # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) # C1 = torch.matmul(A.float(), B.float().t()) - #elif dims == 3: + # elif dims == 3: # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) # C1 = torch.matmul(B.float(), A.t().float()) # C1 = C1.permute([2, 0, 1]) - #A2, SA = F.transform(A, 'col32') - #B2, SB = F.transform(B, 'colx') - #if dims == 2: + # A2, SA = F.transform(A, 'col32') + # B2, SB = F.transform(B, 'colx') + # if dims == 2: # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') - #else: + # else: # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda') # state = (C2.shape, 'row', A.shape[0]) # C2, SC = F.transform(C2, 'col32', state=state) - #F.igemmlt(A2, B2, C2, SA, SB, SC) - #C3, S = F.transform(C2, 'row', state=SC, ld=[0]) - #torch.testing.assert_allclose(C1, C3.float()) + # F.igemmlt(A2, B2, C2, SA, SB, SC) + # C3, S = F.transform(C2, 'row', state=SC, ld=[0]) + # torch.testing.assert_allclose(C1, C3.float()) ## weight update - #if dims == 3: + # if dims == 3: # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8) # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float()) @@ -73,18 +87,18 @@ dims = (2, 3) ldb = [0] n = 2 -dim1 = torch.randint(1,256, size=(n,)).tolist() -dim2 = torch.randint(32,512, size=(n,)).tolist() -dim3 = torch.randint(32,1024, size=(n,)).tolist() -dim4 = torch.randint(32,1024, size=(n,)).tolist() -values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) +dim1 = torch.randint(1, 256, size=(n,)).tolist() +dim2 = torch.randint(32, 512, size=(n,)).tolist() +dim3 = torch.randint(32, 1024, size=(n,)).tolist() +dim4 = torch.randint(32, 1024, size=(n,)).tolist() +values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) for ldb in range(32, 4096, 32): -#for ldb in [None]: + # for ldb in [None]: val = test_igemmlt(2, 2, 2, 2, 2, ldb) if val: print(val, ldb) else: - print('nope', ldb) -#for val in values: - #test_igemmlt(*val) + print("nope", ldb) +# for val in values: +# test_igemmlt(*val) diff --git a/setup.py b/setup.py index 965817d..3285ba1 100644 --- a/setup.py +++ b/setup.py @@ -1,19 +1,21 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import os import glob -from setuptools import setup, find_packages +import os +from setuptools import find_packages, setup -libs = list(glob.glob('./bitsandbytes/libbitsandbytes*.so')) +libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) libs = [os.path.basename(p) for p in libs] -print('libs:', libs) +print("libs:", libs) + def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() + setup( name=f"bitsandbytes", version=f"0.31.0", @@ -27,11 +29,11 @@ setup( entry_points={ "console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"], }, - package_data={'': libs}, - long_description=read('README.md'), - long_description_content_type='text/markdown', + package_data={"": libs}, + long_description=read("README.md"), + long_description_content_type="text/markdown", classifiers=[ "Development Status :: 4 - Beta", - 'Topic :: Scientific/Engineering :: Artificial Intelligence' + "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index d2b5d59..9cd01a9 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,27 +1,38 @@ -import pytest - -import torch -import bitsandbytes as bnb - from itertools import product +import pytest +import torch + +import bitsandbytes as bnb + n = 1 k = 25 -dim1 = torch.randint(16,64, size=(n,)).tolist() -dim2 = torch.randint(32,96, size=(n,)).tolist() -dim3 = torch.randint(32,96, size=(n,)).tolist() -dim4 = torch.randint(32,96, size=(n,)).tolist() +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] -str_funcs = ['bmm', 'matmul'] +str_funcs = ["bmm", "matmul"] req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ['FF', 'TF', 'TT', 'FT'] +req_grad_str = ["FF", "TF", "TT", "FT"] transpose = [(False, False), (False, True), (True, True), (True, False)] -str_transpose = ['FF', 'FT', 'TT', 'TF'] +str_transpose = ["FF", "FT", "TT", "TF"] dtype = [torch.float32, torch.float16] -values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose)) -str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values] -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) +str_values = list( + product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose) +) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( + *vals + ) + for vals in str_values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names +) def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) @@ -32,9 +43,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): if funcs[0] in [torch.mm, torch.matmul]: dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0]) - B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) - target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1]) + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) + target = torch.randn( + size=(dim2, dim4), device="cuda", requires_grad=req_grad[1] + ) torch.nn.init.xavier_uniform_(B) if not transpose[0] and not transpose[1]: @@ -52,9 +65,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.0175 + assert (idx == 0).sum().item() < n * 0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx==0).sum().item() < n*0.001 + assert (idx == 0).sum().item() < n * 0.001 if any(req_grad): out_bnb.data.copy_(out_torch) @@ -78,16 +91,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 + assert (idx == 0).sum().item() < n * 0.02 torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) # batched matrix multiply if funcs[0] in [torch.bmm, torch.matmul]: - A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) - B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1]) - target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) + A = torch.randn( + size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0] + ) + B = torch.randn( + size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1] + ) + target = torch.randn( + size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1] + ) torch.nn.init.xavier_uniform_(B) out_torch = funcs[0](A, B) @@ -95,7 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.01 + assert (idx == 0).sum().item() < n * 0.01 torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) if any(req_grad): @@ -120,16 +139,20 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 + assert (idx == 0).sum().item() < n * 0.02 if funcs[0] in [torch.matmul]: dim1 = dim1 - (dim1 % 16) - A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) + A = torch.randn( + size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0] + ) dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) - B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) - target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) + target = torch.randn( + size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1] + ) torch.nn.init.xavier_uniform_(B) if transpose[1]: @@ -141,9 +164,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.0175 + assert (idx == 0).sum().item() < n * 0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx==0).sum().item() < n*0.001 + assert (idx == 0).sum().item() < n * 0.001 if any(req_grad): out_bnb.data.copy_(out_torch) @@ -167,51 +190,96 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 + assert (idx == 0).sum().item() < n * 0.02 n = 1 k = 3 -dim1 = torch.randint(16,64, size=(n,)).tolist() -dim2 = torch.randint(32,96, size=(n,)).tolist() -dim3 = torch.randint(32,96, size=(n,)).tolist() -dim4 = torch.randint(32,96, size=(n,)).tolist() +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() -#dim1 = (17,) -#dim2 = (7,) -#dim3 = (37,) -#dim4 = (23,) +# dim1 = (17,) +# dim2 = (7,) +# dim3 = (37,) +# dim4 = (23,) decomp = [0.0, 6.0] funcs = [(torch.matmul, bnb.matmul)] -str_funcs = ['matmul'] +str_funcs = ["matmul"] req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ['FF', 'TF', 'TT', 'FT'] +req_grad_str = ["FF", "TF", "TT", "FT"] transpose = [(False, True), (False, False)] -str_transpose = ['NT', 'NN'] +str_transpose = ["NT", "NN"] dtype = [torch.float16] has_fp16_weights = [True, False] -values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights)) -str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values] -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names) -def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights): +values = list( + product( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + decomp, + has_fp16_weights, + ) +) +str_values = list( + product( + dim1, + dim2, + dim3, + dim4, + str_funcs, + dtype, + req_grad_str, + str_transpose, + decomp, + has_fp16_weights, + ) +) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format( + *vals + ) + for vals in str_values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", + values, + ids=names, +) +def test_matmullt( + dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights +): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda') + outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") for i in range(k): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype) + A = torch.randn( + size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype + ) if decomp == 6.0: with torch.no_grad(): A[:, outlier_dim] = 6.0 - B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype) - target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype) + B = torch.randn( + size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype + ) + target = torch.randn( + size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype + ) torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -219,8 +287,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec state.threshold = decomp state.has_fp16_weights = has_fp16_weights if not has_fp16_weights: - if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous() - state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2) + if not transpose[0] and not transpose[1]: + B2 = B2.t().contiguous() + ( + state.CB, + CBt, + state.SCB, + SCBt, + coo_tensorB, + ) = bnb.functional.double_quant(B2) B2 = state.CB if not transpose[0] and transpose[1]: @@ -231,12 +306,12 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec out_bnb = funcs[1](A, B2.t(), state=state) n = out_bnb.numel() - err = torch.abs(out_bnb-out_torch).mean().item() - #print(f'abs error {err:.4f}') + err = torch.abs(out_bnb - out_torch).mean().item() + # print(f'abs error {err:.4f}') idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.0175 + assert (idx == 0).sum().item() < n * 0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx==0).sum().item() < n*0.001 + assert (idx == 0).sum().item() < n * 0.001 if has_fp16_weights: if any(req_grad): @@ -263,8 +338,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec assert torch.abs(gradB1).sum() > 0.0 assert torch.abs(gradB2).sum() > 0.0 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 + assert (idx == 0).sum().item() < n * 0.02 torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) - diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 72aa3c7..d45354f 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,37 +1,45 @@ -import pytest import os +from typing import List, NamedTuple -from typing import List +import pytest -from bitsandbytes.cuda_setup import ( - CUDA_RUNTIME_LIB, - get_cuda_runtime_lib_path, - evaluate_cuda_setup, - tokenize_paths, -) +from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup, + get_cuda_runtime_lib_path, tokenize_paths) -HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [ +class InputAndExpectedOutput(NamedTuple): + input: str + output: str + + +HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [ (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), - (f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}"), + ( + f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", + f"dir/with/{CUDA_RUNTIME_LIB}", + ), ] -@pytest.mark.parametrize( - "test_input, expected", - HAPPY_PATH__LD_LIB_TEST_PATHS -) +@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS) +def happy_path_path_string(tmpdir, request): + for path in tokenize_paths(request.param): + test_dir.mkdir() + if CUDA_RUNTIME_LIB in path: + (test_input / CUDA_RUNTIME_LIB).touch() + + +@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS) def test_get_cuda_runtime_lib_path__happy_path( - tmp_path, test_input: str, expected: str + tmp_path, test_input: str, expected: str ): for path in tokenize_paths(test_input): - assert False == tmp_path / test_input - test_dir.mkdir() - (test_input / CUDA_RUNTIME_LIB).touch() + path.mkdir() + (path / CUDA_RUNTIME_LIB).touch() assert get_cuda_runtime_lib_path(test_input) == expected @@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str): (test_input / CUDA_RUNTIME_LIB).touch() with pytest.raises(FileNotFoundError) as err_info: get_cuda_runtime_lib_path(test_input) - assert all( - match in err_info - for match in {"duplicate", CUDA_RUNTIME_LIB} - ) + assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB}) def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path): - existent_dir = tmp_path / 'a/b' + existent_dir = tmp_path / "a/b" existent_dir.mkdir() - non_existent_dir = tmp_path / 'c/d' # non-existent dir + non_existent_dir = tmp_path / "c/d" # non-existent dir test_input = ":".join([str(existent_dir), str(non_existent_dir)]) get_cuda_runtime_lib_path(test_input) std_err = capsys.readouterr().err - assert all( - match in std_err - for match in {"WARNING", "non-existent"} - ) + assert all(match in std_err for match in {"WARNING", "non-existent"}) + def test_full_system(): ## this only tests the cuda version and not compute capability - ld_path = os.environ['LD_LIBRARY_PATH'] - paths = ld_path.split(':') - version = '' + ld_path = os.environ["LD_LIBRARY_PATH"] + paths = ld_path.split(":") + version = "" for p in paths: - if 'cuda' in p: - idx = p.rfind('cuda-') - version = p[idx+5:idx+5+4].replace('/', '') + if "cuda" in p: + idx = p.rfind("cuda-") + version = p[idx + 5 : idx + 5 + 4].replace("/", "") version = float(version) break binary_name = evaluate_cuda_setup() - binary_name = binary_name.replace('libbitsandbytes_cuda', '') - assert binary_name.startswith(str(version).replace('.', '')) - - + binary_name = binary_name.replace("libbitsandbytes_cuda", "") + assert binary_name.startswith(str(version).replace(".", "")) diff --git a/tests/test_functional.py b/tests/test_functional.py index bfc3e28..11cd198 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,25 +1,29 @@ -import pytest import math import random import time -import torch -import bitsandbytes as bnb -import einops - from itertools import product +import einops +import pytest +import torch + +import bitsandbytes as bnb from bitsandbytes import functional as F -torch.set_printoptions(precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000) +torch.set_printoptions( + precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 +) k = 20 + def assert_all_approx_close(a, b, rtol, atol, count): idx = torch.isclose(a, b, rtol, atol) - sumval = (idx==0).sum().item() + sumval = (idx == 0).sum().item() if sumval > count: - print(f'Too many values not close: assert {sumval} < {count}') + print(f"Too many values not close: assert {sumval} < {count}") torch.testing.assert_allclose(a, b, rtol, atol) + class FFN(torch.nn.Module): def __init__(self, input_features, hidden_size, bias=True): super(FFN, self).__init__() @@ -35,13 +39,14 @@ class FFN(torch.nn.Module): x = self.fc2(x) return x + class Timer(object): def __init__(self): self.starts = {} self.ends = {} self.agg = {} - def tick(self, name='default'): + def tick(self, name="default"): if name not in self.starts: self.starts[name] = torch.cuda.Event(enable_timing=True) self.ends[name] = torch.cuda.Event(enable_timing=True) @@ -49,66 +54,70 @@ class Timer(object): else: ms = self.tock(name, evict=True, print_ms=False) - def tock(self, name='default', evict=True, print_ms=True): + def tock(self, name="default", evict=True, print_ms=True): if name in self.ends: self.ends[name].record() torch.cuda.synchronize() ms = self.starts[name].elapsed_time(self.ends[name]) - if name not in self.agg: self.agg[name] = 0.0 + if name not in self.agg: + self.agg[name] = 0.0 self.agg[name] += ms if evict: self.starts.pop(name) self.ends.pop(name) if print_ms and name in self.agg: - print('{0} took: {1:.5f}s'.format(name, self.agg[name]/1000.0)) + print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0)) return self.agg[name] def reset(self): - self.starts = {} + self.starts = {} self.ends = {} self.agg = {} - print('Resetting benchmark data') + print("Resetting benchmark data") + def setup(): pass + def teardown(): pass -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half']) + +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_estimate_quantiles(dtype): - A = torch.rand(1024, 1024, device='cuda') + A = torch.rand(1024, 1024, device="cuda") A = A.to(dtype) code = F.estimate_quantiles(A) - percs = torch.linspace(1/512, 511/512, 256, device=A.device) + percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) - A = torch.randn(1024, 1024, device='cuda') + A = torch.randn(1024, 1024, device="cuda") A = A.to(dtype) code = F.estimate_quantiles(A) quantiles = torch.quantile(A.float(), percs) - diff = torch.abs(code-quantiles) + diff = torch.abs(code - quantiles) assert (diff > 5e-02).sum().item() == 0 def test_quantile_quantization(): for i in range(100): - A1 = torch.randn(1024, 1024, device='cuda') + A1 = torch.randn(1024, 1024, device="cuda") code = F.estimate_quantiles(A1) C = F.quantize_no_absmax(A1, code) A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1-A2).mean().item() + diff = torch.abs(A1 - A2).mean().item() assert diff < 0.0075 - A1 = torch.rand(1024, 1024, device='cuda') + A1 = torch.rand(1024, 1024, device="cuda") code = F.estimate_quantiles(A1) C = F.quantize_no_absmax(A1, code) A2 = F.dequantize_no_absmax(C, code) - diff = torch.abs(A1-A2).mean().item() + diff = torch.abs(A1 - A2).mean().item() torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) assert diff < 0.001 @@ -117,22 +126,22 @@ def test_dynamic_quantization(): diffs = [] reldiffs = [] for i in range(100): - A1 = torch.randn(1024, 1024, device='cuda') + A1 = torch.randn(1024, 1024, device="cuda") C, S = F.quantize(A1) A2 = F.dequantize(C, S) - diff = torch.abs(A1-A2) - reldiff = diff/torch.abs(A1+1e-8) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diff.mean().item() < 0.0135 - #print(sum(diffs)/len(diffs)) - #print(sum(reldiffs)/len(reldiffs)) + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs)) for i in range(100): - A1 = torch.rand(1024, 1024, device='cuda') + A1 = torch.rand(1024, 1024, device="cuda") C, S = F.quantize(A1) A2 = F.dequantize(C, S) - diff = torch.abs(A1-A2).mean().item() + diff = torch.abs(A1 - A2).mean().item() torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) assert diff < 0.004 @@ -141,56 +150,60 @@ def test_dynamic_blockwise_quantization(): diffs = [] reldiffs = [] for i in range(100): - A1 = torch.randn(1024, 1024, device='cuda') + A1 = torch.randn(1024, 1024, device="cuda") C, S = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1-A2) - reldiff = diff/torch.abs(A1+1e-8) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) diffs.append(diff.mean().item()) reldiffs.append(reldiff.mean().item()) assert diffs[-1] < 0.011 - #print(sum(diffs)/len(diffs)) - #print(sum(reldiffs)/len(reldiffs)) + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): - A1 = torch.rand(1024, 1024, device='cuda') + A1 = torch.rand(1024, 1024, device="cuda") C, S = F.quantize_blockwise(A1) A2 = F.dequantize_blockwise(C, S) - diff = torch.abs(A1-A2).mean().item() + diff = torch.abs(A1 - A2).mean().item() assert diff < 0.0033 diffs.append(diff) torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - #print(sum(diffs)/len(diffs)) + # print(sum(diffs)/len(diffs)) + def test_dynamic_blockwise_stochastic_quantization(): diffs = [] reldiffs = [] rand = torch.rand(1024).cuda() for i in range(100): - A1 = torch.randn(1024, 1024, device='cuda') + A1 = torch.randn(1024, 1024, device="cuda") C1, S1 = F.quantize_blockwise(A1, rand=rand) C2, S2 = F.quantize_blockwise(A1) # a maximunm distance of quantized values of 1 torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) - fraction_smaller = (C1C2).float().sum()/C1.numel() - torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0) + fraction_smaller = (C1 < C2).float().sum() / C1.numel() + fraction_larger = (C1 > C2).float().sum() / C1.numel() + torch.testing.assert_allclose( + fraction_larger, fraction_smaller, atol=0.01, rtol=0 + ) - -@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half']) +@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) def test_percentile_clipping(gtype): - gnorm_vec1 = torch.zeros(100, device='cuda') - gnorm_vec2 = torch.zeros(100, device='cuda') + gnorm_vec1 = torch.zeros(100, device="cuda") + gnorm_vec2 = torch.zeros(100, device="cuda") n = 4 step = 0 - percentile=5 + percentile = 5 for i in range(k): step += 1 - g = torch.randn(n, n, dtype=gtype, device='cuda') - gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile) - assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1 + g = torch.randn(n, n, dtype=gtype, device="cuda") + gnorm1, clip2, gnorm_scale = F.percentile_clipping( + g, gnorm_vec2, step, percentile=percentile + ) + assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 gnorm2 = torch.norm(g.float()) if step == 1: @@ -208,74 +221,89 @@ def test_percentile_clipping(gtype): def quant(x): max1 = torch.abs(x).max() - x = torch.round(x/max1*127) + x = torch.round(x / max1 * 127) return max1, x.to(torch.int8) + def dequant(c, maxC): - return c.float()*(maxC/127) + return c.float() * (maxC / 127) + def mm_dequant(maxA, maxB, C): - return C.float()*(maxA/127)*(maxB/127) + return C.float() * (maxA / 127) * (maxB / 127) + def quant_multi(x, dim): max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - max1[max1==0] = 1.0 - x = torch.round(x/max1*127) + max1[max1 == 0] = 1.0 + x = torch.round(x / max1 * 127) return max1, x.to(torch.int8) + def quant_multi_chunk(x, dim, chunk_size=32): - if dim==1: - x_chunked = einops.rearrange(x, '(c a) b -> c a b', c=chunk_size) - max1 = torch.amax(torch.abs(x_chunked), dim=dim+1, keepdim=True) + if dim == 1: + x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size) + max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True) max1 = torch.tile(max1, (1, 1, x.shape[1])) max1 = max1.view(x.shape) - elif dim==0: - x_chunked = einops.rearrange(x, 'a (b c) -> a b c', c=chunk_size) + elif dim == 0: + x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size) max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True) max1 = torch.tile(max1, (x.shape[0], 1, 1)) max1 = max1.view(x.shape) - max1[max1==0] = 1.0 - x = torch.round(x/max1*127) + max1[max1 == 0] = 1.0 + x = torch.round(x / max1 * 127) return max1, x.to(torch.int8) + def quant_minmax(A): minA = A.min() maxA = A.max() -def mean(xx): - return sum(xx)/float(len(xx)) -#dim1 = torch.randint(1,1024*4, size=(4,)).tolist() -#dim2 = torch.randint(1,1024*4, size=(4,)).tolist() -dim1 = [1024*2] -dim2 = [1024*16] -methods = [(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)] +def mean(xx): + return sum(xx) / float(len(xx)) + + +# dim1 = torch.randint(1,1024*4, size=(4,)).tolist() +# dim2 = torch.randint(1,1024*4, size=(4,)).tolist() +dim1 = [1024 * 2] +dim2 = [1024 * 16] +methods = [ + (lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant) +] methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) -#methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) -method_names = ['linear', 'vectorwise'] +# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) +method_names = ["linear", "vectorwise"] batched = [False, True] -values = list(product(dim1,dim2, methods, batched)) -values_names = list(product(dim1,dim2, method_names, batched)) -names = ['dim1_{0}_dim2_{1}_quant_{2}_batched_{3}'.format(*vals) for vals in values_names] +values = list(product(dim1, dim2, methods, batched)) +values_names = list(product(dim1, dim2, method_names, batched)) +names = [ + "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names +] + + @pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names) def test_approx_igemm(dim1, dim2, quant_methods, batched): dim1 = dim1 - (dim1 % 32) dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - print('') + print("") for i in range(5): if batched: - A = torch.normal(0, 0.5, size=(32, dim1, dim2//32), device='cuda') - B = torch.normal(0, 0.5, size=(32, dim2//32, dim1), device='cuda') + A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") + B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 2) maxB, Bc = quant_methods[1](B, 1) else: - A = torch.normal(0, 0.5, size=(dim1, dim2), device='cuda') - B = torch.normal(0, 0.5, size=(dim2, dim1), device='cuda') + A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda") + B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 1) maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_allclose(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05) + torch.testing.assert_allclose( + quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 + ) if batched: out2 = torch.bmm(A, B) C = torch.bmm(Ac.float(), Bc.float()) @@ -284,43 +312,49 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): C = F.igemm(Ac, Bc) out = quant_methods[4](maxA, maxB, C) std = out2.std() - out/= std - out2/= std - err = torch.abs(out-out2) - relerr = err/torch.abs(out2) + out /= std + out2 /= std + err = torch.abs(out - out2) + relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) print(mean(errors)) print(mean(relerrors)) - - - - def test_stable_embedding(): layer = bnb.nn.StableEmbedding(1024, 1024) layer.reset_parameters() - n = 2 -hidden_dim = torch.randint(32,256, size=(n,)).tolist() -batch_dim = torch.randint(16,256, size=(n,)).tolist() -seq_dim = torch.randint(16,256, size=(n,)).tolist() +hidden_dim = torch.randint(32, 256, size=(n,)).tolist() +batch_dim = torch.randint(16, 256, size=(n,)).tolist() +seq_dim = torch.randint(16, 256, size=(n,)).tolist() transpose = [(False, False), (False, True), (True, False), (True, True)] -values = list(product(hidden_dim,batch_dim, transpose, seq_dim)) -names = ['hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}'.format(*vals) for vals in values] +values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) +names = [ + "hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals) + for vals in values +] + + @pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names) def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 16) seq_dim = seq_dim - (seq_dim % 16) for i in range(k): - shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) - shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + shapeA = ( + (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) + ) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) + if transpose[1] + else (hidden_dim, 32 * random.randint(1, 4)) + ) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: out2 = torch.matmul(A.float(), B.float()) out = F.igemm(A, B) @@ -338,9 +372,13 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): for i in range(k): shapeA = (batch_dim, seq_dim, hidden_dim) - shapeB = ((32*random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32*random.randint(1, 4))) - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) + if transpose[1] + else (hidden_dim, 32 * random.randint(1, 4)) + ) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: out2 = torch.matmul(A.float(), B.float()) out = F.igemm(A, B) @@ -352,40 +390,51 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): n = 3 -seq_dim = torch.randint(32,512, size=(n,)).tolist() -hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() -batch_dim = torch.randint(2,16, size=(n,)).tolist() -values = list(product(seq_dim,hidden_dim,batch_dim)) -names = ['seq_dim{0}_hidden_dim{1}_batch_dim{2}'.format(*vals) for vals in values] +seq_dim = torch.randint(32, 512, size=(n,)).tolist() +hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() +batch_dim = torch.randint(2, 16, size=(n,)).tolist() +values = list(product(seq_dim, hidden_dim, batch_dim)) +names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): seq_dim = seq_dim - (seq_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32) batch_dim = batch_dim - (batch_dim % 2) for i in range(25): - A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device='cuda').to(torch.int8) - out2 = torch.einsum('bsi, bso->io', A.float(), B.float()) + A = torch.randint( + -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" + ).to(torch.int8) + B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to( + torch.int8 + ) + out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) out = F.igemm(A, B, out=iout) torch.testing.assert_allclose(out.float(), out2) + n = 2 -seq_dim = torch.randint(32,512, size=(n,)).tolist() -hidden_dim = torch.randint(32,1024*4, size=(n,)).tolist() -batch_dim = torch.randint(2,16, size=(n,)).tolist() +seq_dim = torch.randint(32, 512, size=(n,)).tolist() +hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() +batch_dim = torch.randint(2, 16, size=(n,)).tolist() transpose = [False, True] -values = list(product(seq_dim,hidden_dim,batch_dim, transpose)) -names = ['seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}'.format(*vals) for vals in values] +values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) +names = [ + "seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals) + for vals in values +] + + @pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names) def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): - def min_max(x): maxA = torch.amax(x, dim=2, keepdim=True) minA = torch.amin(x, dim=2, keepdim=True) - scale = (maxA-minA)/2.0 - return (127*(x-minA-scale)/scale).to(torch.int8), minA, scale + scale = (maxA - minA) / 2.0 + return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale seq_dim = seq_dim - (seq_dim % 16) hidden_dim = hidden_dim - (hidden_dim % 16) @@ -395,30 +444,30 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): errs2 = [] relerrs2 = [] for i in range(k): - A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device='cuda') + A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") if transpose: - B = torch.normal(0, 0.5, size=(256, hidden_dim), device='cuda') + B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") else: - B = torch.normal(0, 0.5, size=(hidden_dim, 256), device='cuda') + B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda") Ac, minA, scale = min_max(A) if transpose: maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) out = F.igemm(Ac, Bc.t()) - out2 = torch.matmul(A,B.t()) - offset = B.t().sum(0)*(minA+scale) + out2 = torch.matmul(A, B.t()) + offset = B.t().sum(0) * (minA + scale) out = out.float() - out = (out*maxB.t()*scale/(127*127))+offset + out = (out * maxB.t() * scale / (127 * 127)) + offset maxA, Ac = quant_multi(A, dim=2) out3 = F.igemm(Ac, Bc.t()) out3 = mm_dequant(maxA, maxB.t(), out3) else: maxB, Bc = quant_multi(B, dim=0) - offset = B.sum(0)*(minA+scale) + offset = B.sum(0) * (minA + scale) out = F.igemm(Ac, Bc) - out2 = torch.matmul(A,B) + out2 = torch.matmul(A, B) out = out.float() - out = (out*maxB*scale/(127*127))+offset + out = (out * maxB * scale / (127 * 127)) + offset maxA, Ac = quant_multi(A, dim=2) out3 = F.igemm(Ac, Bc) @@ -429,31 +478,36 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): out /= std out3 /= std - err = torch.abs(out-out2) - relerr = err/(torch.abs(out2)+1e-7) + err = torch.abs(out - out2) + relerr = err / (torch.abs(out2) + 1e-7) - err2 = torch.abs(out3-out2) - relerr2 = err2/(torch.abs(out2)+1e-7) + err2 = torch.abs(out3 - out2) + relerr2 = err2 / (torch.abs(out2) + 1e-7) errs.append(err.mean().item()) relerrs.append(relerr.mean().item()) errs2.append(err2.mean().item()) relerrs2.append(relerr2.mean().item()) - #print(mean(errs)) - #print(mean(relerrs)) - #print(mean(errs2)) - #print(mean(relerrs2)) + # print(mean(errs)) + # print(mean(relerrs)) + # print(mean(errs2)) + # print(mean(relerrs2)) assert mean(errs) < 0.015 assert mean(relerrs) < 0.3 + n = 2 -dim1 = torch.randint(1,64, size=(n,)).tolist() -dim2 = torch.randint(32,128, size=(n,)).tolist() -dim3 = torch.randint(32,256, size=(n,)).tolist() -dim4 = torch.randint(32,256, size=(n,)).tolist() +dim1 = torch.randint(1, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 128, size=(n,)).tolist() +dim3 = torch.randint(32, 256, size=(n,)).tolist() +dim4 = torch.randint(32, 256, size=(n,)).tolist() transpose = [(False, False), (True, False), (False, True), (True, True)] -values = list(product(dim1,dim2,dim3,dim4,transpose)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, dim3, dim4, transpose)) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values +] + + @pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) def test_ibmm(dim1, dim2, dim3, dim4, transpose): dim2 = dim2 - (dim2 % 16) @@ -462,8 +516,8 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): for i in range(k): shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=shapeB, device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) if not transpose[0] and not transpose[1]: out2 = torch.bmm(A.float(), B.float()) @@ -479,146 +533,174 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) torch.testing.assert_allclose(out.float(), out2.float()) + n = 1 -dim1 = torch.randint(1,64, size=(n,)).tolist() -dim2 = torch.randint(32,128, size=(n,)).tolist() -dim3 = torch.randint(32,256, size=(n,)).tolist() -values = list(product(dim1,dim2,dim3)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}'.format(*vals) for vals in values] +dim1 = torch.randint(1, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 128, size=(n,)).tolist() +dim3 = torch.randint(32, 256, size=(n,)).tolist() +values = list(product(dim1, dim2, dim3)) +names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) def test_vector_quant(dim1, dim2, dim3): dim2 = dim2 - (dim2 % 16) dim3 = dim3 - (dim3 % 16) for i in range(k): - A = torch.randn(size=(dim2, dim3), device='cuda') + A = torch.randn(size=(dim2, dim3), device="cuda") qA, SA = F.vectorwise_quant(A, dim=0) A1 = F.vectorwise_dequant(qA, SA) torch.testing.assert_allclose(A1, A, atol=0.01, rtol=0.1) - n = 2 -dim1 = torch.randint(2,256, size=(n,)).tolist() -dim2 = torch.randint(2,256, size=(n,)).tolist() -dim3 = torch.randint(2,256, size=(n,)).tolist() -#dim1, dim2 = (256,), (256,) +dim1 = torch.randint(2, 256, size=(n,)).tolist() +dim2 = torch.randint(2, 256, size=(n,)).tolist() +dim3 = torch.randint(2, 256, size=(n,)).tolist() +# dim1, dim2 = (256,), (256,) dtype = [torch.int8, torch.int32] -a_order = ['row'] -out_order = ['col', 'row', 'col32'] +a_order = ["row"] +out_order = ["col", "row", "col32"] transpose = [False] dims = [2, 3] -values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) +values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}'.format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format( + *vals + ) + for vals in values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names +) def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): - if dims == 3 and out_order != 'col32': return - if dtype == torch.int32 and out_order != 'col32': return + if dims == 3 and out_order != "col32": + return + if dtype == torch.int32 and out_order != "col32": + return func = F.get_transform_func(dtype, orderA, orderOut, transpose) if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim2), device='cuda').to(dtype) + A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(dtype) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) out, S = F.nvidia_transform(A, to_order=orderOut) - if orderOut == 'row': + if orderOut == "row": torch.testing.assert_allclose(A.flatten(), out.flatten()) - elif orderOut == 'col': + elif orderOut == "col": torch.testing.assert_allclose(A.t().flatten(), out.flatten()) - elif orderOut == 'col32': + elif orderOut == "col32": if dims == 2: - n = A.shape[0]*(A.shape[1] + (32 - (A.shape[1]%32))) + n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) elif dims == 3: - n = A.shape[0]*A.shape[1]*(A.shape[2] + (32 - (A.shape[2]%32))) + n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) assert out.numel() == n - elif orderOut == 'col_turing': + elif orderOut == "col_turing": # 32 col 8 row tiles - n = (A.shape[0]+(8- A.shape[0]%8))*(A.shape[1] + (32 - (A.shape[1]%32))) + n = (A.shape[0] + (8 - A.shape[0] % 8)) * ( + A.shape[1] + (32 - (A.shape[1] % 32)) + ) assert out.numel() == n total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) for row in range(A.shape[0]): for col in range(A.shape[1]): - i = row*A.shape[1] + i = row * A.shape[1] j = col coltile = (col // 32) + (1 if col % 32 != 0 else 0) - rowtile = ((row // 8) + (1 if row % 8 != 0 else 0))*total_coltile - offset = 32*8*(rowtile+coltile) + rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile + offset = 32 * 8 * (rowtile + coltile) col2 = col % 32 - row2 = (row%8)*32 + row2 = (row % 8) * 32 + assert A.flatten()[i + j] == A[row, col] + # assert A.flatten()[i+j] == out.flatten()[row2+col2] + # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) + # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) - assert A.flatten()[i+j] == A[row, col] - #assert A.flatten()[i+j] == out.flatten()[row2+col2] - #torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) - #torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) - - if orderOut == 'col32': - out2, S = F.nvidia_transform(out, from_order=orderOut, to_order='row', state=S) + if orderOut == "col32": + out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) torch.testing.assert_allclose(A, out2) n = 1 -dim1 = torch.randint(1,256, size=(n,)).tolist() -dim2 = torch.randint(32,512, size=(n,)).tolist() -dim3 = torch.randint(32,1024, size=(n,)).tolist() -dim4 = torch.randint(32,1024, size=(n,)).tolist() +dim1 = torch.randint(1, 256, size=(n,)).tolist() +dim2 = torch.randint(32, 512, size=(n,)).tolist() +dim3 = torch.randint(32, 1024, size=(n,)).tolist() +dim4 = torch.randint(32, 1024, size=(n,)).tolist() -#dim1 = [2] -#dim2 = [2] -#dim3 = [2] -#dim4 = [2] +# dim1 = [2] +# dim2 = [2] +# dim3 = [2] +# dim4 = [2] -dims = (2,3) +dims = (2, 3) ldb = [0] -#ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}'.format(*vals) for vals in values] +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals) + for vals in values +] + + @pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): for i in range(k): if dims == 2: - A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( + torch.int8 + ) elif dims == 3: - A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) - B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( + torch.int8 + ) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.t().float()) - A2, SA = F.transform(A, 'col32') - B2, SB = F.transform(B, 'col_turing') + A2, SA = F.transform(A, "col32") + B2, SB = F.transform(B, "col_turing") C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, 'row', state=SC) + C3, S = F.nvidia_transform(C2, "row", state=SC) torch.testing.assert_allclose(C1, C3.float()) # transpose - B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) C1 = torch.matmul(A.float(), B.float()) - B2t, SBt = F.transform(B, 'col_turing', transpose=True) + B2t, SBt = F.transform(B, "col_turing", transpose=True) C2, SC = F.igemmlt(A2, B2t, SA, SBt) - C3, S = F.nvidia_transform(C2, 'row', state=SC) + C3, S = F.nvidia_transform(C2, "row", state=SC) torch.testing.assert_allclose(C1, C3.float()) + dim1 = [32] dim2 = [32] dim3 = [32] dim4 = [32] dims = (2,) -#ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1,dim2,dim3,dim4,dims)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}'.format(*vals) for vals in values] +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dim3, dim4, dims)) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values +] + + @pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): formatB = F.get_special_format_str() for i in range(k): if dims == 2: - A = torch.normal(0, 0.5, size=(dim1, dim3), device='cuda').half() + A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() elif dims == 3: - A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device='cuda').half() - B = torch.randn((dim4, dim3), device='cuda').half() + A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() + B = torch.randn((dim4, dim3), device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) C2 = bnb.matmul(A, B.t()) @@ -627,50 +709,56 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") CxB, SB = F.transform(CB, to_order=formatB) out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB) output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt) - #print('') - #print(output.flatten()[:10]) - #print(C1.flatten()[:10]) - #print(C2.flatten()[:10]) + # print('') + # print(output.flatten()[:10]) + # print(C1.flatten()[:10]) + # print(C2.flatten()[:10]) - - #torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) # transpose - #B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) - #C1 = torch.matmul(A.float(), B.float()) + # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + # C1 = torch.matmul(A.float(), B.float()) + + # B2t, SBt = F.transform2(B, 'col_turing', transpose=True) + # C2, SC = F.igemmlt(A2, B2t, SA, SBt) + # C3, S = F.transform(C2, 'row', state=SC) + # torch.testing.assert_allclose(C1, C3.float()) - #B2t, SBt = F.transform2(B, 'col_turing', transpose=True) - #C2, SC = F.igemmlt(A2, B2t, SA, SBt) - #C3, S = F.transform(C2, 'row', state=SC) - #torch.testing.assert_allclose(C1, C3.float()) batch_size = 2 seqdim = 512 -#values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] -values = [(batch_size, seqdim, 4*1024, 3*4*1024),(batch_size, seqdim, 5120, 3*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] +# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] +values = [ + (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024), + (batch_size, seqdim, 5120, 3 * 5120), + (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024), +] + + +# values = list(product(batch, seq, model, hidden)) +names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values] -#values = list(product(batch, seq, model, hidden)) -names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_8bit_training(batch, seq, model, hidden): formatB = F.get_special_format_str() - A = torch.randn(batch, seq, model, device='cuda').half() - grad = torch.randn(batch, seq, model, device='cuda').half() - w1 = torch.randint(-128, 127, size=(hidden, model), device='cuda').half() - w2 = torch.randint(-128, 127, size=(model, hidden), device='cuda').half() - print('') + A = torch.randn(batch, seq, model, device="cuda").half() + grad = torch.randn(batch, seq, model, device="cuda").half() + w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half() + w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half() + print("") - #torch.cuda.synchronize() + # torch.cuda.synchronize() ## warmup - #for i in range(100): + # for i in range(100): # torch.matmul(A, w1.t()) - #torch.cuda.synchronize() + # torch.cuda.synchronize() dtype = torch.int8 A = A.view(-1, A.shape[-1]).contiguous() @@ -679,77 +767,77 @@ def test_bench_8bit_training(batch, seq, model, hidden): t0 = time.time() for i in range(k): - out1 = torch.matmul(A, w1.t()) # fc1 - #out2 = torch.matmul(out1, w2.t())# fc2 + out1 = torch.matmul(A, w1.t()) # fc1 + # out2 = torch.matmul(out1, w2.t())# fc2 - #d1 = torch.matmul(grad, w2) # delta1 - #d2 = torch.matmul(d1, w1) # delta2 + # d1 = torch.matmul(grad, w2) # delta1 + # d2 = torch.matmul(d1, w1) # delta2 - #grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 - #grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 + # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 + # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 torch.cuda.synchronize() t16 = time.time() - t0 print(t16) - #torch.cuda.empty_cache() + # torch.cuda.empty_cache() - #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - #CTw1, Sw1 = F.transform2(Cw1, formatB) - #CTw2, Sw2 = F.transform2(Cw2, formatB) - #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) - #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # CTw1, Sw1 = F.transform2(Cw1, formatB) + # CTw2, Sw2 = F.transform2(Cw2, formatB) + # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) - #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - #C32A, SA = F.transform2(CA, 'col32') + # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + # C32A, SA = F.transform2(CA, 'col32') ## fc1 - #out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t) ## fc2 - #Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) - #C32out1, Sout1 = F.transform2(Cout1, 'col32') - #out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + # C32out1, Sout1 = F.transform2(Cout1, 'col32') + # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t) ## delta1 - #Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) - #C32grad, Sgrad = F.transform2(Cgrad, 'col32') + # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + # C32grad, Sgrad = F.transform2(Cgrad, 'col32') ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2) ## delta2 - #Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) - #C32d1, Sd1 = F.transform2(Cd1, 'col32') + # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + # C32d1, Sd1 = F.transform2(Cd1, 'col32') ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1) ## grad1 - #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) - #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad) ## grad2 - #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) - #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + # C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1) - #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - #Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) - #CTw1, Sw1 = F.transform2(Cw1, formatB) - #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) - #CTw2, Sw2 = F.transform2(Cw2, formatB) - #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(k): + # CTw1, Sw1 = F.transform2(Cw1, formatB) + # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # CTw2, Sw2 = F.transform2(Cw2, formatB) + # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(k): # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) # #CTw1, Sw1 = F.transform2(Cw1, formatB) # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) @@ -802,74 +890,76 @@ def test_bench_8bit_training(batch, seq, model, hidden): # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t) - #torch.cuda.synchronize() - #t8 = time.time() - t0 - #print(t8) - - - + # torch.cuda.synchronize() + # t8 = time.time() - t0 + # print(t8) n = 2 -dim1 = torch.randint(64,256, size=(n,)).tolist() -dim4 = torch.randint(64,1024, size=(n,)).tolist() +dim1 = torch.randint(64, 256, size=(n,)).tolist() +dim4 = torch.randint(64, 1024, size=(n,)).tolist() -#dim1 = [2*1024] -#dim4 = [2*1024] +# dim1 = [2*1024] +# dim4 = [2*1024] -#dim1 = [4] -#dim4 = [4] +# dim1 = [4] +# dim4 = [4] dims = (2,) -#ldb = list(range(256, 1*1024, 256)) -formatB = ['col_turing', 'col_ampere'] -values = list(product(dim1,dim4,dims, formatB)) -names = ['dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}'.format(*vals) for vals in values] +# ldb = list(range(256, 1*1024, 256)) +formatB = ["col_turing", "col_ampere"] +values = list(product(dim1, dim4, dims, formatB)) +names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) def test_dequant_mm(dim1, dim4, dims, formatB): inner = torch.randint(1, 128, size=(1,)).item() formatB = F.get_special_format_str() for i in range(k): - A = torch.randn(dim1, inner, device='cuda') - B = torch.randn(dim4, inner, device='cuda') + A = torch.randn(dim1, inner, device="cuda") + B = torch.randn(dim4, inner, device="cuda") C1 = torch.matmul(A.half(), B.t().half()) A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) - A2, SA = F.nvidia_transform(A1, 'col32') + A2, SA = F.nvidia_transform(A1, "col32") B2, SB = F.nvidia_transform(B1, formatB) C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, 'row', state=SC) + C3, S = F.nvidia_transform(C2, "row", state=SC) C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) count = (torch.isclose(C1, C4, atol=0.01, rtol=0.1) == 0).sum().item() n = C1.numel() p = 0.06 - assert count/n < p, f'error in more than {p} of elements: {count}/{n}={count/n}' + assert ( + count / n < p + ), f"error in more than {p} of elements: {count}/{n}={count/n}" C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten()) torch.testing.assert_allclose(C5, C4) - #print(C2) - + # print(C2) n = 2 -dim1 = [1*1024] -dim2 = [1*1024] -#dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1 * 1024] +dim2 = [1 * 1024] +# dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() dims = (2,) -#ldb = list(range(256, 1*1024, 256)) -values = list(product(dim1,dim2,dims)) -names = ['dim1_{0}_dim2_{1}_dims_{2}'.format(*vals) for vals in values] +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dims)) +names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) def test_colrow_absmax(dim1, dim2, dims): for i in range(k): threshold = 3.0 - A = torch.randn(dim1, dim2, device='cuda').half() + A = torch.randn(dim1, dim2, device="cuda").half() A_truncated = A.clone() A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0 if dims == 2: @@ -880,11 +970,22 @@ def test_colrow_absmax(dim1, dim2, dims): else: assert False - row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=threshold) + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( + A, threshold=threshold + ) - A_blocked = einops.rearrange(torch.abs(A), '(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size', row_tiles=16, block_size=64*4) - nnz_rows1_counts = (torch.abs(A_blocked)>=threshold).sum(3).flatten() - nnz_block_ptr1 = torch.zeros(nnz_rows1_counts.shape[0]+1, dtype=nnz_rows1_counts.dtype, device=nnz_rows1_counts.device) + A_blocked = einops.rearrange( + torch.abs(A), + "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size", + row_tiles=16, + block_size=64 * 4, + ) + nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten() + nnz_block_ptr1 = torch.zeros( + nnz_rows1_counts.shape[0] + 1, + dtype=nnz_rows1_counts.dtype, + device=nnz_rows1_counts.device, + ) nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) torch.testing.assert_allclose(col_stats1_trunc, col_stats2) @@ -898,19 +999,20 @@ def test_colrow_absmax(dim1, dim2, dims): assert nnz_block_ptr2 is None - n = 2 -#dim1 = [8*1024] -#dim2 = [4*1024] -dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +# dim1 = [8*1024] +# dim2 = [4*1024] +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() + +values = list(product(dim1, dim2)) +names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] + -values = list(product(dim1,dim2)) -names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2", values, ids=names) def test_double_quant(dim1, dim2): for i in range(k): - A = torch.randn(dim1, dim2, device='cuda').half() + A = torch.randn(dim1, dim2, device="cuda").half() out_col1, Scol = F.vectorwise_quant(A, dim=0) out_row1, Srow = F.vectorwise_quant(A, dim=1) @@ -920,18 +1022,21 @@ def test_double_quant(dim1, dim2): torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) - n = CAt.numel() - num_not_close_rows = (torch.isclose(CA, out_row1, atol=1)==0).sum().item() - num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1)==0).sum().item() + num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() # allow for 1:500 error due to rounding differences - min_error = 1/500 - if num_not_close_cols > (min_error*n): - print(f'Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}') + min_error = 1 / 500 + if num_not_close_cols > (min_error * n): + print( + f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" + ) assert False - if num_not_close_rows > (min_error*n): - print(f'Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}') + if num_not_close_rows > (min_error * n): + print( + f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" + ) assert False torch.testing.assert_allclose(Srow.flatten(), statsA) @@ -939,21 +1044,23 @@ def test_double_quant(dim1, dim2): n = 4 -dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -dim4 = torch.randint(1,4*1024, size=(n,)).tolist() -inner = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() dim1 = [6] dim4 = [4] inner = [8] values = list(zip(dim1, dim4, inner)) -names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) def test_integrated_igemmlt(dim1, dim4, inner): for i in range(k): - A = torch.randn(dim1, inner, device='cuda').half() - B = torch.randn(dim4, inner, device='cuda').half() + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() out1 = torch.matmul(A.half(), B.t().half()) @@ -967,30 +1074,32 @@ def test_integrated_igemmlt(dim1, dim4, inner): torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) - A2, SA = F.nvidia_transform(C1a, 'col32') - B2, SB = F.nvidia_transform(C2a, 'col_turing') + A2, SA = F.nvidia_transform(C1a, "col32") + B2, SB = F.nvidia_transform(C2a, "col_turing") outC32, SC = F.igemmlt(A2, B2, SA, SB) out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) - A2, SA = F.nvidia_transform(A1, 'col32') - B2, SB = F.nvidia_transform(B1, 'col_turing') + A2, SA = F.nvidia_transform(A1, "col32") + B2, SB = F.nvidia_transform(B1, "col_turing") C2, SC = F.igemmlt(A2, B2, SA, SB) - C3, S = F.nvidia_transform(C2, 'row', state=SC) + C3, S = F.nvidia_transform(C2, "row", state=SC) out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) - err1 = torch.abs(out1-out2).mean().item() - err2 = torch.abs(out1-out3).mean().item() - assert err2 <= err1*1.01 + err1 = torch.abs(out1 - out2).mean().item() + err2 = torch.abs(out1 - out3).mean().item() + assert err2 <= err1 * 1.01 n = 6 -dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -dim4 = torch.randint(1,4*1024, size=(n,)).tolist() -inner = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() values = list(zip(dim1, dim4, inner)) -names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.skip("Row scale has some bugs for ampere") def test_igemmlt_row_scale(dim1, dim4, inner): @@ -999,79 +1108,79 @@ def test_igemmlt_row_scale(dim1, dim4, inner): relerr1, relerr2 = [], [] scale = 1 for i in range(k): - A = torch.randn(dim1, inner, device='cuda').half() - B = torch.randn(dim4, inner, device='cuda').half() + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() torch.nn.init.xavier_uniform_(B) C1 = torch.matmul(A, B.t()) out1 = torch.matmul(A.half(), B.t().half()) - C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) - CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') - A2, SA = F.nvidia_transform(C1a, 'col32') + CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") + A2, SA = F.nvidia_transform(C1a, "col32") B2, SB = F.nvidia_transform(CB, formatB) A1, maxA = F.vectorwise_quant(A, dim=1) - c = 10.0*inner*scale - row_scale = torch.ones_like(maxA)/c + c = 10.0 * inner * scale + row_scale = torch.ones_like(maxA) / c outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) - C3, S = F.nvidia_transform(outC32, 'row', state=SC) + C3, S = F.nvidia_transform(outC32, "row", state=SC) maxval = torch.abs(C3).max() if maxval == 127: scale = 1.5 else: - scale = maxval/120 - out3 = C3*maxA*absmaxB*c/(127*127) + scale = maxval / 120 + out3 = C3 * maxA * absmaxB * c / (127 * 127) C4 = torch.matmul(C1a.float(), CB.float().t()) - C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) outC32, SC = F.igemmlt(A2, B2, SA, SB) out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) - CA, SA = F.vectorwise_quant(A, dim=1, quant_type='vector') - CB, SB = F.vectorwise_quant(B, dim=1, quant_type='linear') + CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector") + CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear") C = torch.matmul(CA.float(), CB.t().float()) - out4 = C*SA*SB/(127*127) - #out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) + out4 = C * SA * SB / (127 * 127) + # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) - #print('='*80) - #print(out1) - #print(out2) - #print(out3) + # print('='*80) + # print(out1) + # print(out2) + # print(out3) - #print(out1) - #print(out2) - #print(out3) - err1.append(torch.abs(out1-out2).mean().item()) - err2.append(torch.abs(out1-out3).mean().item()) - err3.append(torch.abs(out1-out4).mean().item()) + # print(out1) + # print(out2) + # print(out3) + err1.append(torch.abs(out1 - out2).mean().item()) + err2.append(torch.abs(out1 - out3).mean().item()) + err3.append(torch.abs(out1 - out4).mean().item()) - #assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) - print('') - print(sum(err1)/len(err1)) - print(sum(err2)/len(err2)) - print(sum(err3)/len(err3)) + # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) + print("") + print(sum(err1) / len(err1)) + print(sum(err2) / len(err2)) + print(sum(err3) / len(err3)) dim1 = [1024, 2048] -inner = [12288*4, 4096*4] +inner = [12288 * 4, 4096 * 4] dim4 = [12288, 4096] values = list(zip(dim1, dim4, inner)) -names = ['dim1_{0}_dim4_{1}_inner_{2}'.format(*vals) for vals in values] +names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) @pytest.mark.skip("Row scale has some bugs for ampere") def test_row_scale_bench(dim1, dim4, inner): err1, err2, err3 = [], [], [] relerr1, relerr2 = [], [] scale = 1 - A = torch.randn(dim1, inner, device='cuda').half() - B = torch.randn(dim4, inner, device='cuda').half() + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() torch.nn.init.xavier_uniform_(B) # warmpup for i in range(k): @@ -1082,23 +1191,22 @@ def test_row_scale_bench(dim1, dim4, inner): for i in range(k): C1 = torch.matmul(A, B.t()) torch.cuda.synchronize() - print('16', time.time()-t0) + print("16", time.time() - t0) C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) - CB, absmaxB = F.vectorwise_quant(B, quant_type='linear') - A2, SA = F.nvidia_transform(C1a, 'col32') + CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") + A2, SA = F.nvidia_transform(C1a, "col32") B2, SB = F.nvidia_transform(CB, formatB) A1, maxA = F.vectorwise_quant(A, dim=1) - c = 10.0*inner*scale - row_scale = maxA/c + c = 10.0 * inner * scale + row_scale = maxA / c torch.cuda.synchronize() t0 = time.time() for i in range(k): outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) torch.cuda.synchronize() - print('row-wise', time.time()-t0) - + print("row-wise", time.time() - t0) C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) B2, SB = F.nvidia_transform(C2a, formatB) @@ -1107,32 +1215,39 @@ def test_row_scale_bench(dim1, dim4, inner): for i in range(k): outC32, SC = F.igemmlt(A2, B2, SA, SB) torch.cuda.synchronize() - print('vector-wise', time.time()-t0) - - + print("vector-wise", time.time() - t0) n = 2 -dim1 = torch.randint(2,1024, size=(n,)).tolist() -dim2 = torch.randint(2,1024, size=(n,)).tolist() -#dim1 = [8*1024] -#dim2 = [4*1024] +dim1 = torch.randint(2, 1024, size=(n,)).tolist() +dim2 = torch.randint(2, 1024, size=(n,)).tolist() +# dim1 = [8*1024] +# dim2 = [4*1024] dim3 = [0] dtype = [torch.int8] -a_order = ['row'] -out_order = ['col32', 'col_turing', 'col_ampere'] +a_order = ["row"] +out_order = ["col32", "col_turing", "col_ampere"] transpose = [False, True] dims = [2] -values = list(product(dim1,dim2,dim3, dims,dtype, a_order, out_order, transpose)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}'.format(*vals) for vals in values] -@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names) +values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format( + *vals + ) + for vals in values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names +) def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): for i in range(k): if dims == 2: - A = torch.randint(10, 99, size=(dim1, dim2), device='cuda').to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) elif dims == 3: - A = torch.randint(10, 99, size=(dim1, dim2, dim3), device='cuda').to(dtype) + A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) A.view(-1)[-1] = -1 if transpose: @@ -1144,53 +1259,55 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): assert S1[0][0] == S2[0][0] assert S1[0][1] == S2[0][1] - #print(out1) - #print(out2) + # print(out1) + # print(out2) torch.testing.assert_allclose(out1, out2) + n = 2 -#dim1 = torch.randint(2,1024, size=(n,)).tolist() -#dim2 = torch.randint(2,1024, size=(n,)).tolist() +# dim1 = torch.randint(2,1024, size=(n,)).tolist() +# dim2 = torch.randint(2,1024, size=(n,)).tolist() dim1 = [1] dim2 = [33] dtype = [torch.int8] -#a_order = ['col_turing', 'col_ampere'] -a_order = ['col_turing'] -out_order = ['row'] -values = list(product(dim1,dim2,dtype, a_order, out_order)) -names = ['dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}'.format(*vals) for vals in values] +# a_order = ['col_turing', 'col_ampere'] +a_order = ["col_turing"] +out_order = ["row"] +values = list(product(dim1, dim2, dtype, a_order, out_order)) +names = [ + "dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals) + for vals in values +] + + @pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names) def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): for i in range(1): - A = torch.randint(-127, 127, size=(dim1, dim2), device='cuda').to(dtype) + A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype) out2, S2 = F.transform(A, to_order=orderA) - A2, S3 = F.transform(out2, from_order=orderA, to_order='row', state=S2) + A2, S3 = F.transform(out2, from_order=orderA, to_order="row", state=S2) assert A2.shape[0] == A.shape[0] assert A2.shape[1] == A.shape[1] - - print('') + print("") print(A) print(out2) print(A2) - - #torch.testing.assert_allclose(A, A2) - - + # torch.testing.assert_allclose(A, A2) def test_overflow(): formatB = F.get_special_format_str() print(formatB) for i in range(2): - a = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) - b = torch.arange(5, 15).cuda().to(torch.int8).view(-1,1 ) + a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) + b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) - Ca, Sa = F.nvidia_transform(a, 'col32') + Ca, Sa = F.nvidia_transform(a, "col32") Cb, Sb = F.nvidia_transform(b, formatB) c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) @@ -1198,46 +1315,51 @@ def test_overflow(): n = 2 -dim1 = torch.randint(1,4*1024, size=(n,)).tolist() -dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -#dim1 = [4] -#dim2 = [5] +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +# dim1 = [4] +# dim2 = [5] + +values = list(product(dim1, dim2)) +names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] + -values = list(product(dim1,dim2)) -names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2", values, ids=names) def test_coo_double_quant(dim1, dim2): threshold = 3.00 for i in range(k): - A = torch.randn(dim1, dim2, device='cuda').half() + A = torch.randn(dim1, dim2, device="cuda").half() - idx = (torch.abs(A) >= threshold) + idx = torch.abs(A) >= threshold CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) if coo_tensor is not None: - A1 = A*idx + A1 = A * idx A2 = torch.zeros_like(A) A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values torch.testing.assert_allclose(A1, A2) - A1 = A*(idx==0) - A2 = (CA.float()*statsA.unsqueeze(1)/127).half() - torch.testing.assert_allclose(A*(idx==0), A2, rtol=0.05, atol=1.5e-2) + A1 = A * (idx == 0) + A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + n = 2 -dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -dim2 = torch.randint(1,1*1024, size=(n,)).tolist() -#dim1 = [7] -#dim2 = [11] +dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() +# dim1 = [7] +# dim2 = [11] transposed_B = [False, True] -values = list(product(dim1,dim2, transposed_B)) -names = ['dim1_{0}_dim2_{1}_transposed_B_{2}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, transposed_B)) +names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) def test_spmm_coo(dim1, dim2, transposed_B): threshold = 1.5 dim3 = torch.randint(32, 128, size=(1,)).item() - #dim3 = 17 + # dim3 = 17 for i in range(k): A = torch.randn(dim1, dim2).cuda().half() if transposed_B: @@ -1249,8 +1371,10 @@ def test_spmm_coo(dim1, dim2, transposed_B): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx if transposed_B: out2 = F.spmm_coo(cooA, B.t()) @@ -1262,18 +1386,17 @@ def test_spmm_coo(dim1, dim2, transposed_B): assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) - def test_spmm_bench(): batch = 2 - model = 1024*1 - hidden = model*4 + model = 1024 * 1 + hidden = model * 4 seq = 1024 - dim1 = batch*seq + dim1 = batch * seq dim2 = model dim3 = hidden threshold = 4 - A = torch.randn(dim1, dim2, device='cuda').half() - B = torch.randn(dim2, dim3, device='cuda').half() + A = torch.randn(dim1, dim2, device="cuda").half() + B = torch.randn(dim2, dim3, device="cuda").half() for i in range(10): C1 = bnb.matmul(A, B) @@ -1282,14 +1405,16 @@ def test_spmm_bench(): for i in range(k): C1 = bnb.matmul(A, B) torch.cuda.synchronize() - t8 = time.time()-t0 + t8 = time.time() - t0 idx = torch.abs(A) >= threshold nnz = (idx == 1).sum().item() - print(nnz/idx.numel()) + print(nnz / idx.numel()) rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) for i in range(10): out2 = F.spmm_coo(cooA, B) @@ -1299,20 +1424,22 @@ def test_spmm_bench(): for i in range(k): out2 = F.spmm_coo(cooA, B) torch.cuda.synchronize() - tsp = time.time()-t0 + tsp = time.time() - t0 print(tsp, t8) - print(tsp/t8) + print(tsp / t8) n = 2 -dim1 = torch.randint(256,1*1024, size=(n,)).tolist() -dim2 = torch.randint(256,1*1024, size=(n,)).tolist() -values = list(product(dim1,dim2)) -names = ['dim1_{0}_dim2_{1}'.format(*vals) for vals in values] +dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() +dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() +values = list(product(dim1, dim2)) +names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2", values, ids=names) def test_integrated_sparse_decomp(dim1, dim2): threshold = 3.0 - formatB = 'col_turing' + formatB = "col_turing" for i in range(k): A = torch.randn(dim1, dim2).cuda().half() w1 = torch.randn(dim1, dim2).cuda().half() @@ -1322,13 +1449,13 @@ def test_integrated_sparse_decomp(dim1, dim2): CTw1, Sw1 = F.transform(Cw1, formatB) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) @@ -1338,8 +1465,8 @@ def test_integrated_sparse_decomp(dim1, dim2): out4 = F.spmm_coo(coo_tensor, w1.t()) out5 = out3 + out4 - err1 = torch.abs(out1-out2).mean().item() - err2 = torch.abs(out1-out5).mean().item() + err1 = torch.abs(out1 - out2).mean().item() + err2 = torch.abs(out1 - out5).mean().item() assert err2 < err1 @@ -1350,91 +1477,95 @@ def test_matmuls(): c2 = bnb.matmul(a, b) c3 = bnb.matmul(a, b) - err1 = torch.abs(c1-c2).mean().item() - err2 = torch.abs(c1-c3).mean().item() + err1 = torch.abs(c1 - c2).mean().item() + err2 = torch.abs(c1 - c3).mean().item() assert err1 < 0.2 assert err2 < 0.2 - n = 2 -#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1*2048] +# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1 * 2048] dim2 = [12288] -#dim1 = [32] -#dim2 = [32] -#dtype = [torch.float16, torch.int8] +# dim1 = [32] +# dim2 = [32] +# dtype = [torch.float16, torch.int8] dtype = [torch.float16] -out_function = ['zeros', 'ones'] -values = list(product(dim1,dim2, dtype, out_function)) -names = ['dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}'.format(*vals) for vals in values] +out_function = ["zeros", "ones"] +values = list(product(dim1, dim2, dtype, out_function)) +names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): out_func = getattr(torch, out_func) threshold = 3.3 - #threshold = 2.8 - #threshold = 0.0 - A = torch.randn(dim1, dim2, device='cuda').half() + # threshold = 2.8 + # threshold = 0.0 + A = torch.randn(dim1, dim2, device="cuda").half() if dtype == torch.float16: - B = torch.randn(dim2, dim2*4, device='cuda').half() + B = torch.randn(dim2, dim2 * 4, device="cuda").half() torch.nn.init.xavier_uniform_(B) else: - B = torch.randn(dim2, dim2*4, device='cuda').half() + B = torch.randn(dim2, dim2 * 4, device="cuda").half() torch.nn.init.xavier_uniform_(B) - B, SB = F.vectorwise_quant(B, quant_type='linear') - #B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) + B, SB = F.vectorwise_quant(B, quant_type="linear") + # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) - print('') + print("") idx = torch.abs(A) >= threshold nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx out1 = torch.matmul(A2.half(), B.half()) out = out_func(out1.shape, dtype=torch.float16, device=out1.device) out1 += out.clone() out2 = F.spmm_coo_very_sparse(cooA, B, out=out) - #print(B) - #print(out1) - #print(out2) - p = 200/(2048*12288*4) + # print(B) + # print(out1) + # print(out2) + p = 200 / (2048 * 12288 * 4) n = out1.numel() - count = math.ceil(p*n) + count = math.ceil(p * n) std = out1.std() out1 /= std out2 /= std assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) - #assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) + # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) idx_col = torch.randint(0, A2.shape[-1], size=(15,)) - #torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) + # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) - #Bt = torch.randn(dim2*4, dim2, device='cuda').half() - #torch.cuda.synchronize() - #t0 = time.time() - #print(A2.shape, B.shape) - #for i in range(100): + # Bt = torch.randn(dim2*4, dim2, device='cuda').half() + # torch.cuda.synchronize() + # t0 = time.time() + # print(A2.shape, B.shape) + # for i in range(100): # #out3 = F.spmm_coo(cooA, Bt.t()) # #out2 = F.spmm_coo(cooA, B) # #out2 = F.spmm_coo_very_sparse(cooA, B) # #out1 = torch.matmul(A, Bt.t()) - #torch.cuda.synchronize() - #print(time.time() - t0) + # torch.cuda.synchronize() + # print(time.time() - t0) + def test_layout(): - a1 = torch.rand(16, 64, device='cuda', dtype=torch.float16) - a1 = torch.arange(16* 64, device='cuda').reshape(16, 64).byte() - a2, s2 = F.transform(a1, 'col_turing') + a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16) + a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte() + a2, s2 = F.transform(a1, "col_turing") print(a2.shape) - print(a1.flatten()[8*64:8*64+32]) + print(a1.flatten()[8 * 64 : 8 * 64 + 32]) for i in range(4): - print(a2.flatten()[i*8*32:i*8*32+32], 0) + print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0) def test_coo2csr(): @@ -1444,14 +1575,16 @@ def test_coo2csr(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx csrA = F.coo2csr(cooA) counts = csrA.rowptr[1:] - csrA.rowptr[:-1] assert counts.numel() == A.shape[0] - torch.testing.assert_allclose(counts, (A2!=0).sum(1)) - idx = (A2!=0) + torch.testing.assert_allclose(counts, (A2 != 0).sum(1)) + idx = A2 != 0 torch.testing.assert_allclose(A2[idx], csrA.values) @@ -1462,41 +1595,43 @@ def test_coo2csc(): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx cscA = F.coo2csc(cooA) counts = cscA.colptr[1:] - cscA.colptr[:-1] assert counts.numel() == A.shape[1] - torch.testing.assert_allclose(counts, (A2!=0).sum(0)) + torch.testing.assert_allclose(counts, (A2 != 0).sum(0)) # torch uses row-major -> use transpose to transfer to col-major - idx = (A2.t()!=0) + idx = A2.t() != 0 torch.testing.assert_allclose(A2.t()[idx], cscA.values) - n = 2 -#dim1 = torch.randint(1,1*1024, size=(n,)).tolist() -#dim2 = torch.randint(1,4*1024, size=(n,)).tolist() -dim1 = [1*2048] -#dim2 = [12288] +# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1 * 2048] +# dim2 = [12288] dim2 = [2048] -#dim1 = [2] -#dim2 = [2] +# dim1 = [2] +# dim2 = [2] dtype = [torch.int8] -values = list(product(dim1,dim2, dtype)) -names = ['dim1_{0}_dim2_{1}_dtype_{2}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, dtype)) +names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) def test_spmm_coo_dequant(dim1, dim2, dtype): threshold = 6.0 - #threshold = 2.8 - #threshold = 0.0 - A = torch.randn(dim1, dim2, device='cuda').half() - B = torch.empty(dim2, dim2*4, device='cuda', dtype=torch.float16) + # threshold = 2.8 + # threshold = 0.0 + A = torch.randn(dim1, dim2, device="cuda").half() + B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16) torch.nn.init.xavier_uniform_(B) Bt = B.t().contiguous() - CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) rowidx = torch.randint(0, A.shape[-1], size=(15,)) @@ -1507,12 +1642,14 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): nnz = (idx == 1).sum().item() rows, cols = torch.where(idx) values = A[idx] - cooA = F.COOSparseTensor(A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values) - A2 = A*idx + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) out1 = torch.matmul(A2, B.half()) out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) - out3 = out3*statsBt.half()/127 + out3 = out3 * statsBt.half() / 127 values, counts = torch.unique(cooA.rowidx, return_counts=True) offset = counts.cumsum(0).int() @@ -1521,56 +1658,54 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) - p = 200/(2048*12288*4) + p = 200 / (2048 * 12288 * 4) n = out1.numel() - count = math.ceil(p*n) + count = math.ceil(p * n) assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) - - - #torch.cuda.synchronize() - #t0 = time.time() - #for i in range(100): + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(100): # out2 = F.spmm_coo_very_sparse(cooA, B) - #torch.cuda.synchronize() - #print('fp16', time.time() - t0) + # torch.cuda.synchronize() + # print('fp16', time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): - out2 = F.spmm_coo(cooA, B) + out2 = F.spmm_coo(cooA, B) torch.cuda.synchronize() - print('cusparse fp16', time.time() - t0) + print("cusparse fp16", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): - out2 = F.spmm_coo_very_sparse(cooA, CBt) + out2 = F.spmm_coo_very_sparse(cooA, CBt) torch.cuda.synchronize() - print('int8', time.time() - t0) + print("int8", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): - out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) torch.cuda.synchronize() - print('int8+dequant', time.time() - t0) + print("int8+dequant", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): - out2 = torch.matmul(A, B) + out2 = torch.matmul(A, B) torch.cuda.synchronize() - print('matmul', time.time() - t0) + print("matmul", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): out1 = bnb.matmul(A, Bt) out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) - out = out1+out2 + out = out1 + out2 torch.cuda.synchronize() - print('sparse+ matmul', time.time() - t0) + print("sparse+ matmul", time.time() - t0) torch.cuda.synchronize() t0 = time.time() @@ -1578,33 +1713,36 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): out1 = bnb.matmul(A, Bt) torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) torch.cuda.synchronize() - print('partial matmul', time.time() - t0) + print("partial matmul", time.time() - t0) torch.cuda.synchronize() t0 = time.time() for i in range(100): out1 = bnb.matmul(A, Bt) torch.cuda.synchronize() - print('partial matmul', time.time() - t0) + print("partial matmul", time.time() - t0) + batch_size = 1 seqdim = 2048 values = [] -values.append((batch_size, seqdim, 768, 4*768)) -#values.append((batch_size, seqdim, 1024, 4*1024)) -#values.append((batch_size, seqdim, 1536, 4*1536)) -#values.append((batch_size, seqdim, 2048, 4*2048)) -#values.append((batch_size, seqdim, 2560, 4*2560)) -#values.append((batch_size, seqdim, 4096, 4*4096)) -#values.append((batch_size, seqdim, 5140, 4*5140)) -#values.append((batch_size, seqdim, 12288, 4*12288)) -names = ['batch_{0}_seq_{1}_model_{2}_hidden_{3}'.format(*vals) for vals in values] +values.append((batch_size, seqdim, 768, 4 * 768)) +# values.append((batch_size, seqdim, 1024, 4*1024)) +# values.append((batch_size, seqdim, 1536, 4*1536)) +# values.append((batch_size, seqdim, 2048, 4*2048)) +# values.append((batch_size, seqdim, 2560, 4*2560)) +# values.append((batch_size, seqdim, 4096, 4*4096)) +# values.append((batch_size, seqdim, 5140, 4*5140)) +# values.append((batch_size, seqdim, 12288, 4*12288)) +names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): formatB = F.get_special_format_str() - A = torch.randn(batch, seq, model, device='cuda').half() - B = torch.empty(hidden, model, dtype=torch.float16, device='cuda') + A = torch.randn(batch, seq, model, device="cuda").half() + B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") torch.nn.init.xavier_uniform_(B) linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() @@ -1613,31 +1751,37 @@ def test_bench_matmul(batch, seq, model, hidden): outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + linearMixedBit = ( + bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + ) linearMixedBit.eval() # warmup for i in range(100): torch.matmul(A, B.t()) torch.cuda.synchronize() - print('') + print("") torch.cuda.synchronize() t0 = time.time() for i in range(100): torch.matmul(A, B.t()) torch.cuda.synchronize() - print(f'pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) torch.cuda.synchronize() t0 = time.time() for i in range(100): bnb.matmul(A, B) torch.cuda.synchronize() - print(f'bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - C32A, SA = F.transform(CA, 'col32') + C32A, SA = F.transform(CA, "col32") CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) CxB, SB = F.transform(CB, to_order=formatB) torch.cuda.synchronize() @@ -1645,7 +1789,9 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(100): out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) torch.cuda.synchronize() - print(f'igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) BA, statsB = F.vectorwise_quant(B, dim=1) CxB, SB = F.nvidia_transform(CB, to_order=formatB) @@ -1654,26 +1800,30 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(100): A2 = A.view(-1, A.shape[-1]).contiguous() CA, statsA = F.vectorwise_quant(A2, dim=1) - C32A, SA = F.nvidia_transform(CA, 'col32') + C32A, SA = F.nvidia_transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) + Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) torch.cuda.synchronize() - print(f'vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) - BA, statsB = F.vectorwise_quant(B, dim=1, quant_type='linear') + BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") CxB, SB = F.nvidia_transform(CB, to_order=formatB) torch.cuda.synchronize() t0 = time.time() for i in range(100): A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type='linear') - C32A, SA = F.nvidia_transform(CA, 'col32') + CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") + C32A, SA = F.nvidia_transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, 'row', state=Sout32) - out = Cout*statsB*statsA*(1.0/(127*127)) + Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + out = Cout * statsB * statsA * (1.0 / (127 * 127)) torch.cuda.synchronize() - print(f'linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) linear8bit(A) torch.cuda.synchronize() @@ -1681,8 +1831,9 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(100): linear8bit(A) torch.cuda.synchronize() - print(f'bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') - + print( + f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) linearMixedBit(A) torch.cuda.synchronize() @@ -1690,65 +1841,66 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(100): linearMixedBit(A) torch.cuda.synchronize() - print(f'bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s') + print( + f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" + ) def test_zeropoint(): def min_max(x): maxA = torch.amax(x, dim=1, keepdim=True) minA = torch.amin(x, dim=1, keepdim=True) - midpoint = (maxA-minA)/2.0 - dyna = 252/(maxA-minA) - #dyna *= 0.98 - x = dyna*x - x = x - torch.round((dyna*(minA+midpoint))) + midpoint = (maxA - minA) / 2.0 + dyna = 252 / (maxA - minA) + # dyna *= 0.98 + x = dyna * x + x = x - torch.round((dyna * (minA + midpoint))) return x.to(torch.int8), minA, midpoint, dyna + batch = 2 seq = 2 model = 4 - hidden = 2*model - #batch = 4 - #seq = 2048 - #model = 1024 - #hidden = 8*model - A = torch.randn(batch*seq, model, device='cuda').half()-0.4 - B = torch.nn.Parameter(torch.randn(model, hidden, device='cuda').half()) + hidden = 2 * model + # batch = 4 + # seq = 2048 + # model = 1024 + # hidden = 8*model + A = torch.randn(batch * seq, model, device="cuda").half() - 0.4 + B = torch.nn.Parameter(torch.randn(model, hidden, device="cuda").half()) - #A[0] = 0 - #B[:, 0] = 0 - #A = A*(A>0) - #A[0, 0] = 0 - #A[0, 0] = 6.0 + # A[0] = 0 + # B[:, 0] = 0 + # A = A*(A>0) + # A[0, 0] = 0 + # A[0, 0] = 6.0 Ac, minA, midpoint, dyna = min_max(A) - #print(Ac[0, 0], 'zero') - #print(Ac, Ac.min(), Ac.max()) - Bc, maxB = F.vectorwise_quant(B, quant_type='linear') + # print(Ac[0, 0], 'zero') + # print(Ac, Ac.min(), Ac.max()) + Bc, maxB = F.vectorwise_quant(B, quant_type="linear") out = F.igemm(Ac, Bc) - out2 = torch.matmul(A,B) - offset = B.sum(0)*torch.round(dyna*(minA+midpoint))/dyna + out2 = torch.matmul(A, B) + offset = B.sum(0) * torch.round(dyna * (minA + midpoint)) / dyna out = out.float() - #print(out.shape, maxB.shape, scale.shape, offset.shape) - norm1 = maxB/127 - C4 = (out/dyna)*norm1+offset - + # print(out.shape, maxB.shape, scale.shape, offset.shape) + norm1 = maxB / 127 + C4 = (out / dyna) * norm1 + offset B1 = torch.nn.Parameter(B.clone()) B2 = torch.nn.Parameter(B.clone()) B3 = torch.nn.Parameter(B.clone()) B4 = torch.nn.Parameter(B.clone()) - C1 = torch.matmul(A, B1) - C2 = bnb.matmul_cublas(A, B2, None, 'linear') - C3 = bnb.matmul_cublas(A, B3, None, 'zeropoint') - C4 = bnb.matmul_cublas(A, B4, None, 'vector-zeropoint') + C2 = bnb.matmul_cublas(A, B2, None, "linear") + C3 = bnb.matmul_cublas(A, B3, None, "zeropoint") + C4 = bnb.matmul_cublas(A, B4, None, "vector-zeropoint") - err1 = torch.abs(C1-C2).mean().item() - err2 = torch.abs(C1-C3).mean().item() - err3 = torch.abs(C1-C4).mean().item() + err1 = torch.abs(C1 - C2).mean().item() + err2 = torch.abs(C1 - C3).mean().item() + err3 = torch.abs(C1 - C4).mean().item() print(err1, err2, err3) - #assert err1 > err2 + # assert err1 > err2 loss1 = C1.mean() loss2 = C2.mean() @@ -1765,40 +1917,38 @@ def test_zeropoint(): print(B2.grad) print(B3.grad) print(B4.grad) - err1 = torch.abs(B1.grad-B2.grad).mean().item() - err2 = torch.abs(B1.grad-B3.grad).mean().item() - err3 = torch.abs(B1.grad-B4.grad).mean().item() + err1 = torch.abs(B1.grad - B2.grad).mean().item() + err2 = torch.abs(B1.grad - B3.grad).mean().item() + err3 = torch.abs(B1.grad - B4.grad).mean().item() print(err1, err2, err3) - - def test_zp(): def quant_zp(x): dtype = x.dtype x = x.float() dyna = x.max() - x.min() - if dyna == 0: dyna = 1 - qx = 254./dyna + if dyna == 0: + dyna = 1 + qx = 254.0 / dyna minx = x.min() - #zpx = torch.round(minx* qx) - #zpx = 127 - torch.round(x.max()* qx) - zpx = torch.round(x.min()* qx) - 127 - x = (qx*x) + zpx + # zpx = torch.round(minx* qx) + # zpx = 127 - torch.round(x.max()* qx) + zpx = torch.round(x.min() * qx) - 127 + x = (qx * x) + zpx return x, qx, zpx + batch = 2 seq = 512 model = 1024 - hidden = 4*model - A = torch.randn(batch*seq, model, device='cuda').half()*0.1 - B = torch.randn(model, hidden, device='cuda').half()*0.1 - + hidden = 4 * model + A = torch.randn(batch * seq, model, device="cuda").half() * 0.1 + B = torch.randn(model, hidden, device="cuda").half() * 0.1 C0 = torch.matmul(A, B) - - #A, SA = F.vectorwise_quant(A, quant_type='linear') - #B, SB = F.vectorwise_quant(B, quant_type='linear') + # A, SA = F.vectorwise_quant(A, quant_type='linear') + # B, SB = F.vectorwise_quant(B, quant_type='linear') A = A.float() B = B.float() @@ -1806,69 +1956,68 @@ def test_zp(): C3 = bnb.matmul(A.half(), B.t().contiguous().half()) zp = 1 - #C2 = torch.matmul(A-zp, B) - #C2 += B.sum(0).view(1, -1)*zp - C2 = torch.matmul(A, B-zp) - C2 -= A.sum(1).view(-1, 1)*zp + # C2 = torch.matmul(A-zp, B) + # C2 += B.sum(0).view(1, -1)*zp + C2 = torch.matmul(A, B - zp) + C2 -= A.sum(1).view(-1, 1) * zp ca, cqa, cza = quant_zp(A) print(ca.min(), ca.max()) - print((ca-cza).min(), (ca-cza).max()) + print((ca - cza).min(), (ca - cza).max()) zp = 1 scale = 2.0 - C5 = torch.matmul((A*scale)-zp, B) - C5 += B.sum(0)*zp + C5 = torch.matmul((A * scale) - zp, B) + C5 += B.sum(0) * zp C5 /= scale CA, qa, zpa = quant_zp(A) C4 = torch.matmul(CA, B) - C4 -= B.sum(0)*zpa + C4 -= B.sum(0) * zpa C4 /= qa zpb = 1 zpa = 1 qa = 2 qb = 2 - C6 = torch.matmul((A*qa)+zpa, (B*qb)+zpb) - C6 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) - C6 -= zpa*zpb*A.shape[1] - C6 /= qa*qb + C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb) + C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb) + C6 -= zpa * zpb * A.shape[1] + C6 /= qa * qb CA, qa, zpa = quant_zp(A) CB, qb, zpb = quant_zp(B) C7 = torch.matmul(CA, CB) - C7 -= (qb*B.sum(0).view(1, -1)*zpa) + (qa*A.sum(1).view(-1, 1)*zpb) - C7 -= zpa*zpb*A.shape[1] - C7 /= qa*qb + C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb) + C7 -= zpa * zpb * A.shape[1] + C7 /= qa * qb - print('') - #print(C0.flatten()[:10]) + print("") + # print(C0.flatten()[:10]) print(C1.flatten()[:10]) print(C2.flatten()[:10]) print(C3.flatten()[:10]) print(C5.flatten()[:10]) print(C6.flatten()[:10]) print(C7.flatten()[:10]) - err1 = torch.abs(C1-C2).mean().item() - err2 = torch.abs(C1-C3).mean().item() - err3 = torch.abs(C1-C4).mean().item() - err4 = torch.abs(C1-C5).mean().item() - err5 = torch.abs(C1-C6).mean().item() - err6 = torch.abs(C1-C7).mean().item() + err1 = torch.abs(C1 - C2).mean().item() + err2 = torch.abs(C1 - C3).mean().item() + err3 = torch.abs(C1 - C4).mean().item() + err4 = torch.abs(C1 - C5).mean().item() + err5 = torch.abs(C1 - C6).mean().item() + err6 = torch.abs(C1 - C7).mean().item() print(err1, err2, err3, err4, err5, err6) - def test_extract_outliers(): for i in range(k): - shapeA = (4096, 4096*4) + shapeA = (4096, 4096 * 4) idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() - #idx = torch.Tensor([0]).int().cuda() - A = torch.randint(-128, 127, size=shapeA, device='cuda').to(torch.int8) + # idx = torch.Tensor([0]).int().cuda() + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) outliers1 = A[:, idx.long()] - CA, SA = F.transform(A, 'col_turing') + CA, SA = F.transform(A, "col_turing") outliers2 = F.extract_outliers(CA, SA, idx) @@ -1877,7 +2026,7 @@ def test_extract_outliers(): torch.testing.assert_allclose(outliers1, outliers2) - CA, SA = F.transform(A, 'col_ampere') + CA, SA = F.transform(A, "col_ampere") outliers2 = F.extract_outliers(CA, SA, idx) diff --git a/tests/test_modules.py b/tests/test_modules.py index a2c950b..6b8d641 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -1,21 +1,27 @@ +from itertools import product + import pytest import torch - -from itertools import product from torch import nn import bitsandbytes as bnb + class MockArgs(object): def __init__(self, initial_data): for key in initial_data: setattr(self, key, initial_data[key]) + class MLP8bit(torch.nn.Module): def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): super(MLP8bit, self).__init__() - self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold) - self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold) + self.fc1 = bnb.nn.Linear8bitLt( + dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold + ) + self.fc2 = bnb.nn.Linear8bitLt( + dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold + ) def forward(self, x): x = self.fc1(x) @@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module): def get_args(): args = MockArgs([]) - args.quant_type = 'vector' - args.use_8bit_training = 'full' + args.quant_type = "vector" + args.use_8bit_training = "full" args.clip_freq = 9999 return args + def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): idx = torch.isclose(a, b, rtol, atol) - sumval = (idx==0).sum().item() + sumval = (idx == 0).sum().item() if sumval > count: - print(f'Too many values not close: assert {sumval} < {count}') + print(f"Too many values not close: assert {sumval} < {count}") torch.testing.assert_allclose(a, b, rtol, atol) -class LinearFunction(torch.autograd.Function): +class LinearFunction(torch.autograd.Function): @staticmethod def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): round_func = LinearFunction.round_stoachastic if stochastic else torch.round - norm = math.sqrt(math.pi)/math.sqrt(2.0) - #std = torch.abs(x).mean()*norm + norm = math.sqrt(math.pi) / math.sqrt(2.0) + # std = torch.abs(x).mean()*norm std = torch.std(x) - max1 = std*trim_value - x = x/max1*127 + max1 = std * trim_value + x = x / max1 * 127 x = round_func(x) x[x > 127] = 127 x[x < -127] = -127 - x = x/127*max1 + x = x / 127 * max1 return x def quant(x, quant_type, dim=1): - if quant_type == 'linear': + if quant_type == "linear": max1 = torch.abs(x).max().float() - xq = torch.round(x/max1*127).to(torch.int8) + xq = torch.round(x / max1 * 127).to(torch.int8) return xq, max1 - elif quant_type == 'vector': + elif quant_type == "vector": max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - xq = torch.round(x/max1*127).to(torch.int8) + xq = torch.round(x / max1 * 127).to(torch.int8) return xq, max1 - elif quant_type == 'min-max': + elif quant_type == "min-max": maxA = torch.amax(x, dim=dim, keepdim=True).float() minA = torch.amin(x, dim=dim, keepdim=True).float() - scale = (maxA-minA)/2.0 - xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8) + scale = (maxA - minA) / 2.0 + xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8) return xq, (minA.float(), scale.float()) - else: return None + else: + return None def dequant(xq, S1, S2, dtype, quant_type): - if quant_type == 'linear': - norm = S1*S2/(127*127) + if quant_type == "linear": + norm = S1 * S2 / (127 * 127) # double cast needed to prevent overflows - return (xq.float()*norm).to(dtype) - elif quant_type == 'vector': + return (xq.float() * norm).to(dtype) + elif quant_type == "vector": x = xq.float() - if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0) - if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0) - #print(x.shape, S1.shape, S2.shape) + if len(xq.shape) == 2 and len(S1.shape) == 3: + S1 = S1.squeeze(0) + if len(xq.shape) == 2 and len(S2.shape) == 3: + S2 = S2.squeeze(0) + # print(x.shape, S1.shape, S2.shape) if len(S1.shape) == 2: - x *= S1.t()/127 + x *= S1.t() / 127 else: - x *= S1/127 - x *= S2/127 + x *= S1 / 127 + x *= S2 / 127 return x.to(dtype) - else: return None + else: + return None def dequant_min_max(xq, A, B, SA, SB, dtype): - offset = B.float().t().sum(0)*(SA[0]+SA[1]) + offset = B.float().t().sum(0) * (SA[0] + SA[1]) x = xq.float() - if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) - if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0) + if len(xq.shape) == 2 and len(SB.shape) == 3: + SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SA.shape) == 3: + SA = SA.squeeze(0) if len(SB.shape) == 2: - x *= SB.t()/127 + x *= SB.t() / 127 else: - x *= SB/127 - x *= SA[1]/127 - x +=offset + x *= SB / 127 + x *= SA[1] / 127 + x += offset return x.to(dtype) - def get_8bit_linear(x, stochastic=False): round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.abs(x).max() - x = x/max1*127 - x = round_func(x)/127*max1 - #x = torch.round(x)/128*max1 + x = x / max1 * 127 + x = round_func(x) / 127 * max1 + # x = torch.round(x)/128*max1 return x @staticmethod def get_8bit_vector_wise(x, dim, stochastic=False): round_func = LinearFunction.round_stoachastic if stochastic else torch.round max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) - max1[max1==0] = 1.0 - x = (x*127)/max1 - x = round_func(x)/127*max1 + max1[max1 == 0] = 1.0 + x = (x * 127) / max1 + x = round_func(x) / 127 * max1 return x @staticmethod def round_stoachastic(x): sign = torch.sign(x) absx = torch.abs(x) - decimal = absx-torch.floor(absx) + decimal = absx - torch.floor(absx) rdm = torch.rand_like(decimal) - return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype)) + return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype)) @staticmethod def fake_8bit_storage(w, exponent_bits): @@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function): @staticmethod def fake_8bit_storage_quantile(w, args): code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) - #C = bnb.functional.quantize_no_absmax(code, w) - #out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) - #print(out) - #out = out.half() + # C = bnb.functional.quantize_no_absmax(code, w) + # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) + # print(out) + # out = out.half() code /= torch.max(torch.abs(code)) absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) out = bnb.functional.dequantize_blockwise(absmax, C, code) @@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function): @staticmethod def fake_8bit_storage_with_max(w, topk=8): - blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256) + blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256) max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) idx = idx[:, :topk] max_val = max_val[:, :topk] @@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function): w.copy_(unblocked_w) return unblocked_w - @staticmethod def forward(ctx, x, weight, bias=None, args=None): - if args.use_8bit_training != 'off': + if args.use_8bit_training != "off": weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) outputq = bnb.functional.igemm(x8, weight8.t()) output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) - #if torch.rand(1) < 0.01: - #output32 = torch.matmul(x, weight.t()) - #err = torch.abs(output-output32).float() - #relerr = err/(torch.abs(output32).float()+1e-8) - #print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) + # if torch.rand(1) < 0.01: + # output32 = torch.matmul(x, weight.t()) + # err = torch.abs(output-output32).float() + # relerr = err/(torch.abs(output32).float()+1e-8) + # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) else: - #output = torch.matmul(x, weight.t()) - output = torch.einsum('bsi,oi->bso', x, weight) + # output = torch.matmul(x, weight.t()) + output = torch.einsum("bsi,oi->bso", x, weight) ctx.save_for_backward(x, weight, bias) ctx.args = args @@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function): args = ctx.args stochastic = False grad_input = grad_weight = grad_bias = None - if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) # weight and x are already 8bit # -> transform grad_output to 8-bit - if args.use_8bit_training == 'forward+wgrad': - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + if args.use_8bit_training == "forward+wgrad": + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=[0, 1] + ) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = bnb.functional.igemm(grad_output8, x8) - grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + grad_weight = LinearFunction.dequant( + grad_weight8, S1, S2, grad_output.dtype, args.quant_type + ) - #grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) + # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) grad_input = grad_output.matmul(weight) - elif args.use_8bit_training == 'full': - grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) + elif args.use_8bit_training == "full": + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=[0, 1] + ) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) bnb.functional.igemm(grad_output8, x8, out=grad_weight8) - grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) + grad_weight = LinearFunction.dequant( + grad_weight8, S1, S2, grad_output.dtype, args.quant_type + ) grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) grad_input8 = bnb.functional.igemm(grad_output8, weight8) - grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) + grad_input = LinearFunction.dequant( + grad_input8, S1, S3, grad_output.dtype, args.quant_type + ) else: grad_input = grad_output.matmul(weight) - grad_weight = torch.einsum('bsi,bso->oi', x, grad_output) + grad_weight = torch.einsum("bsi,bso->oi", x, grad_output) return grad_input, grad_weight, grad_bias, None + class Linear8bit(nn.Module): def __init__(self, input_features, output_features, bias=True, args=None): super(Linear8bit, self).__init__() @@ -263,7 +286,7 @@ class Linear8bit(nn.Module): if bias: self.bias = nn.Parameter(torch.empty(output_features)) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) torch.nn.init.xavier_uniform_(self.weight) if self.bias is not None: @@ -275,12 +298,11 @@ class Linear8bit(nn.Module): return LinearFunction.apply(x, self.weight, self.bias, self.args) - def test_linear8bit(): l0 = torch.nn.Linear(32, 64).cuda().half() - l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half() + l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half() l2 = Linear8bit(32, 64, args=get_args()).cuda().half() - l3 = bnb.nn.Linear8bitLt(32,64).cuda().half() + l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half() l0.weight.data = l2.weight.data.clone() l0.bias.data = l2.bias.data.clone() @@ -292,8 +314,8 @@ def test_linear8bit(): l3.bias.data = l2.bias.data.clone() for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() - t = torch.randn(16, 8, 64, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() + t = torch.randn(16, 8, 64, device="cuda").half() b2 = b1.clone() b3 = b1.clone() b0 = b1.clone() @@ -318,16 +340,20 @@ def test_linear8bit(): assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) - assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) - assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) + assert_all_approx_close( + l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 + ) + assert_all_approx_close( + l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 + ) - err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item() - err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item() - err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item() + err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item() + err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item() + err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item() - assert err1*0.8 < err2 - assert err2*0.8 < err3 - assert err3*0.8 < err1 + assert err1 * 0.8 < err2 + assert err2 * 0.8 < err3 + assert err3 * 0.8 < err1 l0.weight.grad = None l1.weight.grad = None @@ -341,23 +367,28 @@ def test_linear8bit(): threshold = [0.0, 3.0] values = threshold -names = ['threshold_{0}'.format(vals) for vals in values] +names = ["threshold_{0}".format(vals) for vals in values] + + @pytest.mark.parametrize("threshold", values, ids=names) def test_linear8bitlt_inference(threshold): - l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half() - assert l1.weight.device.type == 'cuda' + l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() + assert l1.weight.device.type == "cuda" assert l1.weight.dtype == torch.float16 l1.eval() for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) if i == 1: assert l1.state.CxB is not None + def test_linear8bitlt_accumulated_gradient(): - l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)]) - l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)]) + l1 = torch.nn.Sequential( + *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] + ) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) @@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient(): acc_steps = 10 - for i in range(10): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) o2 = l2(b1) loss1 = o1.mean() @@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient(): opt1.zero_grad(True) opt2.step() opt2.zero_grad(True) - assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) - assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) + assert_all_approx_close( + l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2 + ) + assert_all_approx_close( + l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2 + ) # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) @@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient(): threshold = [0.0, 2.0] values = threshold -names = ['threshold_{0}'.format(vals) for vals in values] +names = ["threshold_{0}".format(vals) for vals in values] + + @pytest.mark.parametrize("threshold", values, ids=names) def test_linear8bitlt_no_fp16_weights(threshold): - l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half() + l1 = ( + bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False) + .cuda() + .half() + ) assert l1.weight.dtype == torch.int8 l1.eval() for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) assert o1.dtype == torch.float16 @@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold): assert mlp.fc2.weight.dtype == torch.int8 for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda') + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda") for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - assert mlp.fc1.weight.device.type == 'cuda' - assert mlp.fc2.weight.device.type == 'cuda' + assert mlp.fc1.weight.device.type == "cuda" + assert mlp.fc2.weight.device.type == "cuda" - mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda') + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .to(torch.float16) + .to("cuda") + ) for i in range(100): - b1 = torch.randn(16, 8, 32, device='cuda').half() + b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = mlp(b1) assert o1.dtype == torch.float16 - if threshold > 0: assert mlp.fc1.state.idx is not None - if threshold > 0: assert mlp.fc2.state.idx is not None + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - assert mlp.fc1.weight.device.type == 'cuda' - assert mlp.fc2.weight.device.type == 'cuda' + assert mlp.fc1.weight.device.type == "cuda" + assert mlp.fc2.weight.device.type == "cuda" diff --git a/tests/test_optim.py b/tests/test_optim.py index b173eaa..b84425e 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -1,81 +1,132 @@ -import os -import time -import shutil -import uuid -import pytest import ctypes +import os +import shutil +import time +import uuid +from itertools import product +from os.path import join + +import pytest import torch + import bitsandbytes as bnb import bitsandbytes.functional as F -from os.path import join -from itertools import product - -#import apex +# import apex k = 20 + def get_temp_dir(): - path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4())) + path = "/tmp/autoswap/{0}".format(str(uuid.uuid4())) os.makedirs(path, exist_ok=True) return path + def rm_path(path): shutil.rmtree(path) + str2optimizers = {} -str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam) -#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) -#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) -str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam) -#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) -#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) +str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) +# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) +# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +str2optimizers["momentum_pytorch"] = ( + None, + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + bnb.optim.Adam, +) +# str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) +# str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) -str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) -#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) -str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) -#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) -str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) -str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) -str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) -#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) -str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) +str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) +# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers["momentum"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["lars"] = ( + lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9), +) +# str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) +str2optimizers["rmsprop"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["adam8bit"] = ( + torch.optim.Adam, + lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), +) +str2optimizers["momentum8bit"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["rmsprop8bit"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), +) +# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) +str2optimizers["lars8bit"] = ( + lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9), +) -str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) -str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) -str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) +str2optimizers["adam8bit_blockwise"] = ( + torch.optim.Adam, + lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), +) +str2optimizers["momentum8bit_blockwise"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), +) +str2optimizers["rmsprop8bit_blockwise"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True), +) str2statenames = {} -str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] -str2statenames['momentum'] = [('momentum_buffer', 'state1')] -str2statenames['lars'] = [('momentum_buffer', 'state1')] -str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] -str2statenames['rmsprop'] = [('square_avg', 'state1')] -str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] -str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] -str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] -str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] -str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] -str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] -str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] -str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] +str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["momentum"] = [("momentum_buffer", "state1")] +str2statenames["lars"] = [("momentum_buffer", "state1")] +str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["rmsprop"] = [("square_avg", "state1")] +str2statenames["adam8bit"] = [ + ("exp_avg", "state1", "qmap1", "max1"), + ("exp_avg_sq", "state2", "qmap2", "max2"), +] +str2statenames["lamb8bit"] = [ + ("exp_avg", "state1", "qmap1", "max1"), + ("exp_avg_sq", "state2", "qmap2", "max2"), +] +str2statenames["adam8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1"), + ("exp_avg_sq", "state2", "qmap2", "absmax2"), +] +str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["momentum8bit_blockwise"] = [ + ("momentum_buffer", "state1", "qmap1", "absmax1") +] +str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] +str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] -values = list(product(dim1,dim2, gtype, optimizer_names)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer32bit(dim1, dim2, gtype, optim_name): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() - torch_optimizer = str2optimizers[optim_name][0]([p1]) bnb_optimizer = str2optimizers[optim_name][1]([p2]) @@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): else: atol, rtol = 1e-4, 1e-3 - for i in range(k): - g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): torch_optimizer.step() for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) + torch.testing.assert_allclose( + torch_optimizer.state[p1][name1], + bnb_optimizer.state[p2][name2], + atol=atol, + rtol=rtol, + ) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) - if i % (k//5) == 0 and i > 0: + if i % (k // 5) == 0 and i > 0: path = get_temp_dir() - torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) + torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt")) del bnb_optimizer bnb_optimizer = None bnb_optimizer = str2optimizers[optim_name][1]([p2]) - bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) + bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) + torch.testing.assert_allclose( + torch_optimizer.state[p1][name1], + bnb_optimizer.state[p2][name2], + atol=atol, + rtol=rtol, + ) if gtype == torch.float16: # the adam buffers should also be close because they are 32-bit @@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): p1.data = p1.data.half().float() p2.copy_(p1.data) torch.testing.assert_allclose(p1.half(), p2) - if optim_name in ['lars', 'lamb']: - assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0 + if optim_name in ["lars", "lamb"]: + assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 + dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -values = list(product(dim1,dim2, gtype)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, gtype)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) def test_global_config(dim1, dim2, gtype): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 - p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 - p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 + p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 + p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 mask = torch.rand_like(p2) < 0.1 beta1 = 0.9 beta2 = 0.999 @@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype): eps = 1e-8 bnb.optim.GlobalOptimManager.get_instance().initialize() - bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) + bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) p1 = p1.cuda() @@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype): atol, rtol = 1e-4, 1e-3 for i in range(50): - g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 - g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 - g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 p1.grad = g1 p2.grad = g2 p3.grad = g3 adam2.step() - assert adam2.state[p3]['state1'].dtype == torch.uint8 - assert adam2.state[p3]['state2'].dtype == torch.uint8 - + assert adam2.state[p3]["state1"].dtype == torch.uint8 + assert adam2.state[p3]["state2"].dtype == torch.uint8 dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] -values = list(product(dim1,dim2, gtype, optimizer_names)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +optimizer_names = [ + "adam8bit", + "momentum8bit", + "rmsprop8bit", + "adam8bit_blockwise", + "lamb8bit", + "lars8bit", + "momentum8bit_blockwise", + "rmsprop8bit_blockwise", +] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer8bit(dim1, dim2, gtype, optim_name): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() blocksize = 2048 @@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerrors = [] for i in range(50): - g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): dequant_states = [] for name1, name2, qmap, max_val in str2statenames[optim_name]: - #print(bnb_optimizer.state[p2][max_val], name1) - if 'blockwise' in optim_name: - s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) + # print(bnb_optimizer.state[p2][max_val], name1) + if "blockwise" in optim_name: + s1 = F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + blocksize=blocksize, + ) else: - s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) - num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 + s1 = F.dequantize( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + ) + num_not_close = ( + torch.isclose( + torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol + ) + == 0 + ) assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) - err = torch.abs(p1-p2) - relerr = err/torch.abs(p1) + err = torch.abs(p1 - p2) + relerr = err / torch.abs(p1) assert err.mean() < 0.0001 assert relerr.mean() < 0.001 @@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerrors.append(relerr.mean().item()) if i % 10 == 0 and i > 0: - for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): + for (name1, name2, qmap, max_val), s in zip( + str2statenames[optim_name], dequant_states + ): s1cpy = s.clone() raws1cpy = bnb_optimizer.state[p2][name2].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone() path = get_temp_dir() - torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) + torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt")) del bnb_optimizer bnb_optimizer = None bnb_optimizer = str2optimizers[optim_name][1]([p2]) - bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) + bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) - if 'blockwise' in optim_name: - s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) + if "blockwise" in optim_name: + s1 = F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + blocksize=blocksize, + ) else: - s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) + s1 = F.dequantize( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + ) torch.testing.assert_allclose(s1cpy, s1) - num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 + num_not_close = ( + torch.isclose( + torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol + ) + == 0 + ) assert num_not_close.sum().item() < 20 torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) @@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): p1.data = p1.data.to(gtype).float() p2.copy_(p1.data) torch.testing.assert_allclose(p1.to(gtype), p2) - for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): + for (name1, name2, qmap, max_val), s in zip( + str2statenames[optim_name], dequant_states + ): torch_optimizer.state[p1][name1].copy_(s.data) - #print(sum(errors)/len(errors)) - #print(sum(relerrors)/len(relerrors)) - + # print(sum(errors)/len(errors)) + # print(sum(relerrors)/len(relerrors)) dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32] optim_bits = [32, 8] -values = list(product(dim1,dim2, gtype, optim_bits)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values] +values = list(product(dim1, dim2, gtype, optim_bits)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 beta1 = 0.9 beta2 = 0.999 lr = 0.001 @@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): p1 = p1.cuda() p2 = p1.clone() adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) - adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) + adam2 = bnb.optim.Adam( + [p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5 + ) gnorm_vec = torch.zeros(100).cuda() step = 0 for i in range(50): step += 1 - g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i) + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) g2 = g1.clone() p2.grad = g2 - current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) - g1 = (g1.float()*gnorm_scale).to(gtype) + current_gnorm, clip_val, gnorm_scale = F.percentile_clipping( + g1, gnorm_vec, step, 5 + ) + g1 = (g1.float() * gnorm_scale).to(gtype) p1.grad = g1 adam1.step() @@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state if optim_bits == 32: torch.testing.assert_allclose(p1, p2) - torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4) - torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4) + torch.testing.assert_allclose( + adam1.state[p1]["state1"], + adam2.state[p2]["state1"], + atol=5e-5, + rtol=1e-4, + ) + torch.testing.assert_allclose( + adam1.state[p1]["state2"], + adam2.state[p2]["state2"], + atol=5e-5, + rtol=1e-4, + ) elif optim_bits == 8: torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) - torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3) - torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3) - adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1']) - adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2']) + torch.testing.assert_allclose( + adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3 + ) + torch.testing.assert_allclose( + adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3 + ) + adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"]) + adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"]) if i % 10 == 0 and i > 0: path = get_temp_dir() - torch.save(adam2.state_dict(),join(path, 'opt.pt')) + torch.save(adam2.state_dict(), join(path, "opt.pt")) del adam2 adam2 = None - adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) - adam2.load_state_dict(torch.load(join(path, 'opt.pt'))) - - + adam2 = bnb.optim.Adam( + [p2], + lr, + (beta1, beta2), + eps, + optim_bits=optim_bits, + percentile_clipping=5, + ) + adam2.load_state_dict(torch.load(join(path, "opt.pt"))) dim1 = [4096] dim2 = [4096] gtype = [torch.float32, torch.float16] -#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] -#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] -#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] -#optimizer_names = ['lamb_apex', 'lamb8bit'] -#optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ['adam8bit_blockwise'] -values = list(product(dim1,dim2, gtype, optimizer_names)) -names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] +# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] +# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] +# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] +# optimizer_names = ['lamb_apex', 'lamb8bit'] +# optimizer_names = ['lars_apex', 'lars8bit'] +optimizer_names = ["adam8bit_blockwise"] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] + + @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): - if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 bnb_optimizer = str2optimizers[optim_name][1]([p1]) - g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g for i in range(k): - if i == k//5: + if i == k // 5: # 100 iterations for burn-in torch.cuda.synchronize() t0 = time.time() @@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch.cuda.synchronize() - s = time.time()-t0 - print('') - params = (k-k//5)*dim1*dim2 - print(optim_name, gtype, s/params) - #assert s < 3.9 - - + s = time.time() - t0 + print("") + params = (k - k // 5) * dim1 * dim2 + print(optim_name, gtype, s / params) + # assert s < 3.9