should be hippified, and all cuda checkes cleaned up, makefile not updated yet

This commit is contained in:
broncotc 2022-11-23 17:52:19 -08:00
parent c059bd2848
commit 2dcf38289d
12 changed files with 377 additions and 925 deletions

View File

@ -20,9 +20,10 @@ CSRC := $(ROOT_DIR)/csrc
BUILD_DIR:= $(ROOT_DIR)/build
FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_HIP := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
# NVIDIA NVCC compilation flags
@ -45,6 +46,7 @@ CC_CUDA10x += -gencode arch=compute_75,code=sm_75
CC_CUDA110 := -gencode arch=compute_75,code=sm_75
CC_CUDA110 += -gencode arch=compute_80,code=sm_80
# CC_CUDA11x := -gencode arch=compute_52,code=sm_52
CC_CUDA11x := -gencode arch=compute_75,code=sm_75
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
@ -52,22 +54,23 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86
CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
# CC_cublasLt111 := -gencode arch=compute_52,code=sm_52
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
all: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
cuda92: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
cuda10x_nomatmul: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
@ -112,9 +115,9 @@ $(BUILD_DIR):
mkdir -p build
mkdir -p dependencies
$(ROOT_DIR)/dependencies/cub:
git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub
cd dependencies/cub; git checkout 1.11.0
# $(ROOT_DIR)/dependencies/cub:
# git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub
# cd dependencies/cub; git checkout 1.11.0
clean:
rm build/*

View File

@ -12,43 +12,10 @@ class CUDASetup(object):
raise RuntimeError("Call get_instance() instead")
def generate_instructions(self):
if self.cuda is None:
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.')
self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.')
self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:')
self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null')
self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a')
self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc')
return
if self.cudart_path is None:
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.')
self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable')
self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null')
self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a')
self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc')
self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.')
self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://github.com/TimDettmers/bitsandbytes/blob/main/cuda_install.sh')
self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.')
self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local')
return
make_cmd = f'CUDA_VERSION={self.cuda_version_string}'
if len(self.cuda_version_string) < 3:
make_cmd += ' make cuda92'
elif self.cuda_version_string == '110':
make_cmd += ' make cuda110'
elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0:
make_cmd += ' make cuda11x'
has_cublaslt = self.cc in ["7.5", "8.0", "8.6"]
if not has_cublaslt:
make_cmd += '_nomatmul'
self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:')
self.add_log_entry('git clone git@github.com:TimDettmers/bitsandbytes.git')
self.add_log_entry('cd bitsandbytes')
self.add_log_entry(make_cmd)
self.add_log_entry("<make_cmd here, commented out>")
self.add_log_entry('python setup.py install')
def initialize(self):
@ -60,37 +27,13 @@ class CUDASetup(object):
self.initialized = True
self.cuda_setup_log = []
from .cuda_setup.main import evaluate_cuda_setup
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
self.cudart_path = cudart_path
self.cuda = cuda
self.cc = cc
self.cuda_version_string = cuda_version_string
binary_name = "libbitsandbytes_hip.so"
package_dir = Path(__file__).parent
binary_path = package_dir / binary_name
try:
if not binary_path.exists():
self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?")
legacy_binary_name = "libbitsandbytes.so"
self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
binary_path = package_dir / legacy_binary_name
if not binary_path.exists():
self.add_log_entry('')
self.add_log_entry('='*48 + 'ERROR' + '='*37)
self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:')
self.add_log_entry('1. CUDA driver not installed')
self.add_log_entry('2. CUDA not installed')
self.add_log_entry('3. You have multiple conflicting CUDA libraries')
self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
self.add_log_entry('='*80)
self.add_log_entry('')
self.generate_instructions()
self.print_log_stack()
raise Exception('CUDA SETUP: Setup Failed!')
self.lib = ct.cdll.LoadLibrary(binary_path)
raise Exception('CUDA SETUP: Setup Failed!')
else:
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
self.lib = ct.cdll.LoadLibrary(binary_path)

View File

@ -1,2 +0,0 @@
from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path
from .main import evaluate_cuda_setup

View File

@ -1,51 +0,0 @@
import os
from typing import Dict
def to_be_ignored(env_var: str, value: str) -> bool:
ignorable = {
"PWD", # PWD: this is how the shell keeps track of the current working dir
"OLDPWD",
"SSH_AUTH_SOCK", # SSH stuff, therefore unrelated
"SSH_TTY",
"HOME", # Linux shell default
"TMUX", # Terminal Multiplexer
"XDG_DATA_DIRS", # XDG: Desktop environment stuff
"XDG_RUNTIME_DIR",
"MAIL", # something related to emails
"SHELL", # binary for currently invoked shell
"DBUS_SESSION_BUS_ADDRESS", # hardware related
"PATH", # this is for finding binaries, not libraries
"LESSOPEN", # related to the `less` command
"LESSCLOSE",
"_", # current Python interpreter
}
return env_var in ignorable
def might_contain_a_path(candidate: str) -> bool:
return "/" in candidate
def is_active_conda_env(env_var: str) -> bool:
return "CONDA_PREFIX" == env_var
def is_other_conda_env_var(env_var: str) -> bool:
return "CONDA" in env_var
def is_relevant_candidate_env_var(env_var: str, value: str) -> bool:
return is_active_conda_env(env_var) or (
might_contain_a_path(value) and not
is_other_conda_env_var(env_var) and not
to_be_ignored(env_var, value)
)
def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]:
return {
env_var: value
for env_var, value in os.environ.items()
if is_relevant_candidate_env_var(env_var, value)
}

View File

@ -1,163 +0,0 @@
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""
import ctypes
import torch
from .paths import determine_cuda_runtime_lib_path
from bitsandbytes.cextension import CUDASetup
def check_cuda_result(cuda, result_val):
# 3. Check for CUDA errors
if result_val != 0:
error_str = ctypes.c_char_p()
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}")
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
def get_cuda_version(cuda, cudart_path):
if cuda is None: return None
try:
cudart = ctypes.CDLL(cudart_path)
except OSError:
CUDASetup.get_instance().add_log_entry(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
return None
version = ctypes.c_int()
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
version = int(version.value)
major = version//1000
minor = (version-(major*1000))//10
if major < 11:
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
return f'{major}{minor}'
def get_cuda_lib_handle():
# 1. find libcuda.so library (GPU driver) (/usr/lib)
try:
cuda = ctypes.CDLL("libcuda.so")
except OSError:
CUDASetup.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
return None
check_cuda_result(cuda, cuda.cuInit(0))
return cuda
def get_compute_capabilities(cuda):
"""
1. find libcuda.so library (GPU driver) (/usr/lib)
init_device -> init variables -> call function by reference
2. call extern C function to determine CC
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
3. Check for CUDA errors
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
"""
nGpus = ctypes.c_int()
cc_major = ctypes.c_int()
cc_minor = ctypes.c_int()
device = ctypes.c_int()
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
ccs = []
for i in range(nGpus.value):
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
ref_major = ctypes.byref(cc_major)
ref_minor = ctypes.byref(cc_minor)
# 2. call extern C function to determine CC
check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device))
ccs.append(f"{cc_major.value}.{cc_minor.value}")
return ccs
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
def get_compute_capability(cuda):
"""
Extracts the highest compute capbility from all available GPUs, as compute
capabilities are downwards compatible. If no GPUs are detected, it returns
None.
"""
if cuda is None: return None
# TODO: handle different compute capabilities; for now, take the max
ccs = get_compute_capabilities(cuda)
if ccs: return ccs[-1]
def evaluate_cuda_setup():
# we remove this for now and see how things go
#print('')
#print('='*35 + 'BUG REPORT' + '='*35)
#print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
#print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
#print('='*80)
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None
cuda_setup = CUDASetup.get_instance()
cudart_path = determine_cuda_runtime_lib_path()
cuda = get_cuda_lib_handle()
cc = get_compute_capability(cuda)
cuda_version_string = get_cuda_version(cuda, cudart_path)
failure = False
if cudart_path is None:
failure = True
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
else:
cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}"))
if cc == '' or cc is None:
failure = True
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True)
else:
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
if cuda is None:
failure = True
else:
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
# 7.5 is the minimum CC vor cublaslt
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
# TODO:
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed
# we use ls -l instead of nvcc to determine the cuda version
# since most installations will have the libcudart.so installed, but not the compiler
if failure:
binary_name = "libbitsandbytes_cpu.so"
elif has_cublaslt:
binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so"
else:
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so"
return binary_name, cudart_path, cuda, cc, cuda_version_string

View File

@ -1,118 +0,0 @@
import errno
from pathlib import Path
from typing import Set, Union
from bitsandbytes.cextension import CUDASetup
from .env_vars import get_potentially_lib_path_containing_env_vars
CUDA_RUNTIME_LIB: str = "libcudart.so"
def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
existent_directories: Set[Path] = set()
for path in candidate_paths:
try:
if path.exists():
existent_directories.add(path)
except OSError as exc:
if exc.errno != errno.ENAMETOOLONG:
raise exc
non_existent_directories: Set[Path] = candidate_paths - existent_directories
if non_existent_directories:
CUDASetup.get_instance().add_log_entry("WARNING: The following directories listed in your path were found to "
f"be non-existent: {non_existent_directories}", is_warning=True)
return existent_directories
def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
return {
path / CUDA_RUNTIME_LIB
for path in candidate_paths
if (path / CUDA_RUNTIME_LIB).is_file()
}
def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
"""
Searches a given environmental var for the CUDA runtime library,
i.e. `libcudart.so`.
"""
return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate))
def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
return get_cuda_runtime_lib_paths(
resolve_paths_list(paths_list_candidate)
)
def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
if len(results_paths) > 1:
warning_msg = (
f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. "
"We'll flip a coin and try one of these, in order to fail forward.\n"
"Either way, this might cause trouble in the future:\n"
"If you get `CUDA error: invalid device function` errors, the above "
"might be the cause and the solution is to make sure only one "
f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.")
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
def determine_cuda_runtime_lib_path() -> Union[Path, None]:
"""
Searches for a cuda installations, in the following order of priority:
1. active conda env
2. LD_LIBRARY_PATH
3. any other env vars, while ignoring those that
- are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`)
- don't contain the path separator `/`
If multiple libraries are found in part 3, we optimistically try one,
while giving a warning message.
"""
candidate_env_vars = get_potentially_lib_path_containing_env_vars()
if "CONDA_PREFIX" in candidate_env_vars:
conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib"
conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path))
warn_in_case_of_duplicates(conda_cuda_libs)
if conda_cuda_libs:
return next(iter(conda_cuda_libs))
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
if "LD_LIBRARY_PATH" in candidate_env_vars:
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
if lib_ld_cuda_libs:
return next(iter(lib_ld_cuda_libs))
warn_in_case_of_duplicates(lib_ld_cuda_libs)
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
remaining_candidate_env_vars = {
env_var: value for env_var, value in candidate_env_vars.items()
if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
}
cuda_runtime_libs = set()
for env_var, value in remaining_candidate_env_vars.items():
cuda_runtime_libs.update(find_cuda_lib_in(value))
if len(cuda_runtime_libs) == 0:
CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
warn_in_case_of_duplicates(cuda_runtime_libs)
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None

View File

@ -7,4 +7,4 @@
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n);
#endif
#endif

View File

@ -3,6 +3,8 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <hip/hip_runtime.h>
#include <kernels.cuh>
#include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh>
@ -10,7 +12,7 @@
#include <cub/block/block_discontinuity.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include <hipcub/hipcub.hpp>
#include <math_constants.h>
#define HLF_MAX 65504
@ -232,9 +234,9 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index
template<typename T, int BLOCK_SIZE, int NUM_MAX>
__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n)
{
typedef cub::WarpReduce<T> WarpReduce;
typedef hipcub::WarpReduce<T> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage;
typedef cub::BlockLoad<T, BLOCK_SIZE/8 , 8, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/8 , 8, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename LoadT::TempStorage loadt;
const int warp_idx = threadIdx.x/32;
@ -324,8 +326,8 @@ __global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const f
T vals[NUM_ESTIMATE];
typedef cub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
typedef cub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
typedef hipcub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ union {
typename LoadFloat::TempStorage loadf;
@ -391,8 +393,8 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
unsigned char qvals[NUM];
//const int lane_id = threadIdx.x % 2;
typedef cub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
__shared__ typename LoadFloat::TempStorage loadf;
__shared__ typename StoreChar::TempStorage storec;
@ -442,10 +444,10 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
float local_abs_max = 0.0f;
int local_rand_idx = 0;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf;
@ -521,8 +523,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX;
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<T, THREADS, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
typedef hipcub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<T, THREADS, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;
@ -592,9 +594,9 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
const float correction1 = 1.0f/(1.0f - powf(beta1, step));
const float correction2 = 1.0f/(1.0f - powf(beta2, step));
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
@ -681,11 +683,11 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
}
else{ update_scale = 1.0f; }
typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef hipcub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
typedef hipcub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
__shared__ union {
typename Load::TempStorage load;
@ -755,9 +757,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float s1_vals[NUM_VALS];
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
@ -843,11 +845,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
float s1_vals[NUM_PER_THREAD];
typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef hipcub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
typedef hipcub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
__shared__ union {
typename Load::TempStorage load;
@ -939,9 +941,9 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
unsigned char m_c1[NUM8BIT];
unsigned char r_c2[NUM8BIT];
typedef cub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
typedef hipcub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef hipcub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union {
@ -1068,11 +1070,11 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
unsigned char c2s[NUM_PER_THREAD2];
T p_vals[NUM_PER_THREAD2];
T g_vals[NUM_PER_THREAD2];
typedef cub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
typedef hipcub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[256];
__shared__ float smem_quantiles2[256];
@ -1176,9 +1178,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
T g_vals[NUM8BIT];
unsigned char m_c1[NUM8BIT];
typedef cub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
typedef hipcub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef hipcub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union {
@ -1271,11 +1273,11 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
unsigned char c1s[NUM_PER_THREAD2];
T p_vals[NUM_PER_THREAD2];
T g_vals[NUM_PER_THREAD2];
typedef cub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
typedef hipcub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[256];
@ -1353,8 +1355,8 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
int valid_items = 0;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename BlockReduce::TempStorage reduce;
@ -1426,16 +1428,16 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
unsigned char c1s[N_PER_TH];
unsigned char c2s[N_PER_TH];
T g_vals[N_PER_TH];
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257];
__shared__ float smem_quantiles2[LANES][257];
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ typename BlockReduce2::TempStorage reduce2;
__shared__ float smem_exchange1[1];
@ -1599,14 +1601,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
T g_vals[N_PER_TH];
T p_vals[N_PER_TH];
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257];
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ float smem_exchange1[1];
@ -1756,10 +1758,10 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;
typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
typedef cub::BlockReduce<int, THREADS> BlockRowSum;
typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
typedef hipcub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
typedef hipcub::BlockReduce<float, THREADS> BlockRowReduce;
typedef hipcub::BlockReduce<int, THREADS> BlockRowSum;
typedef hipcub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
__shared__ union {
typename BlockExchange::TempStorage exchange;
@ -1945,8 +1947,8 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
float local_rowStats[ITEMS_PER_THREAD];
__shared__ float smem_rowStats[SUBTILE_ROWS];
typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
typedef hipcub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
typedef hipcub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
__shared__ typename LoadInt32::TempStorage loadint32;
__shared__ typename ExchangeInt32::TempStorage exchangeint32;
@ -2033,9 +2035,9 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;
typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
typedef hipcub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
__shared__ typename LoadHalf::TempStorage loadhalf;
typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
typedef hipcub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
__shared__ typename StoreInt8::TempStorage storeint8;
__shared__ float smem_row_stats[TILE_ROWS];
@ -2168,7 +2170,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
// so that we can have contiguous stores
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
char local_data[ITEMS_PER_THREAD];
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
typedef hipcub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
// we load row after row from the base_position
// Load data row by row

View File

@ -3,8 +3,10 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <hip/hip_runtime.h>
#include <float.h>
#include <ops.cuh>
#include "ops.cuh"
#ifndef kernels
#define kernels

View File

@ -3,14 +3,17 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <ops.cuh>
#include <kernels.cuh>
#include <hip/hip_runtime.h>
#include "ops.cuh"
#include "kernels.cuh"
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
// #include <BinSearch.h>
#include <AAlloc.h>
#include <BinAlgo.h>
#include <cassert>
#include <common.h>
// #include <common.h>
using namespace BinSearch;
using std::cout;
@ -22,16 +25,16 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr
int num_blocks = n/threads;
num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float)));
kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void quantize(float *code, float *A, unsigned char *out, int n)
@ -39,7 +42,7 @@ void quantize(float *code, float *A, unsigned char *out, int n)
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void dequantize(float *code, unsigned char *A, float *out, int n)
@ -47,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
@ -73,7 +76,7 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
@ -95,7 +98,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
else if(blocksize == 64)
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
@ -110,12 +113,12 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
case ADAM:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
@ -123,13 +126,13 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
}
}
@ -147,28 +150,28 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); }
switch(OPTIMIZER)
{
case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
default:
break;
@ -193,7 +196,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
@ -202,7 +205,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
}
}
@ -213,9 +216,9 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
{
int num_blocks = n/2048;
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
@ -224,17 +227,17 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
cublasStatus_t status;
hipblasStatus_t status;
status = cublasGemmEx(context->m_handle,
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
status = hipblasGemmEx(context->m_handle,
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k,
alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
C, CUDA_R_32I, ldc,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta,
C, HIPBLAS_R_32I, ldc,
HIPBLAS_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
if (status != CUBLAS_STATUS_SUCCESS)
if (status != HIPBLAS_STATUS_SUCCESS)
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
@ -248,7 +251,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
cublasStatus_t status;
hipblasStatus_t status;
//cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k);
@ -256,15 +259,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
//printf("%i %i %i\n", strideA, strideB, strideC);
//printf("%i\n", batchCount);
status = cublasGemmStridedBatchedEx(context->m_handle,
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
status = hipblasGemmStridedBatchedEx(context->m_handle,
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k,
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta,
C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount,
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
if (status != CUBLAS_STATUS_SUCCESS)
if (status != HIPBLAS_STATUS_SUCCESS)
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
@ -276,42 +279,6 @@ int roundoff(int v, int d) {
}
#ifdef NO_CUBLASLT
#else
template<int ORDER> cublasLtOrder_t get_order()
{
switch(ORDER)
{
case ROW:
return CUBLASLT_ORDER_ROW;
break;
case COL:
return CUBLASLT_ORDER_COL;
break;
case COL32:
return CUBLASLT_ORDER_COL32;
break;
case COL_TURING:
return CUBLASLT_ORDER_COL4_4R2_8C;
break;
case COL_AMPERE:
return CUBLASLT_ORDER_COL32_2R_4R4;
break;
default:
break;
}
return CUBLASLT_ORDER_ROW;
}
template cublasLtOrder_t get_order<ROW>();
template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();
#endif
template<int ORDER> int get_leading_dim(int dim1, int dim2)
{
switch(ORDER)
@ -345,47 +312,12 @@ template int get_leading_dim<COL32>(int dim1, int dim2);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
#ifdef NO_CUBLASLT
#else
cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
int ldOut = get_leading_dim<TARGET>(dim1, dim2);
cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
cublasOperation_t opTranspose = CUBLAS_OP_T;
float transformAlpha = 1.0f, transformBeta = 0.0f;
if(DTYPE == 8)
{
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
}
else if(DTYPE == 32)
{
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
}
else
{
printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
}
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
#endif
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);
}
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
@ -399,7 +331,6 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
#ifdef NO_CUBLASLT
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
@ -408,62 +339,6 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
assert(false);
return 0;
#else
int has_error = 0;
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasOperation_t opT = CUBLAS_OP_T;
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(FORMATB == COL_TURING)
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
else
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
if(DTYPE_OUT == 32)
{
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
int alpha = 1, beta = 0;
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
}
else
{
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(!SCALE_ROWS)
{
float alpha = 1.0f, beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
}
else
{
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
}
}
if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
if(has_error == 1)
printf("error detected");
return has_error;
#endif
}
int fill_up_to_nearest_multiple(int value, int multiple)
@ -484,7 +359,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
assert(threads <= tilesize);
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
#define STATS_THREADS 64
@ -505,7 +380,7 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
@ -529,7 +404,7 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
else
kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
@ -573,69 +448,27 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
}
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
#ifdef NO_CUBLASLT
#else
cout << "" << endl;
cout << "=============================================" << endl;
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
cout << "=============================================" << endl;
cout << "" << endl;
assert(false);
cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC;
float alpha = 1.0f;
float beta = 0.0f;
void *dBuffer = NULL;
size_t bufferSize = 0;
CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
A_rowidx, A_colidx, A_vals,
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
// Create dense matrix C
CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
// Create dense matrix B
if(transposed_B)
{
int tmp = A_cols;
A_cols = B_cols;
B_cols = tmp;
}
CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
// allocate an external buffer if needed
CHECK_CUSPARSE( cusparseSpMM_bufferSize(
handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
// execute SpMM
CHECK_CUSPARSE( cusparseSpMM(handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));
// destroy matrix/vector descriptors
CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
#endif
return;
}
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
@ -658,7 +491,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
}
kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
//==============================================================

View File

@ -12,29 +12,31 @@
#include <unistd.h>
#include <assert.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cusparse.h>
#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#include <hipblas.h>
// #include <cublasLt.h>
#include <hipsparse.h>
#include <vector>
#include <functional>
typedef struct cublasLtContext* cublasLtHandle_t;
#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
if (_m_cudaStat != cudaSuccess) { \
hipError_t _m_cudaStat = value; \
if (_m_cudaStat != hipSuccess) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
#define CHECK_CUSPARSE(value) { \
cusparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
hipsparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != HIPSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error <sparse error> at line %d in file %s\n", \
__LINE__, __FILE__); \
exit(1); \
} }
@ -42,15 +44,15 @@
#define THREADS_PER_BLOCKS (512)
inline void checkCudaStatus(cudaError_t status) {
if (status != cudaSuccess) {
printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
inline void checkCudaStatus(hipError_t status) {
if (status != hipSuccess) {
printf("cuda API failed with status %d: %s\n", status, hipGetErrorString(status));
throw std::logic_error("cuda API failed");
}
}
inline int checkCublasStatus(cublasStatus_t status) {
if (status != CUBLAS_STATUS_SUCCESS) {
inline int checkCublasStatus(hipblasStatus_t status) {
if (status != HIPBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %d\n", status);
//throw std::logic_error("cuBLAS API failed");
return 1;
@ -84,40 +86,40 @@ typedef enum Transform_t
class Context
{
public:
cublasHandle_t m_handle;
hipblasHandle_t m_handle;
Context()
{
cublasHandle_t handle;
cublasCreate_v2(&handle);
hipblasHandle_t handle;
hipblasCreate(&handle);
m_handle = handle;
}
};
class ContextLt
{
public:
cublasLtHandle_t m_handle;
// class ContextLt
// {
// public:
// cublasLtHandle_t m_handle;
ContextLt()
{
cublasLtHandle_t handle;
cublasLtCreate(&handle);
m_handle = handle;
}
// ContextLt()
// {
// cublasLtHandle_t handle;
// cublasLtCreate(&handle);
// m_handle = handle;
// }
};
// };
class ContextCusparse
{
public:
cusparseHandle_t m_handle;
hipsparseHandle_t m_handle;
ContextCusparse()
{
cusparseHandle_t handle;
cusparseCreate(&handle);
hipsparseHandle_t handle;
hipsparseCreate(&handle);
m_handle = handle;
}
@ -170,7 +172,7 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);

View File

@ -2,73 +2,75 @@
#include "Algo-Direct-Common.h"
namespace BinSearch {
namespace Details {
template <typename T, Algos A>
struct AlgoScalarBase<T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : DirectAux::DirectInfo<2, T, A>
namespace BinSearch
{
private:
typedef DirectAux::DirectInfo<2, T, A> base_t;
static const size_t Offset=2;
public:
AlgoScalarBase(const T* x, const uint32 n)
: base_t(x, n)
namespace Details
{
}
FORCE_INLINE uint32 scalar(T z) const
{
const T* px = base_t::data.xi;
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z);
uint32 iidx = buckets[bidx];
px += iidx;
if (z < *px)
--iidx;
if (z < *(px+1))
--iidx;
return iidx;
}
};
template <typename T, Algos A>
struct AlgoScalarBase<T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : DirectAux::DirectInfo<2, T, A>
{
private:
typedef DirectAux::DirectInfo<2, T, A> base_t;
static const size_t Offset = 2;
public:
AlgoScalarBase(const T *x, const uint32 n)
: base_t(x, n)
{
}
template <InstrSet I, typename T, Algos A>
struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : AlgoScalarBase<T, A>
{
static const uint32 nElem = sizeof(typename InstrFloatTraits<I, T>::vec_t) / sizeof(T);
FORCE_INLINE uint32 scalar(T z) const
{
const T *px = base_t::data.xi;
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z);
uint32 iidx = buckets[bidx];
px += iidx;
if (z < *px)
--iidx;
if (z < *(px + 1))
--iidx;
return iidx;
}
};
typedef FVec<I, T> fVec;
typedef IVec<SSE, T> i128;
template <InstrSet I, typename T, Algos A>
struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : AlgoScalarBase<T, A>
{
static const uint32 nElem = sizeof(typename InstrFloatTraits<I, T>::vec_t) / sizeof(T);
struct Constants
{
fVec vscaler;
fVec vcst0;
IVec<I, T> one;
};
typedef FVec<I, T> fVec;
typedef IVec<SSE, T> i128;
private:
typedef AlgoScalarBase<T, A> base_t;
struct Constants
{
fVec vscaler;
fVec vcst0;
IVec<I, T> one;
};
FORCE_INLINE
//NO_INLINE
void resolve(const FVec<SSE, float>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
{
union U {
__m128i vec;
uint32 ui32[4];
} u;
private:
typedef AlgoScalarBase<T, A> base_t;
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const float *xi = base_t::data.xi;
FORCE_INLINE
// NO_INLINE
void resolve(const FVec<SSE, float> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
{
union U
{
__m128i vec;
uint32 ui32[4];
} u;
// read indices t
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const float *xi = base_t::data.xi;
// read indices t
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
#if 0
// read pairs ( X(t-1), X(t) )
@ -87,65 +89,66 @@ private:
__m128 vxm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6));
__m128 vxp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6));
#else
__m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2));
__m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0));
__m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6));
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
__m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2));
__m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0));
__m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6));
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
#endif
IVec<SSE, float> i(u.vec);
IVec<SSE, float> vlem = vz < vxm;
IVec<SSE, float> vlep = vz < vxp;
i = i + vlem + vlep;
i.store(pr);
}
IVec<SSE, float> i(u.vec);
IVec<SSE, float> vlem = (vz < vxm);
IVec<SSE, float> vlep = (vz < vxp);
i = i + vlem + vlep;
i.store(pr);
}
FORCE_INLINE
//NO_INLINE
void resolve(const FVec<SSE, double>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
{
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const double *xi = base_t::data.xi;
FORCE_INLINE
// NO_INLINE
void resolve(const FVec<SSE, double> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
{
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const double *xi = base_t::data.xi;
uint32 b1 = buckets[bidx.get1()];
uint32 b0 = buckets[bidx.get0()];
uint32 b1 = buckets[bidx.get1()];
uint32 b0 = buckets[bidx.get0()];
const double *p1 = &xi[b1];
const double *p0 = &xi[b0];
const double *p1 = &xi[b1];
const double *p0 = &xi[b0];
// read pairs ( X(t-1), X(t) )
__m128d vx1 = _mm_loadu_pd(p1);
__m128d vx0 = _mm_loadu_pd(p0);
// read pairs ( X(t-1), X(t) )
__m128d vx1 = _mm_loadu_pd(p1);
__m128d vx0 = _mm_loadu_pd(p0);
// build:
// { X(t(0)-1), X(t(1)-1) }
// { X(t(0)), X(t(1)) }
__m128d vxm = _mm_shuffle_pd(vx0, vx1, 0);
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
// build:
// { X(t(0)-1), X(t(1)-1) }
// { X(t(0)), X(t(1)) }
__m128d vxm = _mm_shuffle_pd(vx0, vx1, 0);
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
IVec<SSE, double> i(b1, b0);
IVec<SSE, double> vlem = (vz < vxm);
IVec<SSE, double> vlep = (vz < vxp);
i = i + vlem + vlep;
IVec<SSE, double> i(b1, b0);
IVec<SSE, double> vlem = (vz < vxm);
IVec<SSE, double> vlep = (vz < vxp);
i = i + vlem + vlep;
union {
__m128i vec;
uint32 ui32[4];
} u;
u.vec = i;
pr[0] = u.ui32[0];
pr[1] = u.ui32[2];
}
union
{
__m128i vec;
uint32 ui32[4];
} u;
u.vec = i;
pr[0] = u.ui32[0];
pr[1] = u.ui32[2];
}
#ifdef USE_AVX
FORCE_INLINE
//NO_INLINE
void resolve(const FVec<AVX, float>& vz, const IVec<AVX, float>& bidx, uint32 *pr) const
{
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const float *xi = base_t::data.xi;
FORCE_INLINE
// NO_INLINE
void resolve(const FVec<AVX, float> &vz, const IVec<AVX, float> &bidx, uint32 *pr) const
{
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const float *xi = base_t::data.xi;
#if 0 // use gather instructions
#if 0 // use gather instructions
IVec<AVX,float> idxm;
idxm.setidx(buckets, bidx);
@ -159,21 +162,22 @@ private:
#else // do not use gather instrucions
union U {
__m256i vec;
uint32 ui32[8];
} u;
union U
{
__m256i vec;
uint32 ui32[8];
} u;
// read indices t
// read indices t
const double *p7 = reinterpret_cast<const double *>(&xi[(u.ui32[7] = buckets[bidx.get7()])]);
const double *p6 = reinterpret_cast<const double *>(&xi[(u.ui32[6] = buckets[bidx.get6()])]);
const double *p5 = reinterpret_cast<const double *>(&xi[(u.ui32[5] = buckets[bidx.get5()])]);
const double *p4 = reinterpret_cast<const double *>(&xi[(u.ui32[4] = buckets[bidx.get4()])]);
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
const double *p7 = reinterpret_cast<const double *>(&xi[(u.ui32[7] = buckets[bidx.get7()])]);
const double *p6 = reinterpret_cast<const double *>(&xi[(u.ui32[6] = buckets[bidx.get6()])]);
const double *p5 = reinterpret_cast<const double *>(&xi[(u.ui32[5] = buckets[bidx.get5()])]);
const double *p4 = reinterpret_cast<const double *>(&xi[(u.ui32[4] = buckets[bidx.get4()])]);
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
#if 0 // perform 8 loads in double precision
@ -210,96 +214,93 @@ private:
IVec<AVX, float> ip(u.vec);
#else // use __mm256_set_pd
#else // use __mm256_set_pd
// read pairs ( X(t-1), X(t) )
__m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) }
__m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) }
// read pairs ( X(t-1), X(t) )
__m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) }
__m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) }
// { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) }
FVec<AVX, float> vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6) );
// { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) }
FVec<AVX, float> vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6) );
// { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) }
FVec<AVX, float> vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6));
// { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) }
FVec<AVX, float> vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6));
IVec<AVX, float> ip(u.vec);
IVec<AVX, float> ip(u.vec);
#endif
#endif
IVec<AVX, float> vlem = vz < vxm;
IVec<AVX, float> vlep = vz < vxp;
ip = ip + vlem + vlep;
IVec<AVX, float> vlem = vz < vxm;
IVec<AVX, float> vlep = vz < vxp;
ip = ip + vlem + vlep;
ip.store(pr);
}
ip.store(pr);
}
FORCE_INLINE
// NO_INLINE
void resolve(const FVec<AVX, double> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
{
union
{
__m256i vec;
uint64 ui64[4];
} u;
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const double *xi = base_t::data.xi;
FORCE_INLINE
//NO_INLINE
void resolve(const FVec<AVX, double>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
{
union {
__m256i vec;
uint64 ui64[4];
} u;
// read indices t
const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])];
const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])];
const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])];
const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])];
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
const double *xi = base_t::data.xi;
// read pairs ( X(t-1), X(t) )
__m128d xp3 = _mm_loadu_pd(p3);
__m128d xp2 = _mm_loadu_pd(p2);
__m128d xp1 = _mm_loadu_pd(p1);
__m128d xp0 = _mm_loadu_pd(p0);
// read indices t
const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])];
const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])];
const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])];
const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])];
// build:
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
__m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1);
__m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1);
FVec<AVX, double> vxm = _mm256_unpacklo_pd(x02, x13);
FVec<AVX, double> vxp = _mm256_unpackhi_pd(x02, x13);
// read pairs ( X(t-1), X(t) )
__m128d xp3 = _mm_loadu_pd(p3);
__m128d xp2 = _mm_loadu_pd(p2);
__m128d xp1 = _mm_loadu_pd(p1);
__m128d xp0 = _mm_loadu_pd(p0);
// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0);
// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0);
// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3);
// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3);
// FVec<AVX, double> vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1);
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
// build:
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
__m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1);
__m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1);
FVec<AVX, double> vxm = _mm256_unpacklo_pd(x02,x13);
FVec<AVX, double> vxp = _mm256_unpackhi_pd(x02,x13);
// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0);
// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0);
// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3);
// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3);
// FVec<AVX, double> vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1);
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
IVec<AVX, double> i(u.vec);
IVec<AVX, double> vlem = vz < vxm;
IVec<AVX, double> vlep = vz < vxp;
i = i + vlem + vlep;
i.extractLo32s().store(pr);
}
IVec<AVX, double> i(u.vec);
IVec<AVX, double> vlem = vz < vxm;
IVec<AVX, double> vlep = vz < vxp;
i = i + vlem + vlep;
i.extractLo32s().store(pr);
}
#endif
public:
public:
AlgoVecBase(const T *x, const uint32 n) : base_t(x, n) {}
AlgoVecBase(const T* x, const uint32 n) : base_t(x, n) {}
void initConstants(Constants &cst) const
{
cst.vscaler.setN(base_t::data.scaler);
cst.vcst0.setN(base_t::data.cst0);
cst.one.setN(uint32(1));
}
void initConstants(Constants& cst) const
{
cst.vscaler.setN(base_t::data.scaler);
cst.vcst0.setN(base_t::data.cst0);
cst.one.setN(uint32(1));
}
void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
{
fVec vz(pz);
resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr);
}
};
} // namespace Details
void vectorial(uint32 *pr, const T *pz, const Constants &cst) const
{
fVec vz(pz);
resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr);
}
};
} // namespace Details
} // namespace BinSearch