diff --git a/Makefile b/Makefile index 7bee7ef..f7a4d65 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,7 @@ CSRC := $(ROOT_DIR)/csrc BUILD_DIR:= $(ROOT_DIR)/build FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu +# FILES_HIP := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include @@ -43,6 +44,7 @@ CC_CUDA10x += -gencode arch=compute_75,code=sm_75 CC_CUDA110 := -gencode arch=compute_75,code=sm_75 CC_CUDA110 += -gencode arch=compute_80,code=sm_80 +# CC_CUDA11x := -gencode arch=compute_52,code=sm_52 CC_CUDA11x := -gencode arch=compute_75,code=sm_75 CC_CUDA11x += -gencode arch=compute_80,code=sm_80 CC_CUDA11x += -gencode arch=compute_86,code=sm_86 @@ -51,6 +53,7 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86 CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 +# CC_cublasLt111 := -gencode arch=compute_52,code=sm_52 CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 @@ -107,6 +110,17 @@ cuda12x: $(BUILD_DIR) env cpuonly: $(BUILD_DIR) env $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so + +HIP_INCLUDE := -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include +# -I /opt/rocm-5.3.0/hipcub/include +HIP_LIB := -L/opt/rocm-5.3.0/lib -L/opt/rocm-5.3.0/llvm/bin/../lib/clang/15.0.0/lib/linux -L/usr/lib/gcc/x86_64-linux-gnu/11 -L/usr/lib/gcc/x86_64-linux-gnu/11/../../../../lib64 -L/lib/x86_64-linux-gnu -L/lib/../lib64 -L/usr/lib/x86_64-linux-gnu -L/usr/lib/../lib64 -L/lib -L/usr/lib -lgcc_s -lgcc -lpthread -lm -lrt -lamdhip64 -lhipblas -lhipsparse -lclang_rt.builtins-x86_64 -lstdc++ -lm -lgcc_s -lgcc -lc -lgcc_s -lgcc + +hip: $(BUILD_DIR) + /usr/bin/hipcc -std=c++14 -c -fPIC --amdgpu-target=gfx1030 $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -D NO_CUBLASLT $(CSRC)/ops.cu + /usr/bin/hipcc -std=c++14 -c -fPIC --amdgpu-target=gfx1030 $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -D NO_CUBLASLT $(CSRC)/kernels.cu + # /usr/bin/hipcc -fPIC -static $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.so + $(GPP) -std=c++14 -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -shared -fPIC -I /opt/rocm/include $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nocublaslt.so + env: @echo "ENVIRONMENT" @echo "============================" @@ -124,9 +138,9 @@ $(BUILD_DIR): mkdir -p build mkdir -p dependencies -$(ROOT_DIR)/dependencies/cub: - git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub - cd dependencies/cub; git checkout 1.11.0 +# $(ROOT_DIR)/dependencies/cub: +# git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub +# cd dependencies/cub; git checkout 1.11.0 clean: rm build/* diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 041df4b..b968204 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index ac7948b..c870bce 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -28,23 +28,7 @@ print() from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL -from .cuda_setup.env_vars import to_be_ignored -from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle -print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS") -for k, v in os.environ.items(): - if "/" in v and not to_be_ignored(k, v): - print(f"'{k}': '{v}'") -print_header("") - -print( - "\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n" -) - -print_header("OTHER") -print(f"{COMPILED_WITH_CUDA = }") -cuda = get_cuda_lib_handle() -print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}") print_header("") print_header("DEBUG INFO END") print_header("") diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 7a62c1e..c2c0a13 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -5,7 +5,59 @@ import torch from pathlib import Path from warnings import warn -from bitsandbytes.cuda_setup.main import CUDASetup + +class CUDASetup(object): + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def generate_instructions(self): + self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') + self.add_log_entry('git clone git@github.com:TimDettmers/bitsandbytes.git') + self.add_log_entry('cd bitsandbytes') + self.add_log_entry("") + self.add_log_entry('python setup.py install') + + def initialize(self): + self.has_printed = False + self.lib = None + self.run_cuda_setup() + + def run_cuda_setup(self): + self.initialized = True + self.cuda_setup_log = [] + + binary_name = "libbitsandbytes_hip_nocublaslt.so" + package_dir = Path(__file__).parent + binary_path = package_dir / binary_name + + try: + if not binary_path.exists(): + raise Exception('CUDA SETUP: Setup Failed!') + else: + self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") + self.lib = ct.cdll.LoadLibrary(binary_path) + except Exception as ex: + self.add_log_entry(str(ex)) + self.print_log_stack() + + def add_log_entry(self, msg, is_warning=False): + self.cuda_setup_log.append((msg, is_warning)) + + def print_log_stack(self): + for msg, is_warning in self.cuda_setup_log: + if is_warning: + warn(msg) + else: + print(msg) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance setup = CUDASetup.get_instance() diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py deleted file mode 100644 index 536a7d8..0000000 --- a/bitsandbytes/cuda_setup/env_vars.py +++ /dev/null @@ -1,51 +0,0 @@ -import os -from typing import Dict - - -def to_be_ignored(env_var: str, value: str) -> bool: - ignorable = { - "PWD", # PWD: this is how the shell keeps track of the current working dir - "OLDPWD", - "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated - "SSH_TTY", - "HOME", # Linux shell default - "TMUX", # Terminal Multiplexer - "XDG_DATA_DIRS", # XDG: Desktop environment stuff - "XDG_RUNTIME_DIR", - "MAIL", # something related to emails - "SHELL", # binary for currently invoked shell - "DBUS_SESSION_BUS_ADDRESS", # hardware related - "PATH", # this is for finding binaries, not libraries - "LESSOPEN", # related to the `less` command - "LESSCLOSE", - "_", # current Python interpreter - } - return env_var in ignorable - - -def might_contain_a_path(candidate: str) -> bool: - return "/" in candidate - - -def is_active_conda_env(env_var: str) -> bool: - return "CONDA_PREFIX" == env_var - - -def is_other_conda_env_var(env_var: str) -> bool: - return "CONDA" in env_var - - -def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: - return is_active_conda_env(env_var) or ( - might_contain_a_path(value) and not - is_other_conda_env_var(env_var) and not - to_be_ignored(env_var, value) - ) - - -def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: - return { - env_var: value - for env_var, value in os.environ.items() - if is_relevant_candidate_env_var(env_var, value) - } diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py deleted file mode 100644 index cd9573f..0000000 --- a/bitsandbytes/cuda_setup/main.py +++ /dev/null @@ -1,412 +0,0 @@ -""" -extract factors the build is dependent on: -[X] compute capability - [ ] TODO: Q - What if we have multiple GPUs of different makes? -- CUDA version -- Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) - - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) - -evaluation: - - if paths faulty, return meaningful error - - else: - - determine CUDA version - - determine capabilities - - based on that set the default path -""" - -import ctypes as ct -import os -import errno -import torch -from warnings import warn - -from pathlib import Path -from typing import Set, Union -from .env_vars import get_potentially_lib_path_containing_env_vars - -CUDA_RUNTIME_LIB: str = "libcudart.so" - -class CUDASetup: - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def generate_instructions(self): - if self.cuda is None: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') - self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') - self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') - self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') - self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') - return - - if self.cudart_path is None: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') - self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') - self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') - self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') - self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://github.com/TimDettmers/bitsandbytes/blob/main/cuda_install.sh') - self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') - self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') - return - - make_cmd = f'CUDA_VERSION={self.cuda_version_string}' - if len(self.cuda_version_string) < 3: - make_cmd += ' make cuda92' - elif self.cuda_version_string == '110': - make_cmd += ' make cuda110' - elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: - make_cmd += ' make cuda11x' - elif self.cuda_version_string == '100': - self.add_log_entry('CUDA SETUP: CUDA 10.0 not supported. Please use a different CUDA version.') - self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.') - return - - - has_cublaslt = is_cublasLt_compatible(self.cc) - if not has_cublaslt: - make_cmd += '_nomatmul' - - self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') - self.add_log_entry('git clone git@github.com:TimDettmers/bitsandbytes.git') - self.add_log_entry('cd bitsandbytes') - self.add_log_entry(make_cmd) - self.add_log_entry('python setup.py install') - - def initialize(self): - if not getattr(self, 'initialized', False): - self.has_printed = False - self.lib = None - self.initialized = False - - def run_cuda_setup(self): - self.initialized = True - self.cuda_setup_log = [] - - binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() - self.cudart_path = cudart_path - self.cuda = cuda - self.cc = cc - self.cuda_version_string = cuda_version_string - - package_dir = Path(__file__).parent.parent - binary_path = package_dir / binary_name - - try: - if not binary_path.exists(): - self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?") - legacy_binary_name = "libbitsandbytes_cpu.so" - self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") - binary_path = package_dir / legacy_binary_name - if not binary_path.exists() or torch.cuda.is_available(): - self.add_log_entry('') - self.add_log_entry('='*48 + 'ERROR' + '='*37) - self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:') - self.add_log_entry('1. CUDA driver not installed') - self.add_log_entry('2. CUDA not installed') - self.add_log_entry('3. You have multiple conflicting CUDA libraries') - self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!') - self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') - self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.') - self.add_log_entry('='*80) - self.add_log_entry('') - self.generate_instructions() - self.print_log_stack() - raise Exception('CUDA SETUP: Setup Failed!') - self.lib = ct.cdll.LoadLibrary(binary_path) - else: - self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") - self.lib = ct.cdll.LoadLibrary(binary_path) - except Exception as ex: - self.add_log_entry(str(ex)) - self.print_log_stack() - - def add_log_entry(self, msg, is_warning=False): - self.cuda_setup_log.append((msg, is_warning)) - - def print_log_stack(self): - for msg, is_warning in self.cuda_setup_log: - if is_warning: - warn(msg) - else: - print(msg) - - @classmethod - def get_instance(cls): - if cls._instance is None: - cls._instance = cls.__new__(cls) - cls._instance.initialize() - return cls._instance - - -def is_cublasLt_compatible(cc): - has_cublaslt = False - if cc is not None: - cc_major, cc_minor = cc.split('.') - if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5): - cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True) - else: - has_cublaslt = True - return has_cublaslt - -def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]: - return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path} - - -def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: - existent_directories: Set[Path] = set() - for path in candidate_paths: - try: - if path.exists(): - existent_directories.add(path) - except OSError as exc: - if exc.errno != errno.ENAMETOOLONG: - raise exc - - non_existent_directories: Set[Path] = candidate_paths - existent_directories - if non_existent_directories: - CUDASetup.get_instance().add_log_entry("WARNING: The following directories listed in your path were found to " - f"be non-existent: {non_existent_directories}", is_warning=True) - - return existent_directories - - -def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: - return { - path / CUDA_RUNTIME_LIB - for path in candidate_paths - if (path / CUDA_RUNTIME_LIB).is_file() - } - - -def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: - """ - Searches a given environmental var for the CUDA runtime library, - i.e. `libcudart.so`. - """ - return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) - - -def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: - return get_cuda_runtime_lib_paths( - resolve_paths_list(paths_list_candidate) - ) - - -def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: - if len(results_paths) > 1: - warning_msg = ( - f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. " - "We'll flip a coin and try one of these, in order to fail forward.\n" - "Either way, this might cause trouble in the future:\n" - "If you get `CUDA error: invalid device function` errors, the above " - "might be the cause and the solution is to make sure only one " - f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.") - CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) - - -def determine_cuda_runtime_lib_path() -> Union[Path, None]: - """ - Searches for a cuda installations, in the following order of priority: - 1. active conda env - 2. LD_LIBRARY_PATH - 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) - - don't contain the path separator `/` - - If multiple libraries are found in part 3, we optimistically try one, - while giving a warning message. - """ - candidate_env_vars = get_potentially_lib_path_containing_env_vars() - - if "CONDA_PREFIX" in candidate_env_vars: - conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" - - conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) - warn_in_case_of_duplicates(conda_cuda_libs) - - if conda_cuda_libs: - return next(iter(conda_cuda_libs)) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' - f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) - - if "LD_LIBRARY_PATH" in candidate_env_vars: - lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) - - if lib_ld_cuda_libs: - return next(iter(lib_ld_cuda_libs)) - warn_in_case_of_duplicates(lib_ld_cuda_libs) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' - f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True) - - remaining_candidate_env_vars = { - env_var: value for env_var, value in candidate_env_vars.items() - if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} - } - - cuda_runtime_libs = set() - for env_var, value in remaining_candidate_env_vars.items(): - cuda_runtime_libs.update(find_cuda_lib_in(value)) - - if len(cuda_runtime_libs) == 0: - CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...') - cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) - - warn_in_case_of_duplicates(cuda_runtime_libs) - - return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None - - -def check_cuda_result(cuda, result_val): - # 3. Check for CUDA errors - if result_val != 0: - error_str = ct.c_char_p() - cuda.cuGetErrorString(result_val, ct.byref(error_str)) - if error_str.value is not None: - CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}") - else: - CUDASetup.get_instance().add_log_entry(f"Unknown CUDA exception! Please check your CUDA install. It might also be that your GPU is too old.") - - -# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION -def get_cuda_version(cuda, cudart_path): - if cuda is None: return None - - try: - cudart = ct.CDLL(cudart_path) - except OSError: - CUDASetup.get_instance().add_log_entry(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') - return None - - version = ct.c_int() - try: - check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ct.byref(version))) - except AttributeError as e: - CUDASetup.get_instance().add_log_entry(f'ERROR: {str(e)}') - CUDASetup.get_instance().add_log_entry(f'CUDA SETUP: libcudart.so path is {cudart_path}') - CUDASetup.get_instance().add_log_entry(f'CUDA SETUP: Is seems that your cuda installation is not in your path. See https://github.com/TimDettmers/bitsandbytes/issues/85 for more information.') - version = int(version.value) - major = version//1000 - minor = (version-(major*1000))//10 - - if major < 11: - CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') - - return f'{major}{minor}' - - -def get_cuda_lib_handle(): - # 1. find libcuda.so library (GPU driver) (/usr/lib) - try: - cuda = ct.CDLL("libcuda.so") - except OSError: - CUDASetup.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') - return None - check_cuda_result(cuda, cuda.cuInit(0)) - - return cuda - - -def get_compute_capabilities(cuda): - """ - 1. find libcuda.so library (GPU driver) (/usr/lib) - init_device -> init variables -> call function by reference - 2. call extern C function to determine CC - (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) - 3. Check for CUDA errors - https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api - # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 - """ - - nGpus = ct.c_int() - cc_major = ct.c_int() - cc_minor = ct.c_int() - - device = ct.c_int() - - check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus))) - ccs = [] - for i in range(nGpus.value): - check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i)) - ref_major = ct.byref(cc_major) - ref_minor = ct.byref(cc_minor) - # 2. call extern C function to determine CC - check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)) - ccs.append(f"{cc_major.value}.{cc_minor.value}") - - return ccs - - -# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error -def get_compute_capability(cuda): - """ - Extracts the highest compute capbility from all available GPUs, as compute - capabilities are downwards compatible. If no GPUs are detected, it returns - None. - """ - if cuda is None: return None - - # TODO: handle different compute capabilities; for now, take the max - ccs = get_compute_capabilities(cuda) - if ccs: return ccs[-1] - - -def evaluate_cuda_setup(): - if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - print('') - print('='*35 + 'BUG REPORT' + '='*35) - print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') - print('='*80) - if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None - - cuda_setup = CUDASetup.get_instance() - cudart_path = determine_cuda_runtime_lib_path() - cuda = get_cuda_lib_handle() - cc = get_compute_capability(cuda) - cuda_version_string = get_cuda_version(cuda, cudart_path) - - failure = False - if cudart_path is None: - failure = True - cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True) - else: - cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") - - if cc == '' or cc is None: - failure = True - cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True) - else: - cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") - - if cuda is None: - failure = True - else: - cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') - - # 7.5 is the minimum CC vor cublaslt - has_cublaslt = is_cublasLt_compatible(cc) - - # TODO: - # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) - # (2) Multiple CUDA versions installed - - # we use ls -l instead of nvcc to determine the cuda version - # since most installations will have the libcudart.so installed, but not the compiler - - if failure: - binary_name = "libbitsandbytes_cpu.so" - elif has_cublaslt: - binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so" - else: - "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" - binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so" - - return binary_name, cudart_path, cuda, cc, cuda_version_string diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h index 2ddf81e..f68a968 100644 --- a/csrc/cpu_ops.h +++ b/csrc/cpu_ops.h @@ -7,4 +7,4 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); -#endif +#endif \ No newline at end of file diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 08b9b44..2a1acde 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3,15 +3,17 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. + +#include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include #define HLF_MAX 65504 #define TH 1024 @@ -19,29 +21,29 @@ #define NUM_BLOCK 4096 // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda -__device__ float atomicMax(float* address, float val) { - int* address_as_i = reinterpret_cast(address); - int old = *address_as_i, assumed; - do { - assumed = old; - old = atomicCAS( - reinterpret_cast(address), assumed, - __float_as_int(fmaxf(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); -} +// __device__ float atomicMax(float* address, float val) { +// int* address_as_i = reinterpret_cast(address); +// int old = *address_as_i, assumed; +// do { +// assumed = old; +// old = atomicCAS( +// reinterpret_cast(address), assumed, +// __float_as_int(fmaxf(val, __int_as_float(assumed)))); +// } while (assumed != old); +// return __int_as_float(old); +// } -__device__ float atomicMin(float* address, float val) { - int* address_as_i = reinterpret_cast(address); - int old = *address_as_i, assumed; - do { - assumed = old; - old = atomicCAS( - reinterpret_cast(address), assumed, - __float_as_int(fminf(val, __int_as_float(assumed)))); - } while (assumed != old); - return __int_as_float(old); -} +// __device__ float atomicMin(float* address, float val) { +// int* address_as_i = reinterpret_cast(address); +// int old = *address_as_i, assumed; +// do { +// assumed = old; +// old = atomicCAS( +// reinterpret_cast(address), assumed, +// __float_as_int(fminf(val, __int_as_float(assumed)))); +// } while (assumed != old); +// return __int_as_float(old); +// } template __device__ unsigned char dQuantize(float* smem_code, const float rand, float x) @@ -232,9 +234,9 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index template __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n) { - typedef cub::WarpReduce WarpReduce; + typedef hipcub::WarpReduce WarpReduce; __shared__ typename WarpReduce::TempStorage temp_storage; - typedef cub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadT; __shared__ typename LoadT::TempStorage loadt; const int warp_idx = threadIdx.x/32; @@ -282,8 +284,8 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou for(int i = 0; i < 8; i++) { // 3. do warp reduction + broadcast back - warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max()); - warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); + warp_max = WarpReduce(temp_storage).Reduce(max1, hipcub::Max()); + warp_max = hipcub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest if(warp_max == max1) @@ -297,7 +299,9 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou max2 = -64000.0f; } - __syncwarp(); + // __syncwarp(); + __syncthreads(); + } if(threadIdx.x % 32 < 8) @@ -324,8 +328,8 @@ __global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const f T vals[NUM_ESTIMATE]; - typedef cub::BlockRadixSort BlockRadixSort; - typedef cub::BlockLoad LoadFloat; + typedef hipcub::BlockRadixSort BlockRadixSort; + typedef hipcub::BlockLoad LoadFloat; __shared__ union { typename LoadFloat::TempStorage loadf; @@ -391,8 +395,8 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c unsigned char qvals[NUM]; //const int lane_id = threadIdx.x % 2; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreChar; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreChar; __shared__ typename LoadFloat::TempStorage loadf; __shared__ typename StoreChar::TempStorage storec; @@ -442,10 +446,10 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float float local_abs_max = 0.0f; int local_rand_idx = 0; - typedef cub::BlockLoad LoadT; - typedef cub::BlockStore StoreChar; - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockLoad LoadFloat; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadFloat; __shared__ typename LoadT::TempStorage loadt; __shared__ typename LoadFloat::TempStorage loadf; @@ -473,7 +477,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float for(int j = 0; j < NUM_PER_TH; j++) local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); - local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items); if(threadIdx.x == 0) smem_absmax_value[0] = local_abs_max; @@ -485,7 +489,9 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float else local_abs_max = smem_absmax_value[0]; - __syncwarp(); + // __syncwarp(); + __syncthreads(); + local_abs_max = 1.0f/local_abs_max; @@ -521,8 +527,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; - typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreT; + typedef hipcub::BlockLoad LoadChar; + typedef hipcub::BlockStore StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; @@ -592,9 +598,9 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, const float correction1 = 1.0f/(1.0f - powf(beta1, step)); const float correction2 = 1.0f/(1.0f - powf(beta2, step)); - typedef cub::BlockLoad Load; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; __shared__ union { typename Load::TempStorage load; @@ -643,7 +649,9 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, if(threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); - __syncwarp(); + // __syncwarp(); + __syncthreads(); + } } @@ -681,11 +689,11 @@ __global__ void kOptimizer32bit2State(T* g, T* p, } else{ update_scale = 1.0f; } - typedef cub::BlockLoad Load; - typedef cub::BlockStore Store; + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreFloat; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; @@ -755,9 +763,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, float s1_vals[NUM_VALS]; - typedef cub::BlockLoad Load; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockReduce BlockReduce; __shared__ union { typename Load::TempStorage load; @@ -813,7 +821,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p, if(threadIdx.x == 0) atomicAdd(&unorm[0], s1_vals[0]); - __syncwarp(); + // __syncwarp(); + __syncthreads(); + } } @@ -843,11 +853,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p, float s1_vals[NUM_PER_THREAD]; - typedef cub::BlockLoad Load; - typedef cub::BlockStore Store; + typedef hipcub::BlockLoad Load; + typedef hipcub::BlockStore Store; - typedef cub::BlockLoad LoadFloat; - typedef cub::BlockStore StoreFloat; + typedef hipcub::BlockLoad LoadFloat; + typedef hipcub::BlockStore StoreFloat; __shared__ union { typename Load::TempStorage load; @@ -939,9 +949,9 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c unsigned char m_c1[NUM8BIT]; unsigned char r_c2[NUM8BIT]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadUInt8; - typedef cub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; __shared__ union { @@ -1008,13 +1018,13 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c } __syncthreads(); - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); __syncthreads(); - local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items); if(unorm != NULL) { __syncthreads(); - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); } if(threadIdx.x == 0) @@ -1068,11 +1078,11 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha unsigned char c2s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; __shared__ float smem_quantiles1[256]; __shared__ float smem_quantiles2[256]; @@ -1176,9 +1186,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c T g_vals[NUM8BIT]; unsigned char m_c1[NUM8BIT]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadUInt8; - typedef cub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadUInt8; + typedef hipcub::BlockReduce BlockReduce; __shared__ union { @@ -1229,12 +1239,12 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c } __syncthreads(); - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items); if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } if(unorm != NULL) { __syncthreads(); - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items); if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } } @@ -1271,11 +1281,11 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, unsigned char c1s[NUM_PER_THREAD2]; T p_vals[NUM_PER_THREAD2]; T g_vals[NUM_PER_THREAD2]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; __shared__ float smem_quantiles1[256]; @@ -1353,8 +1363,8 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); int valid_items = 0; - typedef cub::BlockReduce BlockReduce; - typedef cub::BlockLoad LoadT; + typedef hipcub::BlockReduce BlockReduce; + typedef hipcub::BlockLoad LoadT; __shared__ typename BlockReduce::TempStorage reduce; @@ -1426,16 +1436,16 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; T g_vals[N_PER_TH]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; __shared__ float smem_quantiles2[LANES][257]; - typedef cub::BlockReduce BlockReduce1; - typedef cub::BlockReduce BlockReduce2; + typedef hipcub::BlockReduce BlockReduce1; + typedef hipcub::BlockReduce BlockReduce2; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ typename BlockReduce2::TempStorage reduce2; __shared__ float smem_exchange1[1]; @@ -1504,8 +1514,8 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); - new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max()); if(threadIdx.x == 0) { @@ -1599,14 +1609,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char T g_vals[N_PER_TH]; T p_vals[N_PER_TH]; - typedef cub::BlockLoad LoadT; - typedef cub::BlockLoad LoadChar; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockLoad LoadChar; - typedef cub::BlockStore StoreChar; - typedef cub::BlockStore StoreT; + typedef hipcub::BlockStore StoreChar; + typedef hipcub::BlockStore StoreT; __shared__ float smem_quantiles1[LANES][257]; - typedef cub::BlockReduce BlockReduce1; + typedef hipcub::BlockReduce BlockReduce1; __shared__ typename BlockReduce1::TempStorage reduce1; __shared__ float smem_exchange1[1]; @@ -1678,7 +1688,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char // reduce: 2.51/1.60 -> 2.67/1.69 - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max()); if(threadIdx.x == 0) smem_exchange1[0] = new_local_abs_max1; @@ -1756,10 +1766,10 @@ template LoadT; - typedef cub::BlockReduce BlockRowReduce; - typedef cub::BlockReduce BlockRowSum; - typedef cub::BlockExchange BlockExchange; + typedef hipcub::BlockLoad LoadT; + typedef hipcub::BlockReduce BlockRowReduce; + typedef hipcub::BlockReduce BlockRowSum; + typedef hipcub::BlockExchange BlockExchange; __shared__ union { typename BlockExchange::TempStorage exchange; @@ -1837,7 +1847,7 @@ template__global__ void kd float local_rowStats[ITEMS_PER_THREAD]; __shared__ float smem_rowStats[SUBTILE_ROWS]; - typedef cub::BlockLoad LoadInt32; - typedef cub::BlockExchange ExchangeInt32; + typedef hipcub::BlockLoad LoadInt32; + typedef hipcub::BlockExchange ExchangeInt32; __shared__ typename LoadInt32::TempStorage loadint32; __shared__ typename ExchangeInt32::TempStorage exchangeint32; @@ -2033,9 +2043,9 @@ template LoadHalf; + typedef hipcub::BlockLoad LoadHalf; __shared__ typename LoadHalf::TempStorage loadhalf; - typedef cub::BlockStore StoreInt8; + typedef hipcub::BlockStore StoreInt8; __shared__ typename StoreInt8::TempStorage storeint8; __shared__ float smem_row_stats[TILE_ROWS]; @@ -2168,7 +2178,7 @@ template BlockExchange; + typedef hipcub::BlockExchange BlockExchange; // we load row after row from the base_position // Load data row by row diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index d90ea13..ef5a294 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -3,8 +3,10 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. + +#include #include -#include +#include "ops.cuh" #ifndef kernels #define kernels diff --git a/csrc/ops.cu b/csrc/ops.cu index e770e10..1ef1138 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -3,14 +3,17 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -#include -#include -#include -#include -#include -#include -#include +#include +#include "ops.cuh" +#include "kernels.cuh" +// #include +#include +// #include +#include +#include +#include +// #include using namespace BinSearch; using std::cout; @@ -22,16 +25,16 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr int num_blocks = n/threads; num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } template void estimateQuantiles(T *A, float *code, float offset, int n) { int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float))); kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } void quantize(float *code, float *A, unsigned char *out, int n) @@ -39,7 +42,7 @@ void quantize(float *code, float *A, unsigned char *out, int n) int num_blocks = n/1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; kQuantize<<>>(code, A, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } void dequantize(float *code, unsigned char *A, float *out, int n) @@ -47,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n) int num_blocks = n/1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; kDequantize<<>>(code, A, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) @@ -73,7 +76,7 @@ template void quantizeBlockwise(float * code, T *A, kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) @@ -95,7 +98,7 @@ template void dequantizeBlockwise(float *code, unsigned char *A, flo else if(blocksize == 64) kDequantizeBlockwise<<>>(code, A, absmax, out, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } template void optimizer32bit(T* g, T* p, @@ -110,12 +113,12 @@ template void optimizer32bit(T* g, T* p, case ADAM: if(max_unorm > 0.0f) { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); break; case MOMENTUM: case RMSPROP: @@ -123,13 +126,13 @@ template void optimizer32bit(T* g, T* p, if(max_unorm > 0.0f) { - CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); break; } } @@ -147,28 +150,28 @@ template void optimizerStatic8bit(T* p, T* g, int num_blocks = n/4096; num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; - if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); } switch(OPTIMIZER) { case ADAM: - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); - CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); break; case MOMENTUM: case RMSPROP: case ADAGRAD: - CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); break; default: break; @@ -193,7 +196,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); break; case MOMENTUM: case RMSPROP: @@ -202,7 +205,7 @@ template void optimizerStatic8bitBlockwise(T* p, T* g num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); break; } } @@ -213,31 +216,39 @@ template void percentileClipping(T * g, float *gnorm_vec, int step, { int num_blocks = n/2048; num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; - CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); kPercentileClipping<<>>(g, gnorm_vec, step, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) { - const int falpha = 1; - const int fbeta = 0; - const void * alpha = &falpha; - const void * beta = &fbeta; - cublasStatus_t status; + cout << "" << endl; + cout << "=============================================" << endl; + cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; + cout << "=============================================" << endl; + cout << "" << endl; + assert(false); - status = cublasGemmEx(context->m_handle, - transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, - transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, - m, n, k, - alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, - C, CUDA_R_32I, ldc, - CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + return ; + // const int falpha = 1; + // const int fbeta = 0; + // const void * alpha = &falpha; + // const void * beta = &fbeta; + // hipblasStatus_t status; - if (status != CUBLAS_STATUS_SUCCESS) - { - std::cout << "CUBLAS ERROR: Status " << status << std::endl; - } + // status = hipblasGemmEx(context->m_handle, + // transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + // transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, + // m, n, k, + // alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta, + // C, HIPBLAS_R_32I, ldc, + // HIPBLAS_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + // if (status != HIPBLAS_STATUS_SUCCESS) + // { + // std::cout << "CUBLAS ERROR: Status " << status << std::endl; + // } } @@ -248,7 +259,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i const int fbeta = 0; const void * alpha = &falpha; const void * beta = &fbeta; - cublasStatus_t status; + hipblasStatus_t status; //cout << transposeA << transposeB << endl; //printf("%i %i %i\n", m,n,k); @@ -256,15 +267,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i //printf("%i %i %i\n", strideA, strideB, strideC); //printf("%i\n", batchCount); - status = cublasGemmStridedBatchedEx(context->m_handle, - transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, - transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, + status = hipblasGemmStridedBatchedEx(context->m_handle, + transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, - alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, - C, CUDA_R_32I, ldc, (long long int)strideC, batchCount, - CUDA_R_32I, CUBLAS_GEMM_DEFAULT); + alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta, + C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount, + HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT); - if (status != CUBLAS_STATUS_SUCCESS) + if (status != HIPBLAS_STATUS_SUCCESS) { std::cout << "CUBLAS ERROR: Status " << status << std::endl; } @@ -276,42 +287,6 @@ int roundoff(int v, int d) { } -#ifdef NO_CUBLASLT -#else -template cublasLtOrder_t get_order() -{ - switch(ORDER) - { - case ROW: - return CUBLASLT_ORDER_ROW; - break; - case COL: - return CUBLASLT_ORDER_COL; - break; - case COL32: - return CUBLASLT_ORDER_COL32; - break; - case COL_TURING: - return CUBLASLT_ORDER_COL4_4R2_8C; - break; - case COL_AMPERE: - return CUBLASLT_ORDER_COL32_2R_4R4; - break; - default: - break; - } - - return CUBLASLT_ORDER_ROW; -} - -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -template cublasLtOrder_t get_order(); -#endif - - template int get_leading_dim(int dim1, int dim2) { switch(ORDER) @@ -345,47 +320,12 @@ template int get_leading_dim(int dim1, int dim2); template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) { -#ifdef NO_CUBLASLT -#else - cublasLtOrder_t orderA = get_order(); - cublasLtOrder_t orderOut = get_order(); - int ldA = get_leading_dim(dim1, dim2); - int ldOut = get_leading_dim(dim1, dim2); - - cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; - cublasLtMatrixTransformDesc_t A2Out_desc = NULL; - cublasOperation_t opTranspose = CUBLAS_OP_T; - float transformAlpha = 1.0f, transformBeta = 0.0f; - - - if(DTYPE == 8) - { - checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA)); - checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut)); - } - else if(DTYPE == 32) - { - checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA)); - checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut)); - } - else - { - printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); - } - - checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); - checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); - - checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F)); - - if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } - - checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); - - if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); - if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); - if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); -#endif + cout << "" << endl; + cout << "=============================================" << endl; + cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; + cout << "=============================================" << endl; + cout << "" << endl; + assert(false); } template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); @@ -399,7 +339,6 @@ template void transform(cublasLtHandle_t ltHandl template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { -#ifdef NO_CUBLASLT cout << "" << endl; cout << "=============================================" << endl; cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; @@ -408,62 +347,6 @@ template int igemmlt(cublasLtHandle assert(false); return 0; -#else - int has_error = 0; - cublasLtMatmulDesc_t matmulDesc = NULL; - cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; - cublasOperation_t opT = CUBLAS_OP_T; - cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; - cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; - cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; - cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; - - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb)); - - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - if(FORMATB == COL_TURING) - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); - else - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); - - if(DTYPE_OUT == 32) - { - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I)); - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - int alpha = 1, beta = 0; - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - else - { - has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F)); - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); - has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc)); - has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); - if(!SCALE_ROWS) - { - float alpha = 1.0f, beta = 0.0f; - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - else - { - has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); - has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); - } - } - - - if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); - if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); - if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); - if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); - if(has_error == 1) - printf("error detected"); - - return has_error; -#endif } int fill_up_to_nearest_multiple(int value, int multiple) @@ -484,7 +367,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, assert(threads <= tilesize); kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } #define STATS_THREADS 64 @@ -505,7 +388,7 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); else if(nnz_threshold != 0.0) kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } @@ -529,7 +412,7 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col else kDoubleRowColQuant<64, 4, 16, 64*4, 0><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } template void transformRowToFormat(char * A, char *out, int rows, int cols) @@ -573,69 +456,27 @@ template void transformRowToFormat(char * A, char *o } kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } -void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) { -#ifdef NO_CUBLASLT -#else + cout << "" << endl; + cout << "=============================================" << endl; + cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; + cout << "=============================================" << endl; + cout << "" << endl; + assert(false); - cusparseSpMatDescr_t descA; - cusparseDnMatDescr_t descB, descC; - - float alpha = 1.0f; - float beta = 0.0f; - void *dBuffer = NULL; - size_t bufferSize = 0; - - CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, - A_rowidx, A_colidx, A_vals, - CUSPARSE_INDEX_32I, - CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) ); - // Create dense matrix C - CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, - CUDA_R_16F, CUSPARSE_ORDER_ROW) ); - // Create dense matrix B - if(transposed_B) - { - int tmp = A_cols; - A_cols = B_cols; - B_cols = tmp; - } - - CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, - CUDA_R_16F, CUSPARSE_ORDER_ROW) ); - // allocate an external buffer if needed - CHECK_CUSPARSE( cusparseSpMM_bufferSize( - handle, - CUSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, descA, descB, &beta, descC, CUDA_R_32F, - CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); - CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) ); - - // execute SpMM - CHECK_CUSPARSE( cusparseSpMM(handle, - CUSPARSE_OPERATION_NON_TRANSPOSE, - transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, - &alpha, descA, descB, &beta, descC, CUDA_R_32F, - CUSPARSE_SPMM_ALG_DEFAULT, dBuffer)); - - // destroy matrix/vector descriptors - CHECK_CUSPARSE( cusparseDestroySpMat(descA) ); - CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); - CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); - CUDA_CHECK_RETURN( cudaFree(dBuffer) ); -#endif + return; } template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { kspmm_coo_very_sparse_naive<<>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } @@ -658,7 +499,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id } kExtractOutliers<<>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); - CUDA_CHECK_RETURN(cudaPeekAtLastError()); + CUDA_CHECK_RETURN(hipPeekAtLastError()); } //============================================================== diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 31d4dd8..48478c6 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -11,30 +11,31 @@ #include #include #include - -#include -#include -#include -#include -#include +#include +#include +#include +// #include +#include #include #include +typedef struct cublasLtContext* cublasLtHandle_t; + #define CUDA_CHECK_RETURN(value) { \ - cudaError_t _m_cudaStat = value; \ - if (_m_cudaStat != cudaSuccess) { \ + hipError_t _m_cudaStat = value; \ + if (_m_cudaStat != hipSuccess) { \ fprintf(stderr, "Error %s at line %d in file %s\n", \ - cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ exit(1); \ } } #define THREADS_PER_BLOCKS (512) #define CHECK_CUSPARSE(value) { \ - cusparseStatus_t _m_cudaStat = value; \ - if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ - fprintf(stderr, "Error %s at line %d in file %s\n", \ - cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + hipsparseStatus_t _m_cudaStat = value; \ + if (_m_cudaStat != HIPSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error at line %d in file %s\n", \ + __LINE__, __FILE__); \ exit(1); \ } } @@ -42,15 +43,15 @@ #define THREADS_PER_BLOCKS (512) -inline void checkCudaStatus(cudaError_t status) { - if (status != cudaSuccess) { - printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); +inline void checkCudaStatus(hipError_t status) { + if (status != hipSuccess) { + printf("cuda API failed with status %d: %s\n", status, hipGetErrorString(status)); throw std::logic_error("cuda API failed"); } } -inline int checkCublasStatus(cublasStatus_t status) { - if (status != CUBLAS_STATUS_SUCCESS) { +inline int checkCublasStatus(hipblasStatus_t status) { + if (status != HIPBLAS_STATUS_SUCCESS) { printf("cuBLAS API failed with status %d\n", status); //throw std::logic_error("cuBLAS API failed"); return 1; @@ -84,40 +85,40 @@ typedef enum Transform_t class Context { public: - cublasHandle_t m_handle; + hipblasHandle_t m_handle; Context() { - cublasHandle_t handle; - cublasCreate_v2(&handle); + hipblasHandle_t handle; + hipblasCreate(&handle); m_handle = handle; } }; -class ContextLt -{ - public: - cublasLtHandle_t m_handle; +// class ContextLt +// { +// public: +// cublasLtHandle_t m_handle; - ContextLt() - { - cublasLtHandle_t handle; - cublasLtCreate(&handle); - m_handle = handle; - } +// ContextLt() +// { +// cublasLtHandle_t handle; +// cublasLtCreate(&handle); +// m_handle = handle; +// } -}; +// }; class ContextCusparse { public: - cusparseHandle_t m_handle; + hipsparseHandle_t m_handle; ContextCusparse() { - cusparseHandle_t handle; - cusparseCreate(&handle); + hipsparseHandle_t handle; + hipsparseCreate(&handle); m_handle = handle; } @@ -170,7 +171,7 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col template void transformRowToFormat(char * A, char *out, int rows, int cols); -void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); +void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d8b2290..a706139 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -275,7 +275,7 @@ extern "C" { transform_row2ampereT(A, out, rows, cols); } void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) - { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } + { spmm_coo((hipsparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } diff --git a/include/Algo-Direct2.h b/include/Algo-Direct2.h index d5fa58d..b901c66 100644 --- a/include/Algo-Direct2.h +++ b/include/Algo-Direct2.h @@ -2,73 +2,75 @@ #include "Algo-Direct-Common.h" -namespace BinSearch { -namespace Details { - -template -struct AlgoScalarBase::value>::type> : DirectAux::DirectInfo<2, T, A> +namespace BinSearch { -private: - typedef DirectAux::DirectInfo<2, T, A> base_t; - static const size_t Offset=2; - -public: - AlgoScalarBase(const T* x, const uint32 n) - : base_t(x, n) + namespace Details { - } - FORCE_INLINE uint32 scalar(T z) const - { - const T* px = base_t::data.xi; - const uint32* buckets = reinterpret_cast(base_t::data.buckets); - uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z); - uint32 iidx = buckets[bidx]; - px += iidx; - if (z < *px) - --iidx; - if (z < *(px+1)) - --iidx; - return iidx; - } -}; + template + struct AlgoScalarBase::value>::type> : DirectAux::DirectInfo<2, T, A> + { + private: + typedef DirectAux::DirectInfo<2, T, A> base_t; + static const size_t Offset = 2; + public: + AlgoScalarBase(const T *x, const uint32 n) + : base_t(x, n) + { + } -template -struct AlgoVecBase::value>::type> : AlgoScalarBase -{ - static const uint32 nElem = sizeof(typename InstrFloatTraits::vec_t) / sizeof(T); + FORCE_INLINE uint32 scalar(T z) const + { + const T *px = base_t::data.xi; + const uint32 *buckets = reinterpret_cast(base_t::data.buckets); + uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z); + uint32 iidx = buckets[bidx]; + px += iidx; + if (z < *px) + --iidx; + if (z < *(px + 1)) + --iidx; + return iidx; + } + }; - typedef FVec fVec; - typedef IVec i128; + template + struct AlgoVecBase::value>::type> : AlgoScalarBase + { + static const uint32 nElem = sizeof(typename InstrFloatTraits::vec_t) / sizeof(T); - struct Constants - { - fVec vscaler; - fVec vcst0; - IVec one; - }; + typedef FVec fVec; + typedef IVec i128; -private: - typedef AlgoScalarBase base_t; + struct Constants + { + fVec vscaler; + fVec vcst0; + IVec one; + }; - FORCE_INLINE - //NO_INLINE - void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const - { - union U { - __m128i vec; - uint32 ui32[4]; - } u; + private: + typedef AlgoScalarBase base_t; - const uint32* buckets = reinterpret_cast(base_t::data.buckets); - const float *xi = base_t::data.xi; + FORCE_INLINE + // NO_INLINE + void resolve(const FVec &vz, const IVec &bidx, uint32 *pr) const + { + union U + { + __m128i vec; + uint32 ui32[4]; + } u; - // read indices t - const double *p3 = reinterpret_cast(&xi[(u.ui32[3] = buckets[bidx.get3()])]); - const double *p2 = reinterpret_cast(&xi[(u.ui32[2] = buckets[bidx.get2()])]); - const double *p1 = reinterpret_cast(&xi[(u.ui32[1] = buckets[bidx.get1()])]); - const double *p0 = reinterpret_cast(&xi[(u.ui32[0] = buckets[bidx.get0()])]); + const uint32 *buckets = reinterpret_cast(base_t::data.buckets); + const float *xi = base_t::data.xi; + + // read indices t + const double *p3 = reinterpret_cast(&xi[(u.ui32[3] = buckets[bidx.get3()])]); + const double *p2 = reinterpret_cast(&xi[(u.ui32[2] = buckets[bidx.get2()])]); + const double *p1 = reinterpret_cast(&xi[(u.ui32[1] = buckets[bidx.get1()])]); + const double *p0 = reinterpret_cast(&xi[(u.ui32[0] = buckets[bidx.get0()])]); #if 0 // read pairs ( X(t-1), X(t) ) @@ -87,65 +89,66 @@ private: __m128 vxm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6)); __m128 vxp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6)); #else - __m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2)); - __m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0)); - __m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6)); - __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6)); + __m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2)); + __m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0)); + __m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6)); + __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6)); #endif - IVec i(u.vec); - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; - i = i + vlem + vlep; - i.store(pr); - } + IVec i(u.vec); + IVec vlem = (vz < vxm); + IVec vlep = (vz < vxp); + i = i + vlem + vlep; + i.store(pr); + } - FORCE_INLINE - //NO_INLINE - void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const - { - const uint32* buckets = reinterpret_cast(base_t::data.buckets); - const double *xi = base_t::data.xi; + FORCE_INLINE + // NO_INLINE + void resolve(const FVec &vz, const IVec &bidx, uint32 *pr) const + { + const uint32 *buckets = reinterpret_cast(base_t::data.buckets); + const double *xi = base_t::data.xi; - uint32 b1 = buckets[bidx.get1()]; - uint32 b0 = buckets[bidx.get0()]; + uint32 b1 = buckets[bidx.get1()]; + uint32 b0 = buckets[bidx.get0()]; - const double *p1 = &xi[b1]; - const double *p0 = &xi[b0]; + const double *p1 = &xi[b1]; + const double *p0 = &xi[b0]; - // read pairs ( X(t-1), X(t) ) - __m128d vx1 = _mm_loadu_pd(p1); - __m128d vx0 = _mm_loadu_pd(p0); + // read pairs ( X(t-1), X(t) ) + __m128d vx1 = _mm_loadu_pd(p1); + __m128d vx0 = _mm_loadu_pd(p0); - // build: - // { X(t(0)-1), X(t(1)-1) } - // { X(t(0)), X(t(1)) } - __m128d vxm = _mm_shuffle_pd(vx0, vx1, 0); - __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3); + // build: + // { X(t(0)-1), X(t(1)-1) } + // { X(t(0)), X(t(1)) } + __m128d vxm = _mm_shuffle_pd(vx0, vx1, 0); + __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3); - IVec i(b1, b0); - IVec vlem = (vz < vxm); - IVec vlep = (vz < vxp); - i = i + vlem + vlep; + IVec i(b1, b0); + IVec vlem = (vz < vxm); + IVec vlep = (vz < vxp); + i = i + vlem + vlep; - union { - __m128i vec; - uint32 ui32[4]; - } u; - u.vec = i; - pr[0] = u.ui32[0]; - pr[1] = u.ui32[2]; - } + union + { + __m128i vec; + uint32 ui32[4]; + } u; + u.vec = i; + pr[0] = u.ui32[0]; + pr[1] = u.ui32[2]; + } #ifdef USE_AVX - FORCE_INLINE - //NO_INLINE - void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const - { - const uint32* buckets = reinterpret_cast(base_t::data.buckets); - const float *xi = base_t::data.xi; + FORCE_INLINE + // NO_INLINE + void resolve(const FVec &vz, const IVec &bidx, uint32 *pr) const + { + const uint32 *buckets = reinterpret_cast(base_t::data.buckets); + const float *xi = base_t::data.xi; -#if 0 // use gather instructions +#if 0 // use gather instructions IVec idxm; idxm.setidx(buckets, bidx); @@ -159,21 +162,22 @@ private: #else // do not use gather instrucions - union U { - __m256i vec; - uint32 ui32[8]; - } u; + union U + { + __m256i vec; + uint32 ui32[8]; + } u; - // read indices t + // read indices t - const double *p7 = reinterpret_cast(&xi[(u.ui32[7] = buckets[bidx.get7()])]); - const double *p6 = reinterpret_cast(&xi[(u.ui32[6] = buckets[bidx.get6()])]); - const double *p5 = reinterpret_cast(&xi[(u.ui32[5] = buckets[bidx.get5()])]); - const double *p4 = reinterpret_cast(&xi[(u.ui32[4] = buckets[bidx.get4()])]); - const double *p3 = reinterpret_cast(&xi[(u.ui32[3] = buckets[bidx.get3()])]); - const double *p2 = reinterpret_cast(&xi[(u.ui32[2] = buckets[bidx.get2()])]); - const double *p1 = reinterpret_cast(&xi[(u.ui32[1] = buckets[bidx.get1()])]); - const double *p0 = reinterpret_cast(&xi[(u.ui32[0] = buckets[bidx.get0()])]); + const double *p7 = reinterpret_cast(&xi[(u.ui32[7] = buckets[bidx.get7()])]); + const double *p6 = reinterpret_cast(&xi[(u.ui32[6] = buckets[bidx.get6()])]); + const double *p5 = reinterpret_cast(&xi[(u.ui32[5] = buckets[bidx.get5()])]); + const double *p4 = reinterpret_cast(&xi[(u.ui32[4] = buckets[bidx.get4()])]); + const double *p3 = reinterpret_cast(&xi[(u.ui32[3] = buckets[bidx.get3()])]); + const double *p2 = reinterpret_cast(&xi[(u.ui32[2] = buckets[bidx.get2()])]); + const double *p1 = reinterpret_cast(&xi[(u.ui32[1] = buckets[bidx.get1()])]); + const double *p0 = reinterpret_cast(&xi[(u.ui32[0] = buckets[bidx.get0()])]); #if 0 // perform 8 loads in double precision @@ -210,96 +214,93 @@ private: IVec ip(u.vec); -#else // use __mm256_set_pd +#else // use __mm256_set_pd - // read pairs ( X(t-1), X(t) ) - __m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) } - __m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) } + // read pairs ( X(t-1), X(t) ) + __m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) } + __m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) } - // { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) } - FVec vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6) ); - // { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) } - FVec vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6) ); + // { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) } + FVec vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6)); + // { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) } + FVec vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6)); - IVec ip(u.vec); + IVec ip(u.vec); #endif #endif - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; - ip = ip + vlem + vlep; + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; + ip = ip + vlem + vlep; - ip.store(pr); - } + ip.store(pr); + } + FORCE_INLINE + // NO_INLINE + void resolve(const FVec &vz, const IVec &bidx, uint32 *pr) const + { + union + { + __m256i vec; + uint64 ui64[4]; + } u; + const uint32 *buckets = reinterpret_cast(base_t::data.buckets); + const double *xi = base_t::data.xi; - FORCE_INLINE - //NO_INLINE - void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const - { - union { - __m256i vec; - uint64 ui64[4]; - } u; + // read indices t + const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])]; + const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])]; + const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])]; + const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])]; - const uint32* buckets = reinterpret_cast(base_t::data.buckets); - const double *xi = base_t::data.xi; + // read pairs ( X(t-1), X(t) ) + __m128d xp3 = _mm_loadu_pd(p3); + __m128d xp2 = _mm_loadu_pd(p2); + __m128d xp1 = _mm_loadu_pd(p1); + __m128d xp0 = _mm_loadu_pd(p0); - // read indices t - const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])]; - const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])]; - const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])]; - const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])]; + // build: + // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) } + // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) } + __m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1); + __m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1); + FVec vxm = _mm256_unpacklo_pd(x02, x13); + FVec vxp = _mm256_unpackhi_pd(x02, x13); - // read pairs ( X(t-1), X(t) ) - __m128d xp3 = _mm_loadu_pd(p3); - __m128d xp2 = _mm_loadu_pd(p2); - __m128d xp1 = _mm_loadu_pd(p1); - __m128d xp0 = _mm_loadu_pd(p0); + // __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0); + // __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0); + // __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3); + // __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3); + // FVec vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1); + // FVec vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1); - // build: - // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) } - // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) } - __m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1); - __m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1); - FVec vxm = _mm256_unpacklo_pd(x02,x13); - FVec vxp = _mm256_unpackhi_pd(x02,x13); - - -// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0); -// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0); -// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3); -// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3); -// FVec vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1); -// FVec vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1); - - IVec i(u.vec); - IVec vlem = vz < vxm; - IVec vlep = vz < vxp; - i = i + vlem + vlep; - i.extractLo32s().store(pr); - } + IVec i(u.vec); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; + i = i + vlem + vlep; + i.extractLo32s().store(pr); + } #endif -public: + public: + AlgoVecBase(const T *x, const uint32 n) : base_t(x, n) {} - AlgoVecBase(const T* x, const uint32 n) : base_t(x, n) {} + void initConstants(Constants &cst) const + { + cst.vscaler.setN(base_t::data.scaler); + cst.vcst0.setN(base_t::data.cst0); + cst.one.setN(uint32(1)); + } - void initConstants(Constants& cst) const - { - cst.vscaler.setN(base_t::data.scaler); - cst.vcst0.setN(base_t::data.cst0); - cst.one.setN(uint32(1)); - } - - void vectorial(uint32 *pr, const T *pz, const Constants& cst) const - { - fVec vz(pz); - resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr); - } -}; -} // namespace Details + void vectorial(uint32 *pr, const T *pz, const Constants &cst) const + { + fVec vz(pz); + resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr); + } + }; + } // namespace Details } // namespace BinSearch