forked from mrq/bitsandbytes-rocm
Merge branch 'main' into main
This commit is contained in:
commit
be5cecb88f
|
@ -49,7 +49,7 @@ Features:
|
|||
Bug fixes:
|
||||
- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13
|
||||
- Fixed an unsafe use of eval. #8
|
||||
- Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15
|
||||
- Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15
|
||||
|
||||
Docs:
|
||||
- Added instructions how to solve "\_\_fatbinwrap_" errors.
|
||||
|
@ -149,3 +149,9 @@ Bug fixes:
|
|||
|
||||
Bug fixes:
|
||||
- Fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected.
|
||||
|
||||
### 0.35.4
|
||||
|
||||
Bug fixes:
|
||||
- Fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library.
|
||||
- Fixed a bug where not finding the cuda runtime led to an incomprehensible error.
|
||||
|
|
|
@ -28,4 +28,4 @@ outlined on that page and do not file a public issue.
|
|||
|
||||
## License
|
||||
By contributing to bitsandbytes, you agree that your contributions will be licensed
|
||||
under the LICENSE file in the root directory of this source tree.
|
||||
under the LICENSE file in the root directory of this source tree.
|
||||
|
|
24
Makefile
24
Makefile
|
@ -26,14 +26,14 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
|
|||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
|
||||
# NVIDIA NVCC compilation flags
|
||||
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
|
||||
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
|
||||
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
||||
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
||||
|
@ -58,38 +58,38 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
|||
|
||||
|
||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda110_nomatmul: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda11x_nomatmul: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda110: $(BUILD_DIR) env
|
||||
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cuda11x: $(BUILD_DIR) env
|
||||
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cpuonly: $(BUILD_DIR) env
|
||||
|
@ -117,7 +117,7 @@ $(ROOT_DIR)/dependencies/cub:
|
|||
cd dependencies/cub; git checkout 1.11.0
|
||||
|
||||
clean:
|
||||
rm build/*
|
||||
rm build/*
|
||||
|
||||
cleaneggs:
|
||||
rm -rf *.egg*
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# bitsandbytes
|
||||
|
||||
The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions.
|
||||
The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions.
|
||||
|
||||
|
||||
|
||||
|
@ -48,7 +48,7 @@ out = linear(x.to(torch.float16))
|
|||
|
||||
Requirements: anaconda, cudatoolkit, pytorch
|
||||
|
||||
Hardware requirements:
|
||||
Hardware requirements:
|
||||
- LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or older).
|
||||
- 8-bit optimizers and quantization: NVIDIA Maxwell GPU or newer (>=GTX 9XX).
|
||||
|
||||
|
@ -87,7 +87,7 @@ Note that by default all parameter tensors with less than 4096 elements are kept
|
|||
```
|
||||
# parameter tensors with less than 16384 values are optimized in 32-bit
|
||||
# it is recommended to use multiplies of 4096
|
||||
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
|
||||
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
|
||||
```
|
||||
|
||||
### Change Bits and other Hyperparameters for Individual Parameters
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import cuda_setup, utils
|
||||
from .autograd._functions import (
|
||||
MatmulLtState,
|
||||
bmm_cublas,
|
||||
|
@ -12,7 +13,6 @@ from .autograd._functions import (
|
|||
)
|
||||
from .cextension import COMPILED_WITH_CUDA
|
||||
from .nn import modules
|
||||
from . import cuda_setup, utils
|
||||
|
||||
if COMPILED_WITH_CUDA:
|
||||
from .optim import adam
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
# from bitsandbytes.debug_cli import cli
|
||||
|
||||
# cli()
|
||||
import os
|
||||
import sys
|
||||
from warnings import warn
|
||||
|
@ -31,8 +28,8 @@ print()
|
|||
|
||||
|
||||
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
|
||||
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
|
||||
from .cuda_setup.env_vars import to_be_ignored
|
||||
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
|
||||
|
||||
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
|
||||
for k, v in os.environ.items():
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import operator
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce # Required in Python 3
|
||||
|
||||
import torch
|
||||
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
|
||||
# math.prod not compatible with python < 3.8
|
||||
def prod(iterable):
|
||||
return reduce(operator.mul, iterable, 1)
|
||||
|
@ -15,10 +16,10 @@ tensor = torch.Tensor
|
|||
|
||||
"""
|
||||
This class pools outlier dimensions across layers.
|
||||
This is particularly important for small models where outlier features
|
||||
This is particularly important for small models where outlier features
|
||||
are less systematic and occur with low frequency.
|
||||
"""
|
||||
class GlobalOutlierPooler(object):
|
||||
class GlobalOutlierPooler:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
|
@ -49,8 +50,9 @@ class GlobalOutlierPooler(object):
|
|||
|
||||
class MatMul8bit(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
|
||||
|
||||
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
|
||||
if precision is None:
|
||||
precision = [8, 8, 8]
|
||||
if precision[0] != 8:
|
||||
with torch.no_grad():
|
||||
output = torch.matmul(A, B)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import ctypes as ct
|
||||
import torch
|
||||
|
||||
from pathlib import Path
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
|
||||
class CUDASetup(object):
|
||||
|
||||
class CUDASetup:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
|
@ -52,8 +52,13 @@ class CUDASetup(object):
|
|||
self.add_log_entry('python setup.py install')
|
||||
|
||||
def initialize(self):
|
||||
self.cuda_setup_log = []
|
||||
self.has_printed = False
|
||||
self.lib = None
|
||||
self.run_cuda_setup()
|
||||
|
||||
def run_cuda_setup(self):
|
||||
self.initialized = True
|
||||
self.cuda_setup_log = []
|
||||
|
||||
from .cuda_setup.main import evaluate_cuda_setup
|
||||
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
|
||||
|
@ -89,7 +94,8 @@ class CUDASetup(object):
|
|||
else:
|
||||
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
|
||||
self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||
except:
|
||||
except Exception as ex:
|
||||
self.add_log_entry(str(ex))
|
||||
self.print_log_stack()
|
||||
|
||||
def add_log_entry(self, msg, is_warning=False):
|
||||
|
@ -116,7 +122,7 @@ try:
|
|||
CUDASetup.get_instance().generate_instructions()
|
||||
CUDASetup.get_instance().print_log_stack()
|
||||
raise RuntimeError('''
|
||||
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs to fix your environment!
|
||||
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
|
||||
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
|
||||
https://github.com/TimDettmers/bitsandbytes/issues''')
|
||||
lib.cadam32bit_g32
|
||||
|
@ -124,8 +130,6 @@ try:
|
|||
lib.get_cusparse.restype = ct.c_void_p
|
||||
COMPILED_WITH_CUDA = True
|
||||
except AttributeError:
|
||||
warn(
|
||||
"The installed version of bitsandbytes was compiled without GPU support. "
|
||||
"8-bit optimizers and GPU quantization are unavailable."
|
||||
)
|
||||
warn("The installed version of bitsandbytes was compiled without GPU support. "
|
||||
"8-bit optimizers and GPU quantization are unavailable.")
|
||||
COMPILED_WITH_CUDA = False
|
||||
|
|
|
@ -1,2 +1,6 @@
|
|||
from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path
|
||||
from .main import evaluate_cuda_setup
|
||||
from .paths import (
|
||||
CUDA_RUNTIME_LIB,
|
||||
determine_cuda_runtime_lib_path,
|
||||
extract_candidate_paths,
|
||||
)
|
||||
|
|
|
@ -19,9 +19,12 @@ evaluation:
|
|||
import ctypes
|
||||
import os
|
||||
|
||||
from .paths import determine_cuda_runtime_lib_path
|
||||
import torch
|
||||
|
||||
from bitsandbytes.cextension import CUDASetup
|
||||
|
||||
from .paths import determine_cuda_runtime_lib_path
|
||||
|
||||
|
||||
def check_cuda_result(cuda, result_val):
|
||||
# 3. Check for CUDA errors
|
||||
|
@ -30,8 +33,11 @@ def check_cuda_result(cuda, result_val):
|
|||
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
||||
CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}")
|
||||
|
||||
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
||||
def get_cuda_version(cuda, cudart_path):
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
||||
if cuda is None: return None
|
||||
|
||||
try:
|
||||
cudart = ctypes.CDLL(cudart_path)
|
||||
except OSError:
|
||||
|
@ -45,7 +51,7 @@ def get_cuda_version(cuda, cudart_path):
|
|||
minor = (version-(major*1000))//10
|
||||
|
||||
if major < 11:
|
||||
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
|
||||
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
|
||||
|
||||
return f'{major}{minor}'
|
||||
|
||||
|
@ -73,7 +79,6 @@ def get_compute_capabilities(cuda):
|
|||
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
||||
"""
|
||||
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
@ -100,44 +105,45 @@ def get_compute_capability(cuda):
|
|||
capabilities are downwards compatible. If no GPUs are detected, it returns
|
||||
None.
|
||||
"""
|
||||
if cuda is None: return None
|
||||
|
||||
# TODO: handle different compute capabilities; for now, take the max
|
||||
ccs = get_compute_capabilities(cuda)
|
||||
if ccs:
|
||||
# TODO: handle different compute capabilities; for now, take the max
|
||||
return ccs[-1]
|
||||
return None
|
||||
if ccs: return ccs[-1]
|
||||
|
||||
|
||||
def evaluate_cuda_setup():
|
||||
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
|
||||
print('')
|
||||
print('=' * 35 + 'BUG REPORT' + '=' * 35)
|
||||
print('='*35 + 'BUG REPORT' + '='*35)
|
||||
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
||||
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
||||
print('To hide this message, set the BITSANDBYTES_NOWELCOME variable like so: export BITSANDBYTES_NOWELCOME=1')
|
||||
print('=' * 80)
|
||||
|
||||
# if not torch.cuda.is_available():
|
||||
# print('No GPU detected. Loading CPU library...')
|
||||
# return binary_name
|
||||
|
||||
binary_name = "libbitsandbytes_cpu.so"
|
||||
print('='*80)
|
||||
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None
|
||||
|
||||
cuda_setup = CUDASetup.get_instance()
|
||||
cudart_path = determine_cuda_runtime_lib_path()
|
||||
if cudart_path is None:
|
||||
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
|
||||
return binary_name
|
||||
|
||||
cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}"))
|
||||
cuda = get_cuda_lib_handle()
|
||||
cc = get_compute_capability(cuda)
|
||||
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
|
||||
cuda_version_string = get_cuda_version(cuda, cudart_path)
|
||||
|
||||
failure = False
|
||||
if cudart_path is None:
|
||||
failure = True
|
||||
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
|
||||
else:
|
||||
cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}")
|
||||
|
||||
if cc == '' or cc is None:
|
||||
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Processing to load CPU-only library...", is_warning=True)
|
||||
return binary_name, cudart_path, cuda, cc, cuda_version_string
|
||||
failure = True
|
||||
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True)
|
||||
else:
|
||||
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
|
||||
|
||||
if cuda is None:
|
||||
failure = True
|
||||
else:
|
||||
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
|
||||
|
||||
# 7.5 is the minimum CC vor cublaslt
|
||||
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
|
||||
|
@ -148,16 +154,13 @@ def evaluate_cuda_setup():
|
|||
|
||||
# we use ls -l instead of nvcc to determine the cuda version
|
||||
# since most installations will have the libcudart.so installed, but not the compiler
|
||||
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
|
||||
|
||||
def get_binary_name():
|
||||
if failure:
|
||||
binary_name = "libbitsandbytes_cpu.so"
|
||||
elif has_cublaslt:
|
||||
binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so"
|
||||
else:
|
||||
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
|
||||
bin_base_name = "libbitsandbytes_cuda"
|
||||
if has_cublaslt:
|
||||
return f"{bin_base_name}{cuda_version_string}.so"
|
||||
else:
|
||||
return f"{bin_base_name}{cuda_version_string}_nocublaslt.so"
|
||||
|
||||
binary_name = get_binary_name()
|
||||
binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so"
|
||||
|
||||
return binary_name, cudart_path, cuda, cc, cuda_version_string
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import errno
|
||||
from pathlib import Path
|
||||
from typing import Set, Union
|
||||
|
||||
from bitsandbytes.cextension import CUDASetup
|
||||
|
||||
from .env_vars import get_potentially_lib_path_containing_env_vars
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
import typer
|
||||
|
||||
cli = typer.Typer()
|
||||
|
||||
|
||||
@cli.callback()
|
||||
def callback():
|
||||
"""
|
||||
Awesome Portal Gun
|
||||
"""
|
||||
|
||||
|
||||
@cli.command()
|
||||
def shoot():
|
||||
"""
|
||||
Shoot the portal gun
|
||||
"""
|
||||
typer.echo("Shooting portal gun")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def load():
|
||||
"""
|
||||
Load the portal gun
|
||||
"""
|
||||
typer.echo("Loading portal gun")
|
|
@ -3,15 +3,19 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import ctypes as ct
|
||||
import itertools
|
||||
import operator
|
||||
import random
|
||||
import torch
|
||||
import itertools
|
||||
import math
|
||||
|
||||
from functools import reduce # Required in Python 3
|
||||
from typing import Tuple
|
||||
from torch import Tensor
|
||||
|
||||
from .cextension import COMPILED_WITH_CUDA, lib
|
||||
from functools import reduce # Required in Python 3
|
||||
|
||||
|
||||
# math.prod not compatible with python < 3.8
|
||||
def prod(iterable):
|
||||
|
@ -82,7 +86,7 @@ if COMPILED_WITH_CUDA:
|
|||
)
|
||||
|
||||
|
||||
class CUBLAS_Context(object):
|
||||
class CUBLAS_Context:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
|
@ -112,7 +116,7 @@ class CUBLAS_Context(object):
|
|||
return self.context[device.index]
|
||||
|
||||
|
||||
class Cusparse_Context(object):
|
||||
class Cusparse_Context:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
|
@ -129,14 +133,73 @@ class Cusparse_Context(object):
|
|||
return cls._instance
|
||||
|
||||
|
||||
def create_linear_map(signed=True):
|
||||
if signed:
|
||||
return torch.linspace(-1.0, 1.0, 256)
|
||||
def create_linear_map(signed=True, total_bits=8, add_zero=True):
|
||||
sign = (-1.0 if signed else 0.0)
|
||||
total_values = 2**total_bits
|
||||
if add_zero or total_bits < 8:
|
||||
# add a zero
|
||||
# since we simulate less bits by having zeros in the data type, we
|
||||
# we need to center the quantization around zero and as such lose
|
||||
# a single value
|
||||
total_values = (2**total_bits if not signed else 2**total_bits-1)
|
||||
|
||||
values = torch.linspace(sign, 1.0, total_values)
|
||||
gap = 256 - values.numel()
|
||||
if gap == 0:
|
||||
return values
|
||||
else:
|
||||
return torch.linspace(0.0, 1.0, 256)
|
||||
l = values.numel()//2
|
||||
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
|
||||
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
|
||||
|
||||
|
||||
def create_dynamic_map(signed=True, n=7):
|
||||
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
|
||||
e = exponent_bits
|
||||
p = precision_bits
|
||||
has_sign = 1 if signed else 0
|
||||
assert e+p == total_bits-has_sign
|
||||
# the exponent is biased to 2^(e-1) -1 == 0
|
||||
evalues = []
|
||||
pvalues = []
|
||||
for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)):
|
||||
evalues.append(2**val)
|
||||
|
||||
|
||||
values = []
|
||||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||
#for ev in evalues:
|
||||
bias = 2**(exponent_bits-1)-1
|
||||
for evalue in range(2**(exponent_bits)):
|
||||
for bit_pattern in lst:
|
||||
value = (1 if evalue != 0 else 0)
|
||||
for i, pval in enumerate(list(bit_pattern)):
|
||||
value += pval*(2**-(i+1))
|
||||
if evalue == 0:
|
||||
# subnormals
|
||||
value = value*2**-(bias-1)
|
||||
else:
|
||||
# normals
|
||||
value = value*2**-(evalue-bias-2)
|
||||
values.append(value)
|
||||
if signed:
|
||||
values.append(-value)
|
||||
|
||||
|
||||
assert len(values) == 2**total_bits
|
||||
values.sort()
|
||||
if total_bits < 8:
|
||||
gap = 256 - len(values)
|
||||
for i in range(gap):
|
||||
values.append(0)
|
||||
values.sort()
|
||||
code = torch.Tensor(values)
|
||||
code /= code.max()
|
||||
|
||||
return code
|
||||
|
||||
|
||||
|
||||
def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
|
||||
"""
|
||||
Creates the dynamic quantiztion map.
|
||||
|
||||
|
@ -157,41 +220,58 @@ def create_dynamic_map(signed=True, n=7):
|
|||
# these are additional items that come from the case
|
||||
# where all the exponent bits are zero and no
|
||||
# indicator bit is present
|
||||
additional_items = 2 ** (7 - n) - 1
|
||||
non_sign_bits = total_bits - (1 if signed else 0)
|
||||
additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1
|
||||
if not signed:
|
||||
additional_items = 2 * additional_items
|
||||
for i in range(n):
|
||||
fraction_items = (
|
||||
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
|
||||
)
|
||||
for i in range(max_exponent_bits):
|
||||
fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1))
|
||||
boundaries = torch.linspace(0.1, 1, fraction_items)
|
||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
||||
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||
if signed:
|
||||
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
|
||||
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||
|
||||
if additional_items > 0:
|
||||
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
||||
if signed:
|
||||
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
|
||||
if additional_items > 0:
|
||||
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||
if signed:
|
||||
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
|
||||
|
||||
data.append(0)
|
||||
data.append(1.0)
|
||||
|
||||
gap = 256 - len(data)
|
||||
for i in range(gap):
|
||||
data.append(0)
|
||||
|
||||
data.sort()
|
||||
return Tensor(data)
|
||||
|
||||
def create_quantile_map(A, total_bits=8):
|
||||
q = estimate_quantiles(A, num_quantiles=2**total_bits-1)
|
||||
q = q.tolist()
|
||||
q.append(0)
|
||||
|
||||
gap = 256 - len(q)
|
||||
for i in range(gap):
|
||||
q.append(0)
|
||||
|
||||
q.sort()
|
||||
|
||||
q = Tensor(q)
|
||||
q = q/q.abs().max()
|
||||
return q
|
||||
|
||||
def get_special_format_str():
|
||||
if not torch.cuda.is_available(): return 'col_turing'
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
if major <= 7:
|
||||
return "col_turing"
|
||||
elif major == 8:
|
||||
if major == 8:
|
||||
return "col_ampere"
|
||||
else:
|
||||
return "col_turing"
|
||||
return "col_turing"
|
||||
|
||||
|
||||
|
||||
|
@ -318,16 +398,12 @@ def nvidia_transform(
|
|||
dim2 = ct.c_int32(shape[2])
|
||||
|
||||
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
||||
ptrA = get_ptr(A)
|
||||
ptrOut = get_ptr(out)
|
||||
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
|
||||
|
||||
return out, new_state
|
||||
|
||||
|
||||
def estimate_quantiles(
|
||||
A: Tensor, out: Tensor = None, offset: float = 1 / 512
|
||||
) -> Tensor:
|
||||
def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor:
|
||||
'''
|
||||
Estimates 256 equidistant quantiles on the input tensor eCDF.
|
||||
|
||||
|
@ -347,25 +423,37 @@ def estimate_quantiles(
|
|||
out : torch.Tensor
|
||||
Tensor with the 256 estimated quantiles.
|
||||
offset : float
|
||||
The offset for the first and last quantile from 0 and 1. Default: 1/512
|
||||
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
|
||||
num_quantiles : int
|
||||
The number of equally spaced quantiles.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor:
|
||||
The 256 quantiles in float32 datatype.
|
||||
'''
|
||||
if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.')
|
||||
if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}")
|
||||
if num_quantiles < 256 and offset == 1/(512):
|
||||
# override default arguments
|
||||
offset = 1/(2*num_quantiles)
|
||||
|
||||
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
|
||||
is_on_gpu([A, out])
|
||||
device = pre_call(A.device)
|
||||
if A.dtype == torch.float32:
|
||||
lib.cestimate_quantiles_fp32(
|
||||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
||||
)
|
||||
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cestimate_quantiles_fp16(
|
||||
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
||||
)
|
||||
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise NotImplementedError(f"Not supported data type {A.dtype}")
|
||||
post_call(device)
|
||||
|
||||
if num_quantiles < 256:
|
||||
step = round(256/num_quantiles)
|
||||
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
|
||||
out = out[idx]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
@ -398,15 +486,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
The quantization state to undo the quantization.
|
||||
"""
|
||||
|
||||
|
||||
if code is None:
|
||||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
code = name2qmap["dynamic"]
|
||||
code = code.to(A.device)
|
||||
|
||||
if absmax is None:
|
||||
n = A.numel()
|
||||
blocksize = (blocksize if A.device.type == 'cpu' else 4096)
|
||||
blocks = n // blocksize
|
||||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
|
@ -415,8 +502,13 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
is_on_gpu([code, A, absmax, out, rand])
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||
cblocksize = ct.c_int32(blocksize)
|
||||
prev_device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if rand is not None:
|
||||
is_on_gpu([code, A, out, absmax, rand])
|
||||
assert blocksize==4096
|
||||
assert rand.numel() >= 1024
|
||||
rand_offset = random.randint(0, 1023)
|
||||
if A.dtype == torch.float32:
|
||||
|
@ -424,20 +516,19 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
else:
|
||||
is_on_gpu([code, A, out, absmax])
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
else:
|
||||
# cpu
|
||||
code = code.cpu()
|
||||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
|
@ -482,27 +573,30 @@ def dequantize_blockwise(
|
|||
if "dynamic" not in name2qmap:
|
||||
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
||||
code = name2qmap["dynamic"]
|
||||
code = code.to(A.device)
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
if quant_state is None:
|
||||
quant_state = (absmax, code)
|
||||
else:
|
||||
absmax, code = quant_state
|
||||
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
if blocksize not in [2048, 4096]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
|
||||
device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
|
||||
is_on_gpu([A, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.float16:
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
||||
)
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
else:
|
||||
code = code.cpu()
|
||||
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
return out
|
||||
|
@ -958,7 +1052,7 @@ def histogram_scatter_add_2d(
|
|||
|
||||
maxdim1 = ct.c_int32(histogram.shape[0])
|
||||
n = ct.c_int32(index1.numel())
|
||||
is_on_gpu([histogram, index1, index2d, source])
|
||||
is_on_gpu([histogram, index1, index2, source])
|
||||
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
|
||||
|
||||
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
|
||||
|
@ -1133,7 +1227,7 @@ def igemm(
|
|||
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
||||
|
||||
# B^T @ A^T = C^T
|
||||
# [km, nk -> mn]
|
||||
# [km, nk -> mn]
|
||||
is_on_gpu([B, A, out])
|
||||
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
|
||||
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
|
||||
|
@ -1417,7 +1511,7 @@ def get_colrow_absmax(
|
|||
return row_stats, col_stats, nnz_block_ptr
|
||||
|
||||
|
||||
class COOSparseTensor(object):
|
||||
class COOSparseTensor:
|
||||
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
|
||||
assert rowidx.dtype == torch.int32
|
||||
assert colidx.dtype == torch.int32
|
||||
|
@ -1434,7 +1528,7 @@ class COOSparseTensor(object):
|
|||
self.values = values
|
||||
|
||||
|
||||
class CSRSparseTensor(object):
|
||||
class CSRSparseTensor:
|
||||
def __init__(self, rows, cols, nnz, rowptr, colidx, values):
|
||||
assert rowptr.dtype == torch.int32
|
||||
assert colidx.dtype == torch.int32
|
||||
|
@ -1451,7 +1545,7 @@ class CSRSparseTensor(object):
|
|||
self.values = values
|
||||
|
||||
|
||||
class CSCSparseTensor(object):
|
||||
class CSCSparseTensor:
|
||||
def __init__(self, rows, cols, nnz, colptr, rowidx, values):
|
||||
assert colptr.dtype == torch.int32
|
||||
assert rowidx.dtype == torch.int32
|
||||
|
@ -1615,8 +1709,6 @@ def transform(A, to_order, from_order='row', out=None, transpose=False, state=No
|
|||
dim1 = ct.c_int32(shape[0] * shape[1])
|
||||
dim2 = ct.c_int32(shape[2])
|
||||
|
||||
ptrA = get_ptr(A)
|
||||
ptrOut = get_ptr(out)
|
||||
is_on_gpu([A, out])
|
||||
if to_order == 'col32':
|
||||
if transpose:
|
||||
|
|
|
@ -2,24 +2,11 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from typing import Optional, TypeVar, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, device, dtype, nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.optim import GlobalOptimManager
|
||||
|
@ -39,7 +26,7 @@ class StableEmbedding(torch.nn.Embedding):
|
|||
sparse: bool = False,
|
||||
_weight: Optional[Tensor] = None,
|
||||
) -> None:
|
||||
super(StableEmbedding, self).__init__(
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx,
|
||||
|
@ -96,7 +83,7 @@ class Embedding(torch.nn.Embedding):
|
|||
sparse: bool = False,
|
||||
_weight: Optional[Tensor] = None,
|
||||
) -> None:
|
||||
super(Embedding, self).__init__(
|
||||
super().__init__(
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx,
|
||||
|
@ -225,7 +212,7 @@ class Linear8bitLt(nn.Linear):
|
|||
threshold=0.0,
|
||||
index=None,
|
||||
):
|
||||
super(Linear8bitLt, self).__init__(
|
||||
super().__init__(
|
||||
input_features, output_features, bias
|
||||
)
|
||||
self.state = bnb.MatmulLtState()
|
||||
|
@ -267,7 +254,7 @@ class Linear8bitLt(nn.Linear):
|
|||
self.weight.data = self.state.CxB
|
||||
elif self.state.memory_efficient_backward and self.state.CxB is not None:
|
||||
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
|
||||
# Thus, we delete CxB from the state.
|
||||
# Thus, we delete CxB from the state.
|
||||
del self.state.CxB
|
||||
|
||||
return out
|
||||
|
|
|
@ -5,12 +5,11 @@
|
|||
|
||||
from bitsandbytes.cextension import COMPILED_WITH_CUDA
|
||||
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
from .adam import Adam, Adam8bit, Adam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .optimizer import GlobalOptimManager
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
|
|
|
@ -21,18 +21,18 @@ class Adagrad(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if initial_accumulator_value != 0.0:
|
||||
raise ValueError("Initial accumulator value != 0.0 not supported!")
|
||||
if lr_decay != 0.0:
|
||||
raise ValueError("Lr Decay != 0.0 not supported!")
|
||||
super(Adagrad, self).__init__(
|
||||
super().__init__(
|
||||
"adagrad",
|
||||
params,
|
||||
lr,
|
||||
|
@ -63,19 +63,19 @@ class Adagrad8bit(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if initial_accumulator_value != 0.0:
|
||||
raise ValueError("Initial accumulator value != 0.0 not supported!")
|
||||
if lr_decay != 0.0:
|
||||
raise ValueError("Lr Decay != 0.0 not supported!")
|
||||
assert block_wise
|
||||
super(Adagrad8bit, self).__init__(
|
||||
super().__init__(
|
||||
"adagrad",
|
||||
params,
|
||||
lr,
|
||||
|
@ -106,18 +106,18 @@ class Adagrad32bit(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if initial_accumulator_value != 0.0:
|
||||
raise ValueError("Initial accumulator value != 0.0 not supported!")
|
||||
if lr_decay != 0.0:
|
||||
raise ValueError("Lr Decay != 0.0 not supported!")
|
||||
super(Adagrad32bit, self).__init__(
|
||||
super().__init__(
|
||||
"adagrad",
|
||||
params,
|
||||
lr,
|
||||
|
|
|
@ -28,7 +28,7 @@ class Adam(Optimizer2State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super(Adam, self).__init__(
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
|
@ -57,7 +57,7 @@ class Adam8bit(Optimizer2State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super(Adam8bit, self).__init__(
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
|
@ -86,7 +86,7 @@ class Adam32bit(Optimizer2State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super(Adam32bit, self).__init__(
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
|
@ -146,7 +146,7 @@ class AnalysisAdam(torch.optim.Optimizer):
|
|||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
)
|
||||
super(AnalysisAdam, self).__init__(params, defaults)
|
||||
super().__init__(params, defaults)
|
||||
self.analysis = bnb_analysis
|
||||
self.savedir = savedir
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ class AdamW(Optimizer2State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super(AdamW, self).__init__(
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
|
@ -49,7 +49,7 @@ class AdamW8bit(Optimizer2State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super(AdamW8bit, self).__init__(
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
|
@ -78,7 +78,7 @@ class AdamW32bit(Optimizer2State):
|
|||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super(AdamW32bit, self).__init__(
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
|
|
|
@ -23,7 +23,7 @@ class LAMB(Optimizer2State):
|
|||
block_wise=False,
|
||||
max_unorm=1.0,
|
||||
):
|
||||
super(LAMB, self).__init__(
|
||||
super().__init__(
|
||||
"lamb",
|
||||
params,
|
||||
lr,
|
||||
|
@ -56,7 +56,7 @@ class LAMB8bit(Optimizer2State):
|
|||
block_wise=False,
|
||||
max_unorm=1.0,
|
||||
):
|
||||
super(LAMB8bit, self).__init__(
|
||||
super().__init__(
|
||||
"lamb",
|
||||
params,
|
||||
lr,
|
||||
|
@ -89,7 +89,7 @@ class LAMB32bit(Optimizer2State):
|
|||
block_wise=False,
|
||||
max_unorm=1.0,
|
||||
):
|
||||
super(LAMB32bit, self).__init__(
|
||||
super().__init__(
|
||||
"lamb",
|
||||
params,
|
||||
lr,
|
||||
|
|
|
@ -25,9 +25,9 @@ class LARS(Optimizer1State):
|
|||
):
|
||||
if momentum == 0:
|
||||
raise NotImplementedError(
|
||||
f"LARS without momentum is not supported!"
|
||||
"LARS without momentum is not supported!"
|
||||
)
|
||||
super(LARS, self).__init__(
|
||||
super().__init__(
|
||||
"lars",
|
||||
params,
|
||||
lr,
|
||||
|
@ -59,9 +59,9 @@ class LARS8bit(Optimizer1State):
|
|||
):
|
||||
if momentum == 0:
|
||||
raise NotImplementedError(
|
||||
f"LARS without momentum is not supported!"
|
||||
"LARS without momentum is not supported!"
|
||||
)
|
||||
super(LARS8bit, self).__init__(
|
||||
super().__init__(
|
||||
"lars",
|
||||
params,
|
||||
lr,
|
||||
|
@ -93,9 +93,9 @@ class LARS32bit(Optimizer1State):
|
|||
):
|
||||
if momentum == 0:
|
||||
raise NotImplementedError(
|
||||
f"LARS without momentum is not supported!"
|
||||
"LARS without momentum is not supported!"
|
||||
)
|
||||
super(LARS32bit, self).__init__(
|
||||
super().__init__(
|
||||
"lars",
|
||||
params,
|
||||
lr,
|
||||
|
@ -123,12 +123,12 @@ class PytorchLARS(Optimizer):
|
|||
max_unorm=0.02,
|
||||
):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if momentum < 0.0:
|
||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||
raise ValueError(f"Invalid momentum value: {momentum}")
|
||||
if weight_decay < 0.0:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
|
||||
defaults = dict(
|
||||
|
@ -143,10 +143,10 @@ class PytorchLARS(Optimizer):
|
|||
raise ValueError(
|
||||
"Nesterov momentum requires a momentum and zero dampening"
|
||||
)
|
||||
super(PytorchLARS, self).__init__(params, defaults)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(PytorchLARS, self).__setstate__(state)
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault("nesterov", False)
|
||||
|
||||
|
@ -181,7 +181,7 @@ class PytorchLARS(Optimizer):
|
|||
state = self.state[p]
|
||||
d_p = p.grad
|
||||
if weight_decay != 0:
|
||||
d_p = d_p.add(param, alpha=weight_decay)
|
||||
d_p = d_p.add(p, alpha=weight_decay)
|
||||
|
||||
if momentum != 0:
|
||||
buf = state.get("momentum_buffer", None)
|
||||
|
|
|
@ -12,13 +12,13 @@ import torch
|
|||
import bitsandbytes.functional as F
|
||||
|
||||
|
||||
class MockArgs(object):
|
||||
class MockArgs:
|
||||
def __init__(self, initial_data):
|
||||
for key in initial_data:
|
||||
setattr(self, key, initial_data[key])
|
||||
|
||||
|
||||
class GlobalOptimManager(object):
|
||||
class GlobalOptimManager:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
|
@ -56,9 +56,9 @@ class GlobalOptimManager(object):
|
|||
"""
|
||||
Overrides initial optimizer config for specific parameters.
|
||||
|
||||
The key-values of the optimizer config for the input parameters are overidden
|
||||
The key-values of the optimizer config for the input parameters are overridden
|
||||
This can be both, optimizer parameters like "betas", or "lr" or it can be
|
||||
8-bit specific paramters like "optim_bits", "percentile_clipping".
|
||||
8-bit specific parameters like "optim_bits", "percentile_clipping".
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
@ -93,13 +93,12 @@ class GlobalOptimManager(object):
|
|||
|
||||
class Optimizer8bit(torch.optim.Optimizer):
|
||||
def __init__(self, params, defaults, optim_bits=32):
|
||||
super(Optimizer8bit, self).__init__(params, defaults)
|
||||
super().__init__(params, defaults)
|
||||
self.initialized = False
|
||||
self.name2qmap = {}
|
||||
|
||||
self.mng = GlobalOptimManager.get_instance()
|
||||
self.non_castable_tensor_keys = set(
|
||||
[
|
||||
self.non_castable_tensor_keys = {
|
||||
"qmap1",
|
||||
"qmap2",
|
||||
"max1",
|
||||
|
@ -112,8 +111,7 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
"absmax1",
|
||||
"absmax2",
|
||||
"unorm_vec",
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
if optim_bits == 8:
|
||||
self.fill_qmap()
|
||||
|
@ -123,7 +121,7 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super(Optimizer8bit, self).__setstate__(state)
|
||||
super().__setstate__(state)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""Loads the optimizer state.
|
||||
|
@ -155,8 +153,8 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
id_map = {
|
||||
old_id: p
|
||||
for old_id, p in zip(
|
||||
chain.from_iterable((g["params"] for g in saved_groups)),
|
||||
chain.from_iterable((g["params"] for g in groups)),
|
||||
chain.from_iterable(g["params"] for g in saved_groups),
|
||||
chain.from_iterable(g["params"] for g in groups),
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -284,11 +282,11 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
return config
|
||||
|
||||
def init_state(self, group, p, gindex, pindex):
|
||||
raise NotImplementedError(f"init_state method needs to be overidden")
|
||||
raise NotImplementedError("init_state method needs to be overridden")
|
||||
|
||||
def update_step(self, group, p, gindex, pindex):
|
||||
raise NotImplementedError(
|
||||
f"The update_step method needs to be overidden"
|
||||
"The update_step method needs to be overridden"
|
||||
)
|
||||
|
||||
|
||||
|
@ -310,9 +308,9 @@ class Optimizer2State(Optimizer8bit):
|
|||
skip_zeros=False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
if isinstance(betas, str):
|
||||
# format: '(beta1, beta2)'
|
||||
betas = betas.replace("(", "").replace(")", "").strip().split(",")
|
||||
|
@ -324,10 +322,10 @@ class Optimizer2State(Optimizer8bit):
|
|||
)
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
|
||||
super().__init__(params, defaults, optim_bits)
|
||||
|
||||
if args is None:
|
||||
args = {}
|
||||
|
@ -542,9 +540,9 @@ class Optimizer1State(Optimizer8bit):
|
|||
skip_zeros=False,
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||
for i in range(len(betas)):
|
||||
if not 0.0 <= betas[i] < 1.0:
|
||||
raise ValueError(
|
||||
|
@ -552,10 +550,10 @@ class Optimizer1State(Optimizer8bit):
|
|||
)
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(
|
||||
"Invalid weight_decay value: {}".format(weight_decay)
|
||||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
|
||||
super().__init__(params, defaults, optim_bits)
|
||||
|
||||
if args is None:
|
||||
args = {}
|
||||
|
|
|
@ -23,11 +23,11 @@ class RMSprop(Optimizer1State):
|
|||
):
|
||||
if alpha == 0:
|
||||
raise NotImplementedError(
|
||||
f"RMSprop with alpha==0.0 is not supported!"
|
||||
"RMSprop with alpha==0.0 is not supported!"
|
||||
)
|
||||
if centered:
|
||||
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
||||
super(RMSprop, self).__init__(
|
||||
raise NotImplementedError("Centered RMSprop is not supported!")
|
||||
super().__init__(
|
||||
"rmsprop",
|
||||
params,
|
||||
lr,
|
||||
|
@ -59,11 +59,11 @@ class RMSprop8bit(Optimizer1State):
|
|||
):
|
||||
if alpha == 0:
|
||||
raise NotImplementedError(
|
||||
f"RMSprop with alpha==0.0 is not supported!"
|
||||
"RMSprop with alpha==0.0 is not supported!"
|
||||
)
|
||||
if centered:
|
||||
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
||||
super(RMSprop8bit, self).__init__(
|
||||
raise NotImplementedError("Centered RMSprop is not supported!")
|
||||
super().__init__(
|
||||
"rmsprop",
|
||||
params,
|
||||
lr,
|
||||
|
@ -96,11 +96,11 @@ class RMSprop32bit(Optimizer1State):
|
|||
|
||||
if alpha == 0:
|
||||
raise NotImplementedError(
|
||||
f"RMSprop with alpha==0.0 is not supported!"
|
||||
"RMSprop with alpha==0.0 is not supported!"
|
||||
)
|
||||
if centered:
|
||||
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
||||
super(RMSprop32bit, self).__init__(
|
||||
raise NotImplementedError("Centered RMSprop is not supported!")
|
||||
super().__init__(
|
||||
"rmsprop",
|
||||
params,
|
||||
lr,
|
||||
|
|
|
@ -21,8 +21,8 @@ class SGD(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
if momentum == 0:
|
||||
raise NotImplementedError(f"SGD without momentum is not supported!")
|
||||
super(SGD, self).__init__(
|
||||
raise NotImplementedError("SGD without momentum is not supported!")
|
||||
super().__init__(
|
||||
"momentum",
|
||||
params,
|
||||
lr,
|
||||
|
@ -52,8 +52,8 @@ class SGD8bit(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
if momentum == 0:
|
||||
raise NotImplementedError(f"SGD without momentum is not supported!")
|
||||
super(SGD8bit, self).__init__(
|
||||
raise NotImplementedError("SGD without momentum is not supported!")
|
||||
super().__init__(
|
||||
"momentum",
|
||||
params,
|
||||
lr,
|
||||
|
@ -83,8 +83,8 @@ class SGD32bit(Optimizer1State):
|
|||
block_wise=True,
|
||||
):
|
||||
if momentum == 0:
|
||||
raise NotImplementedError(f"SGD without momentum is not supported!")
|
||||
super(SGD32bit, self).__init__(
|
||||
raise NotImplementedError("SGD without momentum is not supported!")
|
||||
super().__init__(
|
||||
"momentum",
|
||||
params,
|
||||
lr,
|
||||
|
|
|
@ -4,7 +4,7 @@ Basic steps.
|
|||
1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly`
|
||||
2. `CUDA_VERSION=XXX python setup.py install`
|
||||
|
||||
To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive).
|
||||
To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive).
|
||||
|
||||
For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands:
|
||||
```bash
|
||||
|
@ -13,7 +13,7 @@ echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc
|
|||
source ~/.bashrc
|
||||
```
|
||||
|
||||
By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler.
|
||||
By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler.
|
||||
|
||||
Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
|
|||
|
||||
for (int i = 0; i < valid_chunks; i++)
|
||||
int err = pthread_join(threads[i], NULL);
|
||||
|
||||
|
||||
free(threads);
|
||||
for (int i = 0; i < valid_chunks; i++)
|
||||
free(args[i]);
|
||||
|
|
160
csrc/kernels.cu
160
csrc/kernels.cu
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <kernels.cuh>
|
||||
|
@ -303,7 +303,7 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou
|
|||
if(threadIdx.x % 32 < 8)
|
||||
{
|
||||
// offset: 8 values per 256 input values
|
||||
//
|
||||
//
|
||||
int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8;
|
||||
}
|
||||
|
||||
|
@ -428,16 +428,16 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC>
|
||||
__launch_bounds__(TH, 4)
|
||||
//__launch_bounds__(TH, 4)
|
||||
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
|
||||
{
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
|
||||
T vals[NUM];
|
||||
float rand_vals[NUM];
|
||||
unsigned char qvals[NUM];
|
||||
T vals[NUM_PER_TH];
|
||||
float rand_vals[NUM_PER_TH];
|
||||
unsigned char qvals[NUM_PER_TH];
|
||||
//float local_abs_max = -FLT_MAX;
|
||||
float local_abs_max = 0.0f;
|
||||
int local_rand_idx = 0;
|
||||
|
@ -454,8 +454,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
__shared__ float smem_code[256];
|
||||
__shared__ float smem_absmax_value[1];
|
||||
|
||||
if(threadIdx.x < 256)
|
||||
smem_code[threadIdx.x] = code[threadIdx.x];
|
||||
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
|
||||
smem_code[i] = code[i];
|
||||
|
||||
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
|
||||
{
|
||||
|
@ -510,15 +510,15 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH>
|
||||
__global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n)
|
||||
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n)
|
||||
{
|
||||
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
|
||||
T vals[NUM];
|
||||
unsigned char qvals[NUM];
|
||||
T vals[NUM_PER_TH];
|
||||
unsigned char qvals[NUM_PER_TH];
|
||||
float local_abs_max = -FLT_MAX;
|
||||
|
||||
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
|
@ -526,10 +526,11 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
|
|||
|
||||
__shared__ typename LoadChar::TempStorage loadchar;
|
||||
__shared__ typename StoreT::TempStorage storet;
|
||||
__shared__ float smem_code[256];
|
||||
//__shared__ float smem_code[256];
|
||||
//float local_code[16];
|
||||
|
||||
if(threadIdx.x < 256)
|
||||
smem_code[threadIdx.x] = code[threadIdx.x];
|
||||
//if(threadIdx.x < 256)
|
||||
//smem_code[threadIdx.x] = code[threadIdx.x];
|
||||
|
||||
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
|
||||
{
|
||||
|
@ -539,9 +540,10 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ c
|
|||
__syncthreads();
|
||||
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128);
|
||||
|
||||
// load code through read-only cache via __ldg
|
||||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
vals[j] = smem_code[qvals[j]]*local_abs_max;
|
||||
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
|
||||
|
||||
__syncthreads();
|
||||
StoreT(storet).Store(&(out[i]), vals, valid_items);
|
||||
|
@ -572,7 +574,7 @@ __global__ void kDequantize(float *code, unsigned char *A, float *out, const int
|
|||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n)
|
||||
|
@ -620,7 +622,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
|||
{
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
case ADAM:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
|
||||
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
|
||||
s1_vals[j] *= correction1;
|
||||
|
@ -651,7 +653,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
|||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__launch_bounds__(TH, 1)
|
||||
__global__ void kOptimizer32bit2State(T* g, T* p,
|
||||
__global__ void kOptimizer32bit2State(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
|
@ -714,7 +716,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
|
|||
{
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
case ADAM:
|
||||
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
|
||||
{
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
|
||||
|
@ -739,7 +741,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
|
|||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n)
|
||||
|
@ -781,19 +783,19 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
{
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case MOMENTUM:
|
||||
case MOMENTUM:
|
||||
if(step == 1)
|
||||
s1_vals[j] = (float)g_vals[j]; // state update
|
||||
else
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case RMSPROP:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
|
||||
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case ADAGRAD:
|
||||
case ADAGRAD:
|
||||
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update
|
||||
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
|
@ -817,7 +819,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__launch_bounds__(TH, 1)
|
||||
__global__ void kOptimizer32bit1State(T *g, T *p,
|
||||
__global__ void kOptimizer32bit1State(T *g, T *p,
|
||||
float *state1, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
|
@ -880,7 +882,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
{
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case MOMENTUM:
|
||||
case MOMENTUM:
|
||||
if(step == 1)
|
||||
s1_vals[j] = (float)g_vals[j];
|
||||
else
|
||||
|
@ -888,11 +890,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
|
||||
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
|
||||
break;
|
||||
case RMSPROP:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
|
||||
break;
|
||||
case ADAGRAD:
|
||||
case ADAGRAD:
|
||||
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
|
||||
break;
|
||||
|
@ -1154,12 +1156,12 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
|
|||
template<typename T, int OPTIMIZER>
|
||||
__global__ void
|
||||
__launch_bounds__(NUM_THREADS, 2)
|
||||
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||
float *unorm,
|
||||
const float beta1,
|
||||
const float beta1,
|
||||
const float eps, const int step,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
const float weight_decay,
|
||||
const float gnorm_scale, const int n)
|
||||
{
|
||||
|
@ -1209,7 +1211,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0];
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case MOMENTUM:
|
||||
case MOMENTUM:
|
||||
if(step == 1)
|
||||
s1_vals[j] = (float)g_vals[j];
|
||||
else
|
||||
|
@ -1217,7 +1219,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
if(unorm != NULL)
|
||||
local_unorm += s1_vals[j]*s1_vals[j];
|
||||
break;
|
||||
case RMSPROP:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
}
|
||||
|
@ -1242,10 +1244,10 @@ template<typename T, int OPTIMIZER>
|
|||
__global__ void
|
||||
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||
const float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1,
|
||||
const float beta1,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
float weight_decay,
|
||||
const float gnorm_scale, const int n)
|
||||
{
|
||||
|
@ -1311,7 +1313,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case MOMENTUM:
|
||||
case MOMENTUM:
|
||||
if(step == 1)
|
||||
s1_vals[j] = g_vals[j];
|
||||
else
|
||||
|
@ -1319,7 +1321,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
|
||||
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
|
||||
break;
|
||||
case RMSPROP:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
|
||||
break;
|
||||
|
@ -1399,7 +1401,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
const float beta1, const float beta2,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
|
||||
float* absmax1, float* absmax2,
|
||||
float* absmax1, float* absmax2,
|
||||
float weight_decay,
|
||||
const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
{
|
||||
|
@ -1543,7 +1545,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items);
|
||||
|
||||
// quantizaztion: 2.67/1.70 -> 3.4/3.3
|
||||
# pragma unroll N_PER_TH
|
||||
# pragma unroll N_PER_TH
|
||||
for(unsigned int j = 0; j < N_PER_TH; j++)
|
||||
{
|
||||
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
|
||||
|
@ -1656,16 +1658,16 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case MOMENTUM:
|
||||
case MOMENTUM:
|
||||
if(step == 1)
|
||||
s1_vals[j] = g_val;
|
||||
else
|
||||
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
||||
break;
|
||||
case RMSPROP:
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
case ADAGRAD:
|
||||
case ADAGRAD:
|
||||
s1_vals[j] = s1_vals[j] + (g_val*g_val);
|
||||
break;
|
||||
}
|
||||
|
@ -1696,14 +1698,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
{
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case MOMENTUM:
|
||||
case MOMENTUM:
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
||||
break;
|
||||
case RMSPROP:
|
||||
case RMSPROP:
|
||||
g_val = g_vals[j];
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||
break;
|
||||
case ADAGRAD:
|
||||
case ADAGRAD:
|
||||
g_val = g_vals[j];
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||
break;
|
||||
|
@ -1716,7 +1718,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
|
||||
|
||||
// quantizaztion: 2.67/1.70 -> 3.4/3.3
|
||||
# pragma unroll N_PER_TH
|
||||
# pragma unroll N_PER_TH
|
||||
for(unsigned int j = 0; j < N_PER_TH; j++)
|
||||
{
|
||||
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
|
||||
|
@ -1893,9 +1895,9 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
|
|||
{
|
||||
|
||||
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
|
||||
// since different row/col stats need to be loaded with each thread.
|
||||
// since different row/col stats need to be loaded with each thread.
|
||||
// (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
|
||||
// and would lead to low global load utilization.
|
||||
// and would lead to low global load utilization.
|
||||
// (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
|
||||
// for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
|
||||
// (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
|
||||
|
@ -1903,7 +1905,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
|
|||
// We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
|
||||
// the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
|
||||
// we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
|
||||
// shared memory loads.
|
||||
// shared memory loads.
|
||||
|
||||
// data is in 32 column-tile major with tile width 32 columns and numRows rows
|
||||
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
|
||||
|
@ -2140,7 +2142,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
|
||||
|
||||
// To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
|
||||
// As such we need:
|
||||
// As such we need:
|
||||
// at least 32*4 shared memory tiles for col32; preferably 32*32
|
||||
// at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
|
||||
// at least 32*8 shared memory tiles for col4_turing: preferably 32*32
|
||||
|
@ -2150,7 +2152,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
// we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
|
||||
// for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
|
||||
// register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
|
||||
//
|
||||
//
|
||||
// to make the shared memory work with that occupancy we might need to union the block loads/stores
|
||||
|
||||
// each block loads TILE_COLs columns and TILE_ROW rows
|
||||
|
@ -2239,7 +2241,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
|
||||
switch(FORMAT)
|
||||
{
|
||||
case COL32:
|
||||
case COL32:
|
||||
if(TRANSPOSE)
|
||||
{
|
||||
// data lies in shared memory in the following way:
|
||||
|
@ -2264,7 +2266,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
|
||||
// each 32 columns we have new tile
|
||||
// each tile has size outRows*32 and base_row is done in increments of 32
|
||||
offset = base_row*outRows;
|
||||
offset = base_row*outRows;
|
||||
out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
|
||||
}
|
||||
}
|
||||
|
@ -2310,7 +2312,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
// we increase by row_tile_column every 32 columns
|
||||
// base_row increase in increments of 32
|
||||
//int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
|
||||
//int col_offset = (base_row/32)*row_tile_column;
|
||||
//int col_offset = (base_row/32)*row_tile_column;
|
||||
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
|
||||
// 256*outRows/8*base_row/32 = outRows*base_row
|
||||
int col_offset = outRows*base_row;
|
||||
|
@ -2347,7 +2349,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
// this happends every 8 rows anew (subrow % 8)
|
||||
// one writes 4 columns at once that is (col % 4) for the particular index in the subtile
|
||||
int subcol = warp_lane;
|
||||
|
||||
|
||||
// add local offset (4x4 sub-tile)
|
||||
if(subrow % 2 == 1)
|
||||
// odd
|
||||
|
@ -2387,7 +2389,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
// we increase by row_tile_column every 32 columns
|
||||
// base_row increase in increments of 32
|
||||
//int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
|
||||
//int col_offset = (base_row/32)*row_tile_column;
|
||||
//int col_offset = (base_row/32)*row_tile_column;
|
||||
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
|
||||
// 1024*outRows/32*base_row/32 = outRows*base_row
|
||||
int col_offset = outRows*base_row;
|
||||
|
@ -2445,7 +2447,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
#define C 1.0f/127.0f
|
||||
#define MAX_SPARSE_COUNT 32
|
||||
#define SMEM_SIZE 8*256
|
||||
template <typename T, int SPMM_ITEMS, int BITS>
|
||||
template <typename T, int SPMM_ITEMS, int BITS>
|
||||
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{
|
||||
|
||||
|
@ -2575,7 +2577,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
|
|||
#pragma unroll num_items
|
||||
for(int k = 0; k < num_items; k++)
|
||||
local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
|
||||
|
||||
|
||||
reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
|
||||
}
|
||||
else
|
||||
|
@ -2589,11 +2591,11 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
|
|||
|
||||
idx_col_B += blockDim.x*SPMM_ITEMS;
|
||||
local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
|
||||
{
|
||||
{
|
||||
int local_colidx = idx[blockIdx.x];
|
||||
|
||||
if(FORMAT==COL_TURING)
|
||||
|
@ -2653,7 +2655,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
out[out_idx] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
|
@ -2791,11 +2793,33 @@ template __global__ void kQuantizeBlockwise<half, 4096, 4, 0>(float * code, half
|
|||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 4096, 4, 1>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 4096, 4, 1>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 2048, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 2048, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 1024, 4, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 512, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 512, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 256, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 256, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 128, 2, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 128, 2, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<half, 64, 1, 0>(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template __global__ void kQuantizeBlockwise<float, 64, 1, 0>(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
|
||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 4096, 1024, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 2048, 512, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 1024, 256, 4>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 512, 256, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 256, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 256, 128, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 256, 128, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 128, 64, 2>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 128, 64, 2>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<half, 64, 64, 1>(float *code, unsigned char * A, float * absmax, half *out, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 64, 64, 1>(float *code, unsigned char * A, float * absmax, float *out, const int n);
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <float.h>
|
||||
|
@ -15,52 +15,52 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void kOptimizer32bit2State(T* g, T* p,
|
||||
__global__ void kOptimizer32bit2State(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void kOptimizer32bit1State(T* g, T* p,
|
||||
__global__ void kOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void
|
||||
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||
float *unorm,
|
||||
const float beta1,
|
||||
const float eps, const int step,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
const float beta1,
|
||||
const float eps, const int step,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
const float weight_decay,
|
||||
const float gnorm_scale, const int n);
|
||||
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void
|
||||
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||
const float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
const float beta1,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
float weight_decay, const float gnorm_scale, const int n);
|
||||
|
||||
|
||||
|
@ -70,7 +70,7 @@ __global__ void
|
|||
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
|
||||
float *unorm,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step,
|
||||
const float eps, const int step,
|
||||
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
|
||||
float* max1, float* max2, float* new_max1, float* new_max2,
|
||||
const float gnorm_scale, const int n);
|
||||
|
@ -81,7 +81,7 @@ __global__ void
|
|||
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
|
||||
const float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step, const float lr,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
|
||||
float* max1, float* max2, float* new_max1, float* new_max2,
|
||||
float weight_decay, const float gnorm_scale, const int n);
|
||||
|
@ -121,5 +121,3 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
|
57
csrc/ops.cu
57
csrc/ops.cu
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <ops.cuh>
|
||||
|
@ -50,11 +50,29 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
if(STOCHASTIC == 1)
|
||||
assert(blocksize == 4096);
|
||||
|
||||
if(blocksize == 4096)
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 2048)
|
||||
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 1024)
|
||||
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 512)
|
||||
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 256)
|
||||
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 128)
|
||||
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 64)
|
||||
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -66,6 +84,17 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
|
|||
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 2048)
|
||||
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 1024)
|
||||
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 512)
|
||||
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 256)
|
||||
kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 128)
|
||||
kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 64)
|
||||
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
@ -212,7 +241,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in
|
|||
|
||||
}
|
||||
|
||||
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
|
||||
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
|
||||
long long int strideA, long long int strideB, long long int strideC, int batchCount)
|
||||
{
|
||||
const int falpha = 1;
|
||||
|
@ -322,7 +351,7 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
|
|||
cublasLtOrder_t orderOut = get_order<TARGET>();
|
||||
int ldA = get_leading_dim<SRC>(dim1, dim2);
|
||||
int ldOut = get_leading_dim<TARGET>(dim1, dim2);
|
||||
|
||||
|
||||
cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
|
||||
cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
|
||||
cublasOperation_t opTranspose = CUBLAS_OP_T;
|
||||
|
@ -368,7 +397,7 @@ template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHa
|
|||
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
|
||||
|
||||
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{
|
||||
#ifdef NO_CUBLASLT
|
||||
cout << "" << endl;
|
||||
|
@ -659,10 +688,10 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
|
|||
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
|
||||
|
|
18
csrc/ops.cuh
18
csrc/ops.cuh
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
|
@ -128,10 +128,10 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
|
|||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n);
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n);
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||
float beta1, float beta2, float eps, float weight_decay,
|
||||
int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
|
||||
|
@ -139,15 +139,15 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
|
||||
float *unorm, float max_unorm, float param_norm,
|
||||
float beta1, float beta2,
|
||||
float eps, int step, float lr,
|
||||
float eps, int step, float lr,
|
||||
float* quantiles1, float* quantiles2,
|
||||
float* max1, float* max2, float* new_max1, float* new_max2,
|
||||
float weight_decay,
|
||||
const float gnorm_scale, int n);
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
|
||||
bool skip_zeros, int n);
|
||||
|
||||
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
|
||||
|
@ -155,7 +155,7 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
|
|||
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
|
||||
|
||||
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
|
||||
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
|
||||
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
|
||||
long long int strideA, long long int strideB, long long int strideC, int batchCount);
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||
//
|
||||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#if BUILD_CUDA
|
||||
|
@ -9,7 +9,7 @@
|
|||
#include <cpu_ops.h>
|
||||
|
||||
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
|
||||
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
|
||||
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
|
||||
// maintain all that boilerplate
|
||||
//===================================================================================
|
||||
// UNMANGLED CALLS
|
||||
|
@ -75,10 +75,10 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
|||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||
|
@ -140,8 +140,8 @@ extern "C"
|
|||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
||||
|
||||
|
@ -290,4 +290,3 @@ extern "C"
|
|||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
}
|
||||
|
||||
|
|
|
@ -76,6 +76,3 @@ if [[ -n "$CUDA_VERSION" ]]; then
|
|||
else
|
||||
echo ""
|
||||
fi
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -14,16 +14,16 @@ mng.register_parameters(model.parameters()) # 1. register parameters while still
|
|||
|
||||
model = model.cuda()
|
||||
# use 8-bit optimizer states for all parameters
|
||||
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
|
||||
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
|
||||
|
||||
# 2a. override: the parameter model.fc1.weight now uses 32-bit Adam
|
||||
mng.override_config(model.fc1.weight, 'optim_bits', 32)
|
||||
mng.override_config(model.fc1.weight, 'optim_bits', 32)
|
||||
|
||||
# 2b. override: the two special layers use
|
||||
# sparse optimization + different learning rate + different Adam betas
|
||||
mng.override_config([model.special.weight, model.also_special.weight],
|
||||
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
|
||||
```
|
||||
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
|
||||
```
|
||||
Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm`
|
||||
|
||||
For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager:
|
||||
|
|
|
@ -121,7 +121,7 @@ template <unsigned char Gap, typename T>
|
|||
struct DirectTraits<true,Gap,T>
|
||||
{
|
||||
typedef FVec1<SSE, T> fVec1;
|
||||
|
||||
|
||||
static void checkH(T scaler, T H_Times_x0, T xN)
|
||||
{
|
||||
union {
|
||||
|
@ -177,9 +177,9 @@ struct DirectInfo
|
|||
, cst0(fun_t::cst0(H, x[0]))
|
||||
{
|
||||
myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned");
|
||||
|
||||
|
||||
uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]);
|
||||
|
||||
|
||||
const uint32 npad = Gap-1;
|
||||
const uint32 n_sz = n + npad; // size of padded vector
|
||||
|
||||
|
@ -320,7 +320,7 @@ struct DirectInfo
|
|||
T cst0 = fun_t::cst0(H, px[0]);
|
||||
const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]);
|
||||
buckets.resize(maxIndex + 1);
|
||||
|
||||
|
||||
data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL));
|
||||
}
|
||||
|
||||
|
|
|
@ -203,7 +203,7 @@ struct IVec<SSE, double> : IVecBase<SSE>
|
|||
#if 1
|
||||
// takes 4 cycles
|
||||
__m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle
|
||||
__m128i s = _mm_add_epi32(vec, hi);
|
||||
__m128i s = _mm_add_epi32(vec, hi);
|
||||
int32 x = _mm_cvtsi128_si32(s);
|
||||
return -x;
|
||||
#else
|
||||
|
|
5
setup.py
5
setup.py
|
@ -18,7 +18,7 @@ def read(fname):
|
|||
|
||||
setup(
|
||||
name=f"bitsandbytes",
|
||||
version=f"0.35.3",
|
||||
version=f"0.35.4",
|
||||
author="Tim Dettmers",
|
||||
author_email="dettmers@cs.washington.edu",
|
||||
description="8-bit optimizers and matrix multiplication routines.",
|
||||
|
@ -26,9 +26,6 @@ setup(
|
|||
keywords="gpu optimizers optimization 8-bit quantization compression",
|
||||
url="https://github.com/TimDettmers/bitsandbytes",
|
||||
packages=find_packages(),
|
||||
entry_points={
|
||||
"console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
|
||||
},
|
||||
package_data={"": libs},
|
||||
long_description=read("README.md"),
|
||||
long_description_content_type="text/markdown",
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from itertools import product, permutations
|
||||
from itertools import permutations, product
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -27,7 +27,7 @@ str_values = list(
|
|||
)
|
||||
)
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
|
||||
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(
|
||||
*vals
|
||||
)
|
||||
for vals in str_values
|
||||
|
@ -286,7 +286,7 @@ str_values = list(
|
|||
has_bias
|
||||
)
|
||||
)
|
||||
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}_has_bias_{10}".format(*vals) for vals in str_values]
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -336,7 +336,7 @@ def test_matmullt(
|
|||
)
|
||||
bias = None
|
||||
bias2 = None
|
||||
if has_bias:
|
||||
if has_bias:
|
||||
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
|
||||
bias2 = bias.clone()
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import os
|
||||
import pytest
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from typing import List, NamedTuple
|
||||
|
||||
import pytest
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.cuda_setup import (
|
||||
CUDA_RUNTIME_LIB,
|
||||
evaluate_cuda_setup,
|
||||
determine_cuda_runtime_lib_path,
|
||||
evaluate_cuda_setup,
|
||||
extract_candidate_paths,
|
||||
)
|
||||
|
||||
|
|
|
@ -6,12 +6,14 @@ from itertools import product
|
|||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes import functional as F
|
||||
from scipy.stats import norm
|
||||
|
||||
torch.set_printoptions(
|
||||
precision=4, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
||||
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
|
||||
)
|
||||
k = 20
|
||||
|
||||
|
@ -26,7 +28,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
|
|||
|
||||
class FFN(torch.nn.Module):
|
||||
def __init__(self, input_features, hidden_size, bias=True):
|
||||
super(FFN, self).__init__()
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias)
|
||||
self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias)
|
||||
|
||||
|
@ -40,7 +42,7 @@ class FFN(torch.nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Timer(object):
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.starts = {}
|
||||
self.ends = {}
|
||||
|
@ -67,7 +69,7 @@ class Timer(object):
|
|||
self.ends.pop(name)
|
||||
|
||||
if print_ms and name in self.agg:
|
||||
print("{0} took: {1:.5f}s".format(name, self.agg[name] / 1000.0))
|
||||
print(f"{name} took: {self.agg[name] / 1000.0:.5f}s")
|
||||
|
||||
return self.agg[name]
|
||||
|
||||
|
@ -149,30 +151,41 @@ def test_dynamic_quantization():
|
|||
|
||||
|
||||
def test_dynamic_blockwise_quantization():
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
assert diffs[-1] < 0.011
|
||||
# print(sum(diffs)/len(diffs))
|
||||
# print(sum(reldiffs)/len(reldiffs))
|
||||
#print('')
|
||||
for blocksize in [4096, 2048, 1024, 512]:
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2).mean().item()
|
||||
assert diff < 0.0033
|
||||
diffs.append(diff)
|
||||
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
# print(sum(diffs)/len(diffs))
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
#print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
|
@ -289,7 +302,7 @@ batched = [False, True]
|
|||
values = list(product(dim1, dim2, methods, batched))
|
||||
values_names = list(product(dim1, dim2, method_names, batched))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
|
||||
"dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals)
|
||||
for vals in values_names
|
||||
]
|
||||
|
||||
|
@ -347,7 +360,7 @@ seq_dim = torch.randint(16, 256, size=(n,)).tolist()
|
|||
transpose = [(False, False), (False, True), (True, False), (True, True)]
|
||||
values = list(product(hidden_dim, batch_dim, transpose, seq_dim))
|
||||
names = [
|
||||
"hidden_dim_{0}_batch_dim_{1},transpose_{2}_seq_dim_{3}".format(*vals)
|
||||
"hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals)
|
||||
for vals in values
|
||||
]
|
||||
|
||||
|
@ -412,7 +425,7 @@ hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
|
|||
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
|
||||
values = list(product(seq_dim, hidden_dim, batch_dim))
|
||||
names = [
|
||||
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
|
||||
"seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
|
@ -444,7 +457,7 @@ batch_dim = torch.randint(2, 16, size=(n,)).tolist()
|
|||
transpose = [False, True]
|
||||
values = list(product(seq_dim, hidden_dim, batch_dim, transpose))
|
||||
names = [
|
||||
"seq_dim={0}_hidden_dim={1}_batch_dim={2}_transpose{3}".format(*vals)
|
||||
"seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals)
|
||||
for vals in values
|
||||
]
|
||||
|
||||
|
@ -529,7 +542,7 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
|
|||
transpose = [(False, False), (True, False), (False, True), (True, True)]
|
||||
values = list(product(dim1, dim2, dim3, dim4, transpose))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
|
||||
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals)
|
||||
for vals in values
|
||||
]
|
||||
|
||||
|
@ -567,7 +580,7 @@ dim1 = torch.randint(1, 64, size=(n,)).tolist()
|
|||
dim2 = torch.randint(32, 128, size=(n,)).tolist()
|
||||
dim3 = torch.randint(32, 256, size=(n,)).tolist()
|
||||
values = list(product(dim1, dim2, dim3))
|
||||
names = ["dim1_{0}_dim2_{1}_dim3_{2}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names)
|
||||
|
@ -596,7 +609,7 @@ transpose = [False]
|
|||
dims = [2, 3]
|
||||
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
|
||||
|
||||
names = ["dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(*vals)for vals in values]
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names)
|
||||
|
@ -678,7 +691,7 @@ ldb = [0]
|
|||
# ldb = list(range(256, 1*1024, 256))
|
||||
values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}_ldb_{5}".format(*vals)
|
||||
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals)
|
||||
for vals in values
|
||||
]
|
||||
|
||||
|
@ -726,7 +739,7 @@ dims = (2,)
|
|||
# ldb = list(range(256, 1*1024, 256))
|
||||
values = list(product(dim1, dim2, dim3, dim4, dims))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
|
||||
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals)
|
||||
for vals in values
|
||||
]
|
||||
|
||||
|
@ -784,7 +797,7 @@ values = [
|
|||
|
||||
# values = list(product(batch, seq, model, hidden))
|
||||
names = [
|
||||
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
|
||||
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
|
@ -952,7 +965,7 @@ dims = (2,)
|
|||
formatB = ["col_turing", "col_ampere"]
|
||||
has_bias = [True, False]
|
||||
values = list(product(dim1, dim4, dims, formatB, has_bias))
|
||||
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}_has_bias_{4}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
|
||||
|
@ -1002,7 +1015,7 @@ dim2 = [1 * 1024]
|
|||
dims = (2,)
|
||||
# ldb = list(range(256, 1*1024, 256))
|
||||
values = list(product(dim1, dim2, dims))
|
||||
names = ["dim1_{0}_dim2_{1}_dims_{2}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names)
|
||||
|
@ -1058,7 +1071,7 @@ dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|||
dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
||||
|
||||
values = list(product(dim1, dim2))
|
||||
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
|
||||
|
@ -1105,7 +1118,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|||
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
||||
|
||||
values = list(zip(dim1, dim4, inner))
|
||||
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||
|
@ -1149,7 +1162,7 @@ dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|||
inner = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
||||
|
||||
values = list(zip(dim1, dim4, inner))
|
||||
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||
|
@ -1224,7 +1237,7 @@ inner = [12288 * 4, 4096 * 4]
|
|||
dim4 = [12288, 4096]
|
||||
|
||||
values = list(zip(dim1, dim4, inner))
|
||||
names = ["dim1_{0}_dim4_{1}_inner_{2}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names)
|
||||
|
@ -1290,7 +1303,7 @@ values = list(
|
|||
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
|
||||
)
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
|
||||
"dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format(
|
||||
*vals
|
||||
)
|
||||
for vals in values
|
||||
|
@ -1341,7 +1354,7 @@ a_order = ["col_turing"]
|
|||
out_order = ["row"]
|
||||
values = list(product(dim1, dim2, dtype, a_order, out_order))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_dtype_{2}_orderA_{3}_orderOut_{4}".format(*vals)
|
||||
"dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals)
|
||||
for vals in values
|
||||
]
|
||||
|
||||
|
@ -1367,7 +1380,7 @@ dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist()
|
|||
# dim2 = [5]
|
||||
|
||||
values = list(product(dim1, dim2))
|
||||
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
|
||||
|
@ -1404,7 +1417,7 @@ dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist()
|
|||
# dim2 = [11]
|
||||
transposed_B = [False, True]
|
||||
values = list(product(dim1, dim2, transposed_B))
|
||||
names = ["dim1_{0}_dim2_{1}_transposed_B_{2}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names)
|
||||
|
@ -1485,7 +1498,7 @@ n = 2
|
|||
dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
|
||||
dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist()
|
||||
values = list(product(dim1, dim2))
|
||||
names = ["dim1_{0}_dim2_{1}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim2_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2", values, ids=names)
|
||||
|
@ -1550,7 +1563,7 @@ dtype = [torch.float16]
|
|||
out_function = ["zeros", "ones"]
|
||||
values = list(product(dim1, dim2, dtype, out_function))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
|
||||
"dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
|
@ -1616,17 +1629,6 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
|
|||
# print(time.time() - t0)
|
||||
|
||||
|
||||
def test_layout():
|
||||
a1 = torch.rand(16, 64, device="cuda", dtype=torch.float16)
|
||||
a1 = torch.arange(16 * 64, device="cuda").reshape(16, 64).byte()
|
||||
a2, s2 = F.transform(a1, "col_turing")
|
||||
print(a2.shape)
|
||||
|
||||
print(a1.flatten()[8 * 64 : 8 * 64 + 32])
|
||||
for i in range(4):
|
||||
print(a2.flatten()[i * 8 * 32 : i * 8 * 32 + 32], 0)
|
||||
|
||||
|
||||
def test_coo2csr():
|
||||
threshold = 1
|
||||
A = torch.randn(128, 128).half().cuda()
|
||||
|
@ -1678,7 +1680,7 @@ dim2 = [2048]
|
|||
# dim2 = [2]
|
||||
dtype = [torch.int8]
|
||||
values = list(product(dim1, dim2, dtype))
|
||||
names = ["dim1_{0}_dim2_{1}_dtype_{2}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names)
|
||||
|
@ -1794,7 +1796,7 @@ values.append((batch_size, seqdim, 768, 4 * 768))
|
|||
# values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = [
|
||||
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
|
||||
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
|
@ -2040,3 +2042,154 @@ def test_blockwise_cpu_large():
|
|||
assert diffs[-1] < 0.011
|
||||
# print(sum(diffs)/len(diffs))
|
||||
# print(sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
|
||||
def test_fp8_quant():
|
||||
for e_bits in range(1, 7):
|
||||
p_bits = 7-e_bits
|
||||
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
|
||||
|
||||
print(e_bits, p_bits)
|
||||
abserr = []
|
||||
relerr = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, SC = F.quantize_blockwise(A1, code=code)
|
||||
A2 = F.dequantize_blockwise(C, SC)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff/torch.abs(A1+1e-8)
|
||||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
#print(sum(abserr)/len(abserr))
|
||||
#print(sum(relerr)/len(relerr))
|
||||
|
||||
abserr = []
|
||||
relerr = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, SC = F.quantize_blockwise(A1, code=code)
|
||||
A2 = F.dequantize_blockwise(C, SC)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff/torch.abs(A1+1e-8)
|
||||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
#print(sum(abserr)/len(abserr))
|
||||
#print(sum(relerr)/len(relerr))
|
||||
|
||||
abserr = []
|
||||
relerr = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, SC = F.quantize_blockwise(A1)
|
||||
A2 = F.dequantize_blockwise(C, SC)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff/torch.abs(A1+1e-8)
|
||||
abserr.append(diff.mean().item())
|
||||
relerr.append(reldiff.mean().item())
|
||||
#assert diff < 0.0075
|
||||
#print(3, sum(abserr)/len(abserr))
|
||||
#print(3, sum(relerr)/len(relerr))
|
||||
|
||||
|
||||
def test_few_bit_quant():
|
||||
|
||||
#print('')
|
||||
for bits in range(2, 9):
|
||||
#print('='*30, bits, '='*30)
|
||||
for method in ['linear', 'fp8', 'dynamic', 'quantile']:
|
||||
abserrs = []
|
||||
relerrs = []
|
||||
code = None
|
||||
if method == 'linear':
|
||||
code = F.create_linear_map(True, total_bits=bits).cuda()
|
||||
elif method == 'fp8':
|
||||
ebits = math.ceil(bits/2)
|
||||
pbits = bits-ebits-1
|
||||
code = F.create_fp8_map(True, ebits, pbits, bits).cuda()
|
||||
elif method == 'dynamic':
|
||||
code = F.create_dynamic_map(True, bits-0, bits).cuda()
|
||||
elif method == 'quantile':
|
||||
values = torch.randn(2048, 2048, device='cuda')
|
||||
code = F.create_quantile_map(values, bits).cuda()
|
||||
# for some data types we have no zero
|
||||
# for some data types we have one zero
|
||||
# for some data types we have two zeros
|
||||
assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}'
|
||||
#print(method, (code==0).sum())
|
||||
assert code.numel() == 256
|
||||
for i in range(10):
|
||||
|
||||
values = torch.randn(1, 32, device='cuda')
|
||||
values /= values.abs().max()
|
||||
#values[values.abs() < 1e-6] += 1e-5
|
||||
|
||||
q1 = []
|
||||
v1 = []
|
||||
for v in values[0]:
|
||||
idx = torch.abs(v-code).argmin()
|
||||
q1.append(idx.item())
|
||||
v1.append(code[idx].item())
|
||||
|
||||
q1 = torch.Tensor(q1).cuda()
|
||||
v1 = torch.Tensor(v1).cuda()
|
||||
|
||||
q2, S2 = F.quantize_blockwise(values, code=code)
|
||||
v2 = F.dequantize_blockwise(q2, S2)
|
||||
|
||||
idx = torch.isclose(q1.int(), q2.int())
|
||||
err2 = torch.abs(v2-values)
|
||||
abserrs.append(err2.mean().item())
|
||||
relerrs.append((err2/(1e-10+values).abs()).mean().item())
|
||||
if idx.sum():
|
||||
# some weird cases
|
||||
err1 = torch.abs(v1-values).mean()
|
||||
#assert err2.mean() <= err1
|
||||
|
||||
else:
|
||||
torch.testing.assert_allclose(q1, q2)
|
||||
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
|
||||
#assert False
|
||||
|
||||
|
||||
def test_kbit_quantile_estimation():
|
||||
for i in range(100):
|
||||
data = torch.randn(1024, 1024, device='cuda')
|
||||
for bits in range(2, 9):
|
||||
p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits)
|
||||
val1 = torch.Tensor(norm.ppf(p)).cuda()
|
||||
val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits)
|
||||
err = torch.abs(val1-val2).mean()
|
||||
assert err < 0.038
|
||||
|
||||
for i in range(100):
|
||||
data = torch.randn(1024, 1024, device='cuda')
|
||||
for bits in range(2, 4):
|
||||
total_values = 2**bits-1
|
||||
p = np.linspace(0, 1, 2*total_values+1)
|
||||
idx = np.arange(1, 2*total_values+1, 2)
|
||||
p = p[idx]
|
||||
offset = 1/(2*total_values)
|
||||
p = np.linspace(offset, 1-offset, total_values)
|
||||
val1 = torch.Tensor(norm.ppf(p)).cuda()
|
||||
val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1)
|
||||
err = torch.abs(val1-val2).mean()
|
||||
assert err < 0.035
|
||||
|
||||
|
||||
def test_bench_dequantization():
|
||||
a = torch.rand(1024, 1024, device='cuda').half()
|
||||
qa, SA = F.quantize_blockwise(a)
|
||||
|
||||
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
|
||||
#print(max_theoretical_mu)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
F.dequantize_blockwise(qa, SA, blocksize=2048)
|
||||
torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/1e6)
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch import nn
|
|||
import bitsandbytes as bnb
|
||||
|
||||
|
||||
class MockArgs(object):
|
||||
class MockArgs:
|
||||
def __init__(self, initial_data):
|
||||
for key in initial_data:
|
||||
setattr(self, key, initial_data[key])
|
||||
|
@ -15,7 +15,7 @@ class MockArgs(object):
|
|||
|
||||
class MLP8bit(torch.nn.Module):
|
||||
def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0):
|
||||
super(MLP8bit, self).__init__()
|
||||
super().__init__()
|
||||
self.fc1 = bnb.nn.Linear8bitLt(
|
||||
dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward,
|
||||
threshold=threshold
|
||||
|
@ -289,7 +289,7 @@ class LinearFunction(torch.autograd.Function):
|
|||
|
||||
class Linear8bit(nn.Module):
|
||||
def __init__(self, input_features, output_features, bias=True, args=None):
|
||||
super(Linear8bit, self).__init__()
|
||||
super().__init__()
|
||||
self.input_features = input_features
|
||||
self.output_features = output_features
|
||||
self.args = args
|
||||
|
@ -312,7 +312,7 @@ class Linear8bit(nn.Module):
|
|||
|
||||
threshold = [0.0, 3.0]
|
||||
values = threshold
|
||||
names = ["threshold_{0}".format(vals) for vals in values]
|
||||
names = [f"threshold_{vals}" for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||
|
@ -378,7 +378,7 @@ def test_linear8bitlt_accumulated_gradient():
|
|||
|
||||
threshold = [0.0, 2.0]
|
||||
values = threshold
|
||||
names = ["threshold_{0}".format(vals) for vals in values]
|
||||
names = [f"threshold_{vals}" for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||
|
|
|
@ -18,7 +18,7 @@ k = 20
|
|||
|
||||
|
||||
def get_temp_dir():
|
||||
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
|
||||
path = f"/tmp/autoswap/{str(uuid.uuid4())}"
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
@ -116,7 +116,7 @@ gtype = [torch.float32, torch.float16]
|
|||
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
|
@ -187,7 +187,7 @@ dim1 = [1024]
|
|||
dim2 = [32, 1024, 4097]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
values = list(product(dim1, dim2, gtype))
|
||||
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
|
||||
names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
|
||||
|
@ -250,7 +250,7 @@ optimizer_names = [
|
|||
]
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
|
@ -391,7 +391,7 @@ gtype = [torch.float32]
|
|||
optim_bits = [32, 8]
|
||||
values = list(product(dim1, dim2, gtype, optim_bits))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals)
|
||||
for vals in values
|
||||
]
|
||||
|
||||
|
@ -495,7 +495,7 @@ gtype = [torch.float32, torch.float16]
|
|||
optimizer_names = ["adam8bit_blockwise"]
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -1,149 +0,0 @@
|
|||
./setup.py:20:10: F541 f-string is missing placeholders
|
||||
./setup.py:21:13: F541 f-string is missing placeholders
|
||||
./quicktest.py:5:1: F401 'bitsandbytes as bnb' imported but unused
|
||||
./bitsandbytes/cuda_setup.py:42:56: F821 undefined name 'error_str'
|
||||
./bitsandbytes/cuda_setup.py:43:15: F541 f-string is missing placeholders
|
||||
./bitsandbytes/cuda_setup.py:67:5: F841 local variable 'context' is assigned to but never used
|
||||
./bitsandbytes/cuda_setup.py:68:5: F841 local variable 'error_str' is assigned to but never used
|
||||
./bitsandbytes/cuda_setup.py:76:9: F841 local variable 'result' is assigned to but never used
|
||||
./bitsandbytes/cuda_setup.py:144:13: F841 local variable 'has_gpu' is assigned to but never used
|
||||
./bitsandbytes/functional.py:294:13: F821 undefined name 'math'
|
||||
./bitsandbytes/functional.py:295:16: F821 undefined name 'math'
|
||||
./bitsandbytes/functional.py:303:5: F841 local variable 'ptrA' is assigned to but never used
|
||||
./bitsandbytes/functional.py:304:5: F841 local variable 'ptrOut' is assigned to but never used
|
||||
./bitsandbytes/functional.py:1057:17: W503 line break before binary operator
|
||||
./bitsandbytes/functional.py:1058:17: W503 line break before binary operator
|
||||
./bitsandbytes/functional.py:1059:17: W503 line break before binary operator
|
||||
./bitsandbytes/functional.py:1649:1: F811 redefinition of unused 'get_special_format_str' from line 160
|
||||
./bitsandbytes/functional.py:1687:5: F841 local variable 'ptrA' is assigned to but never used
|
||||
./bitsandbytes/functional.py:1688:5: F841 local variable 'ptrOut' is assigned to but never used
|
||||
./bitsandbytes/functional.py:1802:5: F841 local variable 'ccolsA' is assigned to but never used
|
||||
./bitsandbytes/functional.py:1805:5: F841 local variable 'cldb' is assigned to but never used
|
||||
./bitsandbytes/functional.py:1806:5: F841 local variable 'cldc' is assigned to but never used
|
||||
./bitsandbytes/functional.py:1873:9: F841 local variable 'dtype' is assigned to but never used
|
||||
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.MatmulLtState' imported but unused
|
||||
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.bmm_cublas' imported but unused
|
||||
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul' imported but unused
|
||||
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.matmul_cublas' imported but unused
|
||||
./bitsandbytes/__init__.py:6:1: F401 '.autograd._functions.mm_cublas' imported but unused
|
||||
./bitsandbytes/__init__.py:9:1: F401 '.nn.modules' imported but unused
|
||||
./bitsandbytes/__init__.py:12:5: F401 '.optim.adam' imported but unused
|
||||
./bitsandbytes/autograd/_functions.py:5:1: F401 'bitsandbytes as bnb' imported but unused
|
||||
./bitsandbytes/autograd/_functions.py:12:75: W291 trailing whitespace
|
||||
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Int8Params' imported but unused
|
||||
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bit' imported but unused
|
||||
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.Linear8bitLt' imported but unused
|
||||
./bitsandbytes/nn/__init__.py:5:1: F401 '.modules.StableEmbedding' imported but unused
|
||||
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Any' imported but unused
|
||||
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Callable' imported but unused
|
||||
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Dict' imported but unused
|
||||
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Iterator' imported but unused
|
||||
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Mapping' imported but unused
|
||||
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Set' imported but unused
|
||||
./bitsandbytes/nn/modules.py:5:1: F401 'typing.Tuple' imported but unused
|
||||
./bitsandbytes/nn/modules.py:11:1: F401 'torch.nn.parameter.Parameter' imported but unused
|
||||
./bitsandbytes/nn/modules.py:183:13: W503 line break before binary operator
|
||||
./bitsandbytes/nn/modules.py:184:13: W503 line break before binary operator
|
||||
./bitsandbytes/nn/modules.py:272:24: F821 undefined name 'dist'
|
||||
./bitsandbytes/nn/modules.py:272:49: F821 undefined name 'dist'
|
||||
./bitsandbytes/optim/optimizer.py:243:9: F841 local variable 'overflows' is assigned to but never used
|
||||
./bitsandbytes/optim/optimizer.py:280:35: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/optimizer.py:283:35: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/lars.py:27:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/lars.py:59:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/lars.py:91:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/lars.py:157:13: F841 local variable 'params_with_grad' is assigned to but never used
|
||||
./bitsandbytes/optim/lars.py:158:13: F841 local variable 'd_p_list' is assigned to but never used
|
||||
./bitsandbytes/optim/lars.py:159:13: F841 local variable 'momentum_buffer_list' is assigned to but never used
|
||||
./bitsandbytes/optim/lars.py:174:35: F821 undefined name 'param'
|
||||
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam8bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:9:5: F401 '.adam.Adam32bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW8bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:10:5: F401 '.adamw.AdamW32bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD8bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:11:5: F401 '.sgd.SGD32bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS8bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.LARS32bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:12:5: F401 '.lars.PytorchLARS' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB8bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:13:5: F401 '.lamb.LAMB32bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop8bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:14:5: F401 '.rmsprop.RMSprop32bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad8bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:15:5: F401 '.adagrad.Adagrad32bit' imported but unused
|
||||
./bitsandbytes/optim/__init__.py:17:1: F401 '.optimizer.GlobalOptimManager' imported but unused
|
||||
./bitsandbytes/optim/adam.py:229:21: F841 local variable 'max_exp_avg_sq' is assigned to but never used
|
||||
./bitsandbytes/optim/rmsprop.py:25:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/rmsprop.py:27:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/rmsprop.py:59:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/rmsprop.py:61:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/rmsprop.py:94:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/rmsprop.py:96:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/sgd.py:24:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/sgd.py:55:39: F541 f-string is missing placeholders
|
||||
./bitsandbytes/optim/sgd.py:86:39: F541 f-string is missing placeholders
|
||||
./tests/test_optim.py:1:1: F401 'ctypes' imported but unused
|
||||
./tests/test_optim.py:199:5: F841 local variable 'mask' is assigned to but never used
|
||||
./tests/test_optim.py:218:9: F841 local variable 'atol' is assigned to but never used
|
||||
./tests/test_optim.py:218:15: F841 local variable 'rtol' is assigned to but never used
|
||||
./tests/test_optim.py:304:17: W503 line break before binary operator
|
||||
./tests/test_optim.py:354:21: W503 line break before binary operator
|
||||
./tests/test_autograd.py:309:13: F841 local variable 'err' is assigned to but never used
|
||||
./tests/test_cuda_setup_evaluator.py:31:9: F821 undefined name 'test_dir'
|
||||
./tests/test_cuda_setup_evaluator.py:33:14: F821 undefined name 'test_input'
|
||||
./tests/test_cuda_setup_evaluator.py:81:32: E203 whitespace before ':'
|
||||
./tests/test_functional.py:55:13: F841 local variable 'ms' is assigned to but never used
|
||||
./tests/test_functional.py:177:5: F841 local variable 'diffs' is assigned to but never used
|
||||
./tests/test_functional.py:178:5: F841 local variable 'reldiffs' is assigned to but never used
|
||||
./tests/test_functional.py:260:5: F841 local variable 'minA' is assigned to but never used
|
||||
./tests/test_functional.py:261:5: F841 local variable 'maxA' is assigned to but never used
|
||||
./tests/test_functional.py:584:5: F841 local variable 'func' is assigned to but never used
|
||||
./tests/test_functional.py:617:17: F841 local variable 'offset' is assigned to but never used
|
||||
./tests/test_functional.py:618:17: F841 local variable 'col2' is assigned to but never used
|
||||
./tests/test_functional.py:619:17: F841 local variable 'row2' is assigned to but never used
|
||||
./tests/test_functional.py:705:9: F841 local variable 'C1' is assigned to but never used
|
||||
./tests/test_functional.py:706:9: F841 local variable 'C2' is assigned to but never used
|
||||
./tests/test_functional.py:715:9: F841 local variable 'output' is assigned to but never used
|
||||
./tests/test_functional.py:750:5: F841 local variable 'formatB' is assigned to but never used
|
||||
./tests/test_functional.py:754:5: F841 local variable 'w2' is assigned to but never used
|
||||
./tests/test_functional.py:763:5: F841 local variable 'dtype' is assigned to but never used
|
||||
./tests/test_functional.py:770:9: F841 local variable 'out1' is assigned to but never used
|
||||
./tests/test_functional.py:1108:5: F841 local variable 'relerr1' is assigned to but never used
|
||||
./tests/test_functional.py:1108:14: F841 local variable 'relerr2' is assigned to but never used
|
||||
./tests/test_functional.py:1114:9: F841 local variable 'C1' is assigned to but never used
|
||||
./tests/test_functional.py:1135:9: F841 local variable 'C4' is assigned to but never used
|
||||
./tests/test_functional.py:1179:5: F841 local variable 'err1' is assigned to but never used
|
||||
./tests/test_functional.py:1179:11: F841 local variable 'err2' is assigned to but never used
|
||||
./tests/test_functional.py:1179:17: F841 local variable 'err3' is assigned to but never used
|
||||
./tests/test_functional.py:1180:5: F841 local variable 'relerr1' is assigned to but never used
|
||||
./tests/test_functional.py:1180:14: F841 local variable 'relerr2' is assigned to but never used
|
||||
./tests/test_functional.py:1192:9: F841 local variable 'C1' is assigned to but never used
|
||||
./tests/test_functional.py:1313:9: F841 local variable 'c' is assigned to but never used
|
||||
./tests/test_functional.py:1314:9: F841 local variable 'c2' is assigned to but never used
|
||||
./tests/test_functional.py:1406:9: F841 local variable 'C1' is assigned to but never used
|
||||
./tests/test_functional.py:1425:9: F841 local variable 'out2' is assigned to but never used
|
||||
./tests/test_functional.py:1542:5: F841 local variable 'idx_col' is assigned to but never used
|
||||
./tests/test_functional.py:1566:30: E203 whitespace before ':'
|
||||
./tests/test_functional.py:1568:38: E203 whitespace before ':'
|
||||
./tests/test_functional.py:1655:5: F841 local variable 'offset' is assigned to but never used
|
||||
./tests/test_functional.py:1706:9: F841 local variable 'out' is assigned to but never used
|
||||
./tests/test_functional.py:1822:9: F841 local variable 'out' is assigned to but never used
|
||||
./tests/test_functional.py:1882:5: F841 local variable 'out2' is assigned to but never used
|
||||
./tests/test_functional.py:1928:9: F841 local variable 'dtype' is assigned to but never used
|
||||
./tests/test_functional.py:1934:9: F841 local variable 'minx' is assigned to but never used
|
||||
./tests/test_functional.py:1948:5: F841 local variable 'C0' is assigned to but never used
|
||||
./tests/test_modules.py:1:1: F401 'itertools.product' imported but unused
|
||||
./tests/test_modules.py:52:9: F841 local variable 'norm' is assigned to but never used
|
||||
./tests/test_modules.py:52:16: F821 undefined name 'math'
|
||||
./tests/test_modules.py:52:26: F821 undefined name 'math'
|
||||
./tests/test_modules.py:52:37: F821 undefined name 'math'
|
||||
./tests/test_modules.py:177:21: F821 undefined name 'einops'
|
||||
./tests/test_modules.py:233:9: F841 local variable 'stochastic' is assigned to but never used
|
||||
./tests/test_modules.py:382:9: F841 local variable 'o1' is assigned to but never used
|
Loading…
Reference in New Issue
Block a user