forked from mrq/bitsandbytes-rocm
should be hippified, and all cuda checkes cleaned up, makefile not updated yet
This commit is contained in:
parent
c059bd2848
commit
2dcf38289d
17
Makefile
17
Makefile
|
@ -20,9 +20,10 @@ CSRC := $(ROOT_DIR)/csrc
|
|||
BUILD_DIR:= $(ROOT_DIR)/build
|
||||
|
||||
FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
|
||||
FILES_HIP := $(CSRC)/ops.cu $(CSRC)/kernels.cu
|
||||
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
|
||||
|
||||
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
|
||||
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
|
||||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
|
||||
# NVIDIA NVCC compilation flags
|
||||
|
@ -45,6 +46,7 @@ CC_CUDA10x += -gencode arch=compute_75,code=sm_75
|
|||
CC_CUDA110 := -gencode arch=compute_75,code=sm_75
|
||||
CC_CUDA110 += -gencode arch=compute_80,code=sm_80
|
||||
|
||||
# CC_CUDA11x := -gencode arch=compute_52,code=sm_52
|
||||
CC_CUDA11x := -gencode arch=compute_75,code=sm_75
|
||||
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
|
||||
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
||||
|
@ -52,22 +54,23 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
|||
CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
|
||||
CC_cublasLt110 += -gencode arch=compute_80,code=sm_80
|
||||
|
||||
# CC_cublasLt111 := -gencode arch=compute_52,code=sm_52
|
||||
CC_cublasLt111 := -gencode arch=compute_75,code=sm_75
|
||||
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
|
||||
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
|
||||
|
||||
|
||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
all: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
cuda92: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
cuda10x_nomatmul: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
@ -112,9 +115,9 @@ $(BUILD_DIR):
|
|||
mkdir -p build
|
||||
mkdir -p dependencies
|
||||
|
||||
$(ROOT_DIR)/dependencies/cub:
|
||||
git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub
|
||||
cd dependencies/cub; git checkout 1.11.0
|
||||
# $(ROOT_DIR)/dependencies/cub:
|
||||
# git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub
|
||||
# cd dependencies/cub; git checkout 1.11.0
|
||||
|
||||
clean:
|
||||
rm build/*
|
||||
|
|
|
@ -12,43 +12,10 @@ class CUDASetup(object):
|
|||
raise RuntimeError("Call get_instance() instead")
|
||||
|
||||
def generate_instructions(self):
|
||||
if self.cuda is None:
|
||||
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.')
|
||||
self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.')
|
||||
self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:')
|
||||
self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null')
|
||||
self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a')
|
||||
self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc')
|
||||
return
|
||||
|
||||
if self.cudart_path is None:
|
||||
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.')
|
||||
self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable')
|
||||
self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null')
|
||||
self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a')
|
||||
self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc')
|
||||
self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.')
|
||||
self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://github.com/TimDettmers/bitsandbytes/blob/main/cuda_install.sh')
|
||||
self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.')
|
||||
self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local')
|
||||
return
|
||||
|
||||
make_cmd = f'CUDA_VERSION={self.cuda_version_string}'
|
||||
if len(self.cuda_version_string) < 3:
|
||||
make_cmd += ' make cuda92'
|
||||
elif self.cuda_version_string == '110':
|
||||
make_cmd += ' make cuda110'
|
||||
elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0:
|
||||
make_cmd += ' make cuda11x'
|
||||
|
||||
has_cublaslt = self.cc in ["7.5", "8.0", "8.6"]
|
||||
if not has_cublaslt:
|
||||
make_cmd += '_nomatmul'
|
||||
|
||||
self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:')
|
||||
self.add_log_entry('git clone git@github.com:TimDettmers/bitsandbytes.git')
|
||||
self.add_log_entry('cd bitsandbytes')
|
||||
self.add_log_entry(make_cmd)
|
||||
self.add_log_entry("<make_cmd here, commented out>")
|
||||
self.add_log_entry('python setup.py install')
|
||||
|
||||
def initialize(self):
|
||||
|
@ -60,37 +27,13 @@ class CUDASetup(object):
|
|||
self.initialized = True
|
||||
self.cuda_setup_log = []
|
||||
|
||||
from .cuda_setup.main import evaluate_cuda_setup
|
||||
binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup()
|
||||
self.cudart_path = cudart_path
|
||||
self.cuda = cuda
|
||||
self.cc = cc
|
||||
self.cuda_version_string = cuda_version_string
|
||||
|
||||
binary_name = "libbitsandbytes_hip.so"
|
||||
package_dir = Path(__file__).parent
|
||||
binary_path = package_dir / binary_name
|
||||
|
||||
try:
|
||||
if not binary_path.exists():
|
||||
self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?")
|
||||
legacy_binary_name = "libbitsandbytes.so"
|
||||
self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
|
||||
binary_path = package_dir / legacy_binary_name
|
||||
if not binary_path.exists():
|
||||
self.add_log_entry('')
|
||||
self.add_log_entry('='*48 + 'ERROR' + '='*37)
|
||||
self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:')
|
||||
self.add_log_entry('1. CUDA driver not installed')
|
||||
self.add_log_entry('2. CUDA not installed')
|
||||
self.add_log_entry('3. You have multiple conflicting CUDA libraries')
|
||||
self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!')
|
||||
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
|
||||
self.add_log_entry('='*80)
|
||||
self.add_log_entry('')
|
||||
self.generate_instructions()
|
||||
self.print_log_stack()
|
||||
raise Exception('CUDA SETUP: Setup Failed!')
|
||||
self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||
raise Exception('CUDA SETUP: Setup Failed!')
|
||||
else:
|
||||
self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...")
|
||||
self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
from .paths import CUDA_RUNTIME_LIB, extract_candidate_paths, determine_cuda_runtime_lib_path
|
||||
from .main import evaluate_cuda_setup
|
|
@ -1,51 +0,0 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def to_be_ignored(env_var: str, value: str) -> bool:
|
||||
ignorable = {
|
||||
"PWD", # PWD: this is how the shell keeps track of the current working dir
|
||||
"OLDPWD",
|
||||
"SSH_AUTH_SOCK", # SSH stuff, therefore unrelated
|
||||
"SSH_TTY",
|
||||
"HOME", # Linux shell default
|
||||
"TMUX", # Terminal Multiplexer
|
||||
"XDG_DATA_DIRS", # XDG: Desktop environment stuff
|
||||
"XDG_RUNTIME_DIR",
|
||||
"MAIL", # something related to emails
|
||||
"SHELL", # binary for currently invoked shell
|
||||
"DBUS_SESSION_BUS_ADDRESS", # hardware related
|
||||
"PATH", # this is for finding binaries, not libraries
|
||||
"LESSOPEN", # related to the `less` command
|
||||
"LESSCLOSE",
|
||||
"_", # current Python interpreter
|
||||
}
|
||||
return env_var in ignorable
|
||||
|
||||
|
||||
def might_contain_a_path(candidate: str) -> bool:
|
||||
return "/" in candidate
|
||||
|
||||
|
||||
def is_active_conda_env(env_var: str) -> bool:
|
||||
return "CONDA_PREFIX" == env_var
|
||||
|
||||
|
||||
def is_other_conda_env_var(env_var: str) -> bool:
|
||||
return "CONDA" in env_var
|
||||
|
||||
|
||||
def is_relevant_candidate_env_var(env_var: str, value: str) -> bool:
|
||||
return is_active_conda_env(env_var) or (
|
||||
might_contain_a_path(value) and not
|
||||
is_other_conda_env_var(env_var) and not
|
||||
to_be_ignored(env_var, value)
|
||||
)
|
||||
|
||||
|
||||
def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]:
|
||||
return {
|
||||
env_var: value
|
||||
for env_var, value in os.environ.items()
|
||||
if is_relevant_candidate_env_var(env_var, value)
|
||||
}
|
|
@ -1,163 +0,0 @@
|
|||
"""
|
||||
extract factors the build is dependent on:
|
||||
[X] compute capability
|
||||
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
||||
- CUDA version
|
||||
- Software:
|
||||
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
|
||||
- CuBLAS-LT: full-build 8-bit optimizer
|
||||
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
|
||||
|
||||
evaluation:
|
||||
- if paths faulty, return meaningful error
|
||||
- else:
|
||||
- determine CUDA version
|
||||
- determine capabilities
|
||||
- based on that set the default path
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import torch
|
||||
|
||||
from .paths import determine_cuda_runtime_lib_path
|
||||
from bitsandbytes.cextension import CUDASetup
|
||||
|
||||
|
||||
def check_cuda_result(cuda, result_val):
|
||||
# 3. Check for CUDA errors
|
||||
if result_val != 0:
|
||||
error_str = ctypes.c_char_p()
|
||||
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
||||
CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}")
|
||||
|
||||
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
||||
def get_cuda_version(cuda, cudart_path):
|
||||
if cuda is None: return None
|
||||
|
||||
try:
|
||||
cudart = ctypes.CDLL(cudart_path)
|
||||
except OSError:
|
||||
CUDASetup.get_instance().add_log_entry(f'ERROR: libcudart.so could not be read from path: {cudart_path}!')
|
||||
return None
|
||||
|
||||
version = ctypes.c_int()
|
||||
check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ctypes.byref(version)))
|
||||
version = int(version.value)
|
||||
major = version//1000
|
||||
minor = (version-(major*1000))//10
|
||||
|
||||
if major < 11:
|
||||
CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currenlty not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!')
|
||||
|
||||
return f'{major}{minor}'
|
||||
|
||||
|
||||
def get_cuda_lib_handle():
|
||||
# 1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||
try:
|
||||
cuda = ctypes.CDLL("libcuda.so")
|
||||
except OSError:
|
||||
CUDASetup.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
|
||||
return None
|
||||
check_cuda_result(cuda, cuda.cuInit(0))
|
||||
|
||||
return cuda
|
||||
|
||||
|
||||
def get_compute_capabilities(cuda):
|
||||
"""
|
||||
1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||
init_device -> init variables -> call function by reference
|
||||
2. call extern C function to determine CC
|
||||
(https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
||||
3. Check for CUDA errors
|
||||
https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
|
||||
# bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
||||
"""
|
||||
|
||||
nGpus = ctypes.c_int()
|
||||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
||||
device = ctypes.c_int()
|
||||
|
||||
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
||||
ccs = []
|
||||
for i in range(nGpus.value):
|
||||
check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
|
||||
ref_major = ctypes.byref(cc_major)
|
||||
ref_minor = ctypes.byref(cc_minor)
|
||||
# 2. call extern C function to determine CC
|
||||
check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device))
|
||||
ccs.append(f"{cc_major.value}.{cc_minor.value}")
|
||||
|
||||
return ccs
|
||||
|
||||
|
||||
# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error
|
||||
def get_compute_capability(cuda):
|
||||
"""
|
||||
Extracts the highest compute capbility from all available GPUs, as compute
|
||||
capabilities are downwards compatible. If no GPUs are detected, it returns
|
||||
None.
|
||||
"""
|
||||
if cuda is None: return None
|
||||
|
||||
# TODO: handle different compute capabilities; for now, take the max
|
||||
ccs = get_compute_capabilities(cuda)
|
||||
if ccs: return ccs[-1]
|
||||
|
||||
|
||||
def evaluate_cuda_setup():
|
||||
# we remove this for now and see how things go
|
||||
#print('')
|
||||
#print('='*35 + 'BUG REPORT' + '='*35)
|
||||
#print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
||||
#print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
||||
#print('='*80)
|
||||
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None
|
||||
|
||||
cuda_setup = CUDASetup.get_instance()
|
||||
cudart_path = determine_cuda_runtime_lib_path()
|
||||
cuda = get_cuda_lib_handle()
|
||||
cc = get_compute_capability(cuda)
|
||||
cuda_version_string = get_cuda_version(cuda, cudart_path)
|
||||
|
||||
failure = False
|
||||
if cudart_path is None:
|
||||
failure = True
|
||||
cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True)
|
||||
else:
|
||||
cuda_setup.add_log_entry((f"CUDA SETUP: CUDA runtime path found: {cudart_path}"))
|
||||
|
||||
if cc == '' or cc is None:
|
||||
failure = True
|
||||
cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True)
|
||||
else:
|
||||
cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}")
|
||||
|
||||
if cuda is None:
|
||||
failure = True
|
||||
else:
|
||||
cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}')
|
||||
|
||||
# 7.5 is the minimum CC vor cublaslt
|
||||
has_cublaslt = cc in ["7.5", "8.0", "8.6"]
|
||||
|
||||
# TODO:
|
||||
# (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
|
||||
# (2) Multiple CUDA versions installed
|
||||
|
||||
# we use ls -l instead of nvcc to determine the cuda version
|
||||
# since most installations will have the libcudart.so installed, but not the compiler
|
||||
|
||||
if failure:
|
||||
binary_name = "libbitsandbytes_cpu.so"
|
||||
elif has_cublaslt:
|
||||
binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so"
|
||||
else:
|
||||
"if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so"
|
||||
binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so"
|
||||
|
||||
return binary_name, cudart_path, cuda, cc, cuda_version_string
|
|
@ -1,118 +0,0 @@
|
|||
import errno
|
||||
from pathlib import Path
|
||||
from typing import Set, Union
|
||||
from bitsandbytes.cextension import CUDASetup
|
||||
|
||||
from .env_vars import get_potentially_lib_path_containing_env_vars
|
||||
|
||||
CUDA_RUNTIME_LIB: str = "libcudart.so"
|
||||
|
||||
|
||||
def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
|
||||
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
|
||||
|
||||
|
||||
def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
|
||||
existent_directories: Set[Path] = set()
|
||||
for path in candidate_paths:
|
||||
try:
|
||||
if path.exists():
|
||||
existent_directories.add(path)
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.ENAMETOOLONG:
|
||||
raise exc
|
||||
|
||||
non_existent_directories: Set[Path] = candidate_paths - existent_directories
|
||||
if non_existent_directories:
|
||||
CUDASetup.get_instance().add_log_entry("WARNING: The following directories listed in your path were found to "
|
||||
f"be non-existent: {non_existent_directories}", is_warning=True)
|
||||
|
||||
return existent_directories
|
||||
|
||||
|
||||
def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
|
||||
return {
|
||||
path / CUDA_RUNTIME_LIB
|
||||
for path in candidate_paths
|
||||
if (path / CUDA_RUNTIME_LIB).is_file()
|
||||
}
|
||||
|
||||
|
||||
def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
|
||||
"""
|
||||
Searches a given environmental var for the CUDA runtime library,
|
||||
i.e. `libcudart.so`.
|
||||
"""
|
||||
return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate))
|
||||
|
||||
|
||||
def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
|
||||
return get_cuda_runtime_lib_paths(
|
||||
resolve_paths_list(paths_list_candidate)
|
||||
)
|
||||
|
||||
|
||||
def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
|
||||
if len(results_paths) > 1:
|
||||
warning_msg = (
|
||||
f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. "
|
||||
"We'll flip a coin and try one of these, in order to fail forward.\n"
|
||||
"Either way, this might cause trouble in the future:\n"
|
||||
"If you get `CUDA error: invalid device function` errors, the above "
|
||||
"might be the cause and the solution is to make sure only one "
|
||||
f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.")
|
||||
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
|
||||
|
||||
|
||||
def determine_cuda_runtime_lib_path() -> Union[Path, None]:
|
||||
"""
|
||||
Searches for a cuda installations, in the following order of priority:
|
||||
1. active conda env
|
||||
2. LD_LIBRARY_PATH
|
||||
3. any other env vars, while ignoring those that
|
||||
- are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`)
|
||||
- don't contain the path separator `/`
|
||||
|
||||
If multiple libraries are found in part 3, we optimistically try one,
|
||||
while giving a warning message.
|
||||
"""
|
||||
candidate_env_vars = get_potentially_lib_path_containing_env_vars()
|
||||
|
||||
if "CONDA_PREFIX" in candidate_env_vars:
|
||||
conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib"
|
||||
|
||||
conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path))
|
||||
warn_in_case_of_duplicates(conda_cuda_libs)
|
||||
|
||||
if conda_cuda_libs:
|
||||
return next(iter(conda_cuda_libs))
|
||||
|
||||
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
|
||||
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
|
||||
|
||||
if "LD_LIBRARY_PATH" in candidate_env_vars:
|
||||
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
|
||||
|
||||
if lib_ld_cuda_libs:
|
||||
return next(iter(lib_ld_cuda_libs))
|
||||
warn_in_case_of_duplicates(lib_ld_cuda_libs)
|
||||
|
||||
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
|
||||
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
|
||||
|
||||
remaining_candidate_env_vars = {
|
||||
env_var: value for env_var, value in candidate_env_vars.items()
|
||||
if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
|
||||
}
|
||||
|
||||
cuda_runtime_libs = set()
|
||||
for env_var, value in remaining_candidate_env_vars.items():
|
||||
cuda_runtime_libs.update(find_cuda_lib_in(value))
|
||||
|
||||
if len(cuda_runtime_libs) == 0:
|
||||
CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
|
||||
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
|
||||
|
||||
warn_in_case_of_duplicates(cuda_runtime_libs)
|
||||
|
||||
return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None
|
|
@ -7,4 +7,4 @@
|
|||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n);
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n);
|
||||
|
||||
#endif
|
||||
#endif
|
128
csrc/kernels.cu
128
csrc/kernels.cu
|
@ -3,6 +3,8 @@
|
|||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <kernels.cuh>
|
||||
#include <cub/block/block_radix_sort.cuh>
|
||||
#include <cub/warp/warp_reduce.cuh>
|
||||
|
@ -10,7 +12,7 @@
|
|||
#include <cub/block/block_discontinuity.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <math_constants.h>
|
||||
|
||||
#define HLF_MAX 65504
|
||||
|
@ -232,9 +234,9 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index
|
|||
template<typename T, int BLOCK_SIZE, int NUM_MAX>
|
||||
__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n)
|
||||
{
|
||||
typedef cub::WarpReduce<T> WarpReduce;
|
||||
typedef hipcub::WarpReduce<T> WarpReduce;
|
||||
__shared__ typename WarpReduce::TempStorage temp_storage;
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/8 , 8, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockLoad<T, BLOCK_SIZE/8 , 8, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
__shared__ typename LoadT::TempStorage loadt;
|
||||
|
||||
const int warp_idx = threadIdx.x/32;
|
||||
|
@ -324,8 +326,8 @@ __global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const f
|
|||
|
||||
T vals[NUM_ESTIMATE];
|
||||
|
||||
typedef cub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
|
||||
typedef cub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef hipcub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
|
||||
typedef hipcub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
|
||||
__shared__ union {
|
||||
typename LoadFloat::TempStorage loadf;
|
||||
|
@ -391,8 +393,8 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
unsigned char qvals[NUM];
|
||||
//const int lane_id = threadIdx.x % 2;
|
||||
|
||||
typedef cub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef cub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef hipcub::BlockLoad<float, TH, NUM, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef hipcub::BlockStore<unsigned char, TH, NUM, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
|
||||
__shared__ typename LoadFloat::TempStorage loadf;
|
||||
__shared__ typename StoreChar::TempStorage storec;
|
||||
|
@ -442,10 +444,10 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
|
|||
float local_abs_max = 0.0f;
|
||||
int local_rand_idx = 0;
|
||||
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
|
||||
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
|
||||
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
|
||||
__shared__ typename LoadT::TempStorage loadt;
|
||||
__shared__ typename LoadFloat::TempStorage loadf;
|
||||
|
@ -521,8 +523,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
|
|||
unsigned char qvals[NUM_PER_TH];
|
||||
float local_abs_max = -FLT_MAX;
|
||||
|
||||
typedef cub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
typedef cub::BlockStore<T, THREADS, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
typedef hipcub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
typedef hipcub::BlockStore<T, THREADS, NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
|
||||
__shared__ typename LoadChar::TempStorage loadchar;
|
||||
__shared__ typename StoreT::TempStorage storet;
|
||||
|
@ -592,9 +594,9 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
|||
const float correction1 = 1.0f/(1.0f - powf(beta1, step));
|
||||
const float correction2 = 1.0f/(1.0f - powf(beta2, step));
|
||||
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
|
||||
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
|
||||
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
|
||||
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
|
||||
|
||||
__shared__ union {
|
||||
typename Load::TempStorage load;
|
||||
|
@ -681,11 +683,11 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
|
|||
}
|
||||
else{ update_scale = 1.0f; }
|
||||
|
||||
typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
|
||||
typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
|
||||
typedef hipcub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
|
||||
typedef hipcub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
|
||||
|
||||
typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
|
||||
typedef hipcub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef hipcub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
|
||||
|
||||
__shared__ union {
|
||||
typename Load::TempStorage load;
|
||||
|
@ -755,9 +757,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
|
||||
float s1_vals[NUM_VALS];
|
||||
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
|
||||
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
|
||||
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
|
||||
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
|
||||
|
||||
__shared__ union {
|
||||
typename Load::TempStorage load;
|
||||
|
@ -843,11 +845,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
|
||||
float s1_vals[NUM_PER_THREAD];
|
||||
|
||||
typedef cub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
|
||||
typedef cub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
|
||||
typedef hipcub::BlockLoad<T, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
|
||||
typedef hipcub::BlockStore<T, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> Store;
|
||||
|
||||
typedef cub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef cub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
|
||||
typedef hipcub::BlockLoad<float, TH, NUM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
|
||||
typedef hipcub::BlockStore<float, TH, NUM_PER_THREAD, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
|
||||
|
||||
__shared__ union {
|
||||
typename Load::TempStorage load;
|
||||
|
@ -939,9 +941,9 @@ kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned c
|
|||
unsigned char m_c1[NUM8BIT];
|
||||
unsigned char r_c2[NUM8BIT];
|
||||
|
||||
typedef cub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
|
||||
typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
|
||||
typedef hipcub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
|
||||
typedef hipcub::BlockReduce<float, NUM_THREADS> BlockReduce;
|
||||
|
||||
|
||||
__shared__ union {
|
||||
|
@ -1068,11 +1070,11 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
|
|||
unsigned char c2s[NUM_PER_THREAD2];
|
||||
T p_vals[NUM_PER_THREAD2];
|
||||
T g_vals[NUM_PER_THREAD2];
|
||||
typedef cub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
typedef hipcub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
|
||||
typedef cub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
typedef hipcub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef hipcub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
|
||||
__shared__ float smem_quantiles1[256];
|
||||
__shared__ float smem_quantiles2[256];
|
||||
|
@ -1176,9 +1178,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
T g_vals[NUM8BIT];
|
||||
unsigned char m_c1[NUM8BIT];
|
||||
|
||||
typedef cub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
|
||||
typedef cub::BlockReduce<float, NUM_THREADS> BlockReduce;
|
||||
typedef hipcub::BlockLoad<T, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
|
||||
typedef hipcub::BlockReduce<float, NUM_THREADS> BlockReduce;
|
||||
|
||||
|
||||
__shared__ union {
|
||||
|
@ -1271,11 +1273,11 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
unsigned char c1s[NUM_PER_THREAD2];
|
||||
T p_vals[NUM_PER_THREAD2];
|
||||
T g_vals[NUM_PER_THREAD2];
|
||||
typedef cub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
typedef hipcub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
|
||||
typedef cub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
typedef hipcub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef hipcub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
|
||||
__shared__ float smem_quantiles1[256];
|
||||
|
||||
|
@ -1353,8 +1355,8 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
|
|||
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
|
||||
int valid_items = 0;
|
||||
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
|
||||
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
|
||||
__shared__ typename BlockReduce::TempStorage reduce;
|
||||
|
||||
|
@ -1426,16 +1428,16 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
unsigned char c1s[N_PER_TH];
|
||||
unsigned char c2s[N_PER_TH];
|
||||
T g_vals[N_PER_TH];
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
typedef hipcub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
|
||||
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef hipcub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
|
||||
__shared__ float smem_quantiles1[LANES][257];
|
||||
__shared__ float smem_quantiles2[LANES][257];
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
|
||||
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
|
||||
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
|
||||
__shared__ typename BlockReduce1::TempStorage reduce1;
|
||||
__shared__ typename BlockReduce2::TempStorage reduce2;
|
||||
__shared__ float smem_exchange1[1];
|
||||
|
@ -1599,14 +1601,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
T g_vals[N_PER_TH];
|
||||
T p_vals[N_PER_TH];
|
||||
|
||||
typedef cub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef cub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
typedef hipcub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
|
||||
typedef hipcub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
|
||||
|
||||
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef cub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
|
||||
typedef hipcub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
|
||||
|
||||
__shared__ float smem_quantiles1[LANES][257];
|
||||
typedef cub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
|
||||
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
|
||||
__shared__ typename BlockReduce1::TempStorage reduce1;
|
||||
__shared__ float smem_exchange1[1];
|
||||
|
||||
|
@ -1756,10 +1758,10 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_
|
|||
const int base_idx = (base_row*cols) + base_col;
|
||||
const int items_per_load = ITEMS_PER_THREAD*THREADS;
|
||||
|
||||
typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
|
||||
typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
|
||||
typedef cub::BlockReduce<int, THREADS> BlockRowSum;
|
||||
typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
|
||||
typedef hipcub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
|
||||
typedef hipcub::BlockReduce<float, THREADS> BlockRowReduce;
|
||||
typedef hipcub::BlockReduce<int, THREADS> BlockRowSum;
|
||||
typedef hipcub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
|
||||
|
||||
__shared__ union {
|
||||
typename BlockExchange::TempStorage exchange;
|
||||
|
@ -1945,8 +1947,8 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
|
|||
float local_rowStats[ITEMS_PER_THREAD];
|
||||
__shared__ float smem_rowStats[SUBTILE_ROWS];
|
||||
|
||||
typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
|
||||
typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
|
||||
typedef hipcub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
|
||||
typedef hipcub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
|
||||
__shared__ typename LoadInt32::TempStorage loadint32;
|
||||
__shared__ typename ExchangeInt32::TempStorage exchangeint32;
|
||||
|
||||
|
@ -2033,9 +2035,9 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int S
|
|||
const int base_idx = (base_row*cols) + base_col;
|
||||
const int items_per_load = ITEMS_PER_THREAD*THREADS;
|
||||
|
||||
typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
|
||||
typedef hipcub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
|
||||
__shared__ typename LoadHalf::TempStorage loadhalf;
|
||||
typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
|
||||
typedef hipcub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
|
||||
__shared__ typename StoreInt8::TempStorage storeint8;
|
||||
|
||||
__shared__ float smem_row_stats[TILE_ROWS];
|
||||
|
@ -2168,7 +2170,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
// so that we can have contiguous stores
|
||||
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
|
||||
char local_data[ITEMS_PER_THREAD];
|
||||
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
|
||||
typedef hipcub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
|
||||
|
||||
// we load row after row from the base_position
|
||||
// Load data row by row
|
||||
|
|
|
@ -3,8 +3,10 @@
|
|||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <float.h>
|
||||
#include <ops.cuh>
|
||||
#include "ops.cuh"
|
||||
|
||||
#ifndef kernels
|
||||
#define kernels
|
||||
|
|
305
csrc/ops.cu
305
csrc/ops.cu
|
@ -3,14 +3,17 @@
|
|||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <ops.cuh>
|
||||
#include <kernels.cuh>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ops.cuh"
|
||||
#include "kernels.cuh"
|
||||
#include <cub/device/device_scan.cuh>
|
||||
#include <limits>
|
||||
#include <BinSearch.h>
|
||||
// #include <BinSearch.h>
|
||||
#include <AAlloc.h>
|
||||
#include <BinAlgo.h>
|
||||
#include <cassert>
|
||||
#include <common.h>
|
||||
|
||||
// #include <common.h>
|
||||
|
||||
using namespace BinSearch;
|
||||
using std::cout;
|
||||
|
@ -22,16 +25,16 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr
|
|||
int num_blocks = n/threads;
|
||||
num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1;
|
||||
kHistogramScatterAdd2D<<<num_blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
|
||||
{
|
||||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(hipMemset(code, 0, 256*sizeof(float)));
|
||||
kEstimateQuantiles<T><<<num_blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n)
|
||||
|
@ -39,7 +42,7 @@ void quantize(float *code, float *A, unsigned char *out, int n)
|
|||
int num_blocks = n/1024;
|
||||
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
|
||||
kQuantize<<<num_blocks, 1024>>>(code, A, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n)
|
||||
|
@ -47,7 +50,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
|
|||
int num_blocks = n/1024;
|
||||
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
|
||||
kDequantize<<<num_blocks, 1024>>>(code, A, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
|
@ -73,7 +76,7 @@ template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A,
|
|||
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
|
@ -95,7 +98,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
|
|||
else if(blocksize == 64)
|
||||
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
|
@ -110,12 +113,12 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
case ADAM:
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
|
@ -123,13 +126,13 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -147,28 +150,28 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
int num_blocks = n/4096;
|
||||
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
|
||||
|
||||
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
|
||||
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); }
|
||||
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
kOptimizerStatic8bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
@ -193,7 +196,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
|
||||
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
|
@ -202,7 +205,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -213,9 +216,9 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
|
|||
{
|
||||
int num_blocks = n/2048;
|
||||
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
|
||||
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
|
||||
kPercentileClipping<T, 2048, 4><<<num_blocks, 512>>>(g, gnorm_vec, step, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
|
||||
|
@ -224,17 +227,17 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in
|
|||
const int fbeta = 0;
|
||||
const void * alpha = &falpha;
|
||||
const void * beta = &fbeta;
|
||||
cublasStatus_t status;
|
||||
hipblasStatus_t status;
|
||||
|
||||
status = cublasGemmEx(context->m_handle,
|
||||
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
status = hipblasGemmEx(context->m_handle,
|
||||
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
|
||||
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
|
||||
m, n, k,
|
||||
alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
|
||||
C, CUDA_R_32I, ldc,
|
||||
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta,
|
||||
C, HIPBLAS_R_32I, ldc,
|
||||
HIPBLAS_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
if (status != HIPBLAS_STATUS_SUCCESS)
|
||||
{
|
||||
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
|
||||
}
|
||||
|
@ -248,7 +251,7 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
|
|||
const int fbeta = 0;
|
||||
const void * alpha = &falpha;
|
||||
const void * beta = &fbeta;
|
||||
cublasStatus_t status;
|
||||
hipblasStatus_t status;
|
||||
|
||||
//cout << transposeA << transposeB << endl;
|
||||
//printf("%i %i %i\n", m,n,k);
|
||||
|
@ -256,15 +259,15 @@ void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, i
|
|||
//printf("%i %i %i\n", strideA, strideB, strideC);
|
||||
//printf("%i\n", batchCount);
|
||||
|
||||
status = cublasGemmStridedBatchedEx(context->m_handle,
|
||||
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
status = hipblasGemmStridedBatchedEx(context->m_handle,
|
||||
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
|
||||
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
|
||||
m, n, k,
|
||||
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
|
||||
C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
|
||||
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
|
||||
alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta,
|
||||
C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount,
|
||||
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
if (status != HIPBLAS_STATUS_SUCCESS)
|
||||
{
|
||||
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
|
||||
}
|
||||
|
@ -276,42 +279,6 @@ int roundoff(int v, int d) {
|
|||
}
|
||||
|
||||
|
||||
#ifdef NO_CUBLASLT
|
||||
#else
|
||||
template<int ORDER> cublasLtOrder_t get_order()
|
||||
{
|
||||
switch(ORDER)
|
||||
{
|
||||
case ROW:
|
||||
return CUBLASLT_ORDER_ROW;
|
||||
break;
|
||||
case COL:
|
||||
return CUBLASLT_ORDER_COL;
|
||||
break;
|
||||
case COL32:
|
||||
return CUBLASLT_ORDER_COL32;
|
||||
break;
|
||||
case COL_TURING:
|
||||
return CUBLASLT_ORDER_COL4_4R2_8C;
|
||||
break;
|
||||
case COL_AMPERE:
|
||||
return CUBLASLT_ORDER_COL32_2R_4R4;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return CUBLASLT_ORDER_ROW;
|
||||
}
|
||||
|
||||
template cublasLtOrder_t get_order<ROW>();
|
||||
template cublasLtOrder_t get_order<COL>();
|
||||
template cublasLtOrder_t get_order<COL32>();
|
||||
template cublasLtOrder_t get_order<COL_TURING>();
|
||||
template cublasLtOrder_t get_order<COL_AMPERE>();
|
||||
#endif
|
||||
|
||||
|
||||
template<int ORDER> int get_leading_dim(int dim1, int dim2)
|
||||
{
|
||||
switch(ORDER)
|
||||
|
@ -345,47 +312,12 @@ template int get_leading_dim<COL32>(int dim1, int dim2);
|
|||
|
||||
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
|
||||
{
|
||||
#ifdef NO_CUBLASLT
|
||||
#else
|
||||
cublasLtOrder_t orderA = get_order<SRC>();
|
||||
cublasLtOrder_t orderOut = get_order<TARGET>();
|
||||
int ldA = get_leading_dim<SRC>(dim1, dim2);
|
||||
int ldOut = get_leading_dim<TARGET>(dim1, dim2);
|
||||
|
||||
cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
|
||||
cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
|
||||
cublasOperation_t opTranspose = CUBLAS_OP_T;
|
||||
float transformAlpha = 1.0f, transformBeta = 0.0f;
|
||||
|
||||
|
||||
if(DTYPE == 8)
|
||||
{
|
||||
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
|
||||
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
|
||||
}
|
||||
else if(DTYPE == 32)
|
||||
{
|
||||
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
|
||||
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
|
||||
}
|
||||
|
||||
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
|
||||
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
|
||||
|
||||
checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
|
||||
|
||||
if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
|
||||
|
||||
checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
|
||||
|
||||
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
|
||||
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
|
||||
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
|
||||
#endif
|
||||
cout << "" << endl;
|
||||
cout << "=============================================" << endl;
|
||||
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
|
||||
cout << "=============================================" << endl;
|
||||
cout << "" << endl;
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||
|
@ -399,7 +331,6 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
|
|||
|
||||
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{
|
||||
#ifdef NO_CUBLASLT
|
||||
cout << "" << endl;
|
||||
cout << "=============================================" << endl;
|
||||
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
|
||||
|
@ -408,62 +339,6 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle
|
|||
assert(false);
|
||||
|
||||
return 0;
|
||||
#else
|
||||
int has_error = 0;
|
||||
cublasLtMatmulDesc_t matmulDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
||||
cublasOperation_t opT = CUBLAS_OP_T;
|
||||
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
|
||||
cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
|
||||
cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
|
||||
cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
|
||||
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
|
||||
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
|
||||
if(FORMATB == COL_TURING)
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
|
||||
else
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
|
||||
|
||||
if(DTYPE_OUT == 32)
|
||||
{
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
|
||||
int alpha = 1, beta = 0;
|
||||
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
|
||||
if(!SCALE_ROWS)
|
||||
{
|
||||
float alpha = 1.0f, beta = 0.0f;
|
||||
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
|
||||
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
|
||||
if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
|
||||
if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
|
||||
if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
|
||||
if(has_error == 1)
|
||||
printf("error detected");
|
||||
|
||||
return has_error;
|
||||
#endif
|
||||
}
|
||||
|
||||
int fill_up_to_nearest_multiple(int value, int multiple)
|
||||
|
@ -484,7 +359,7 @@ void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out,
|
|||
assert(threads <= tilesize);
|
||||
|
||||
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
#define STATS_THREADS 64
|
||||
|
@ -505,7 +380,7 @@ void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_r
|
|||
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
||||
else if(nnz_threshold != 0.0)
|
||||
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
|
||||
}
|
||||
|
||||
|
@ -529,7 +404,7 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
|
|||
else
|
||||
kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
|
||||
|
@ -573,69 +448,27 @@ template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *o
|
|||
}
|
||||
|
||||
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
|
||||
void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
|
||||
{
|
||||
|
||||
#ifdef NO_CUBLASLT
|
||||
#else
|
||||
cout << "" << endl;
|
||||
cout << "=============================================" << endl;
|
||||
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
|
||||
cout << "=============================================" << endl;
|
||||
cout << "" << endl;
|
||||
assert(false);
|
||||
|
||||
cusparseSpMatDescr_t descA;
|
||||
cusparseDnMatDescr_t descB, descC;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
void *dBuffer = NULL;
|
||||
size_t bufferSize = 0;
|
||||
|
||||
CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
|
||||
A_rowidx, A_colidx, A_vals,
|
||||
CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
|
||||
// Create dense matrix C
|
||||
CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
|
||||
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
|
||||
// Create dense matrix B
|
||||
if(transposed_B)
|
||||
{
|
||||
int tmp = A_cols;
|
||||
A_cols = B_cols;
|
||||
B_cols = tmp;
|
||||
}
|
||||
|
||||
CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
|
||||
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
|
||||
// allocate an external buffer if needed
|
||||
CHECK_CUSPARSE( cusparseSpMM_bufferSize(
|
||||
handle,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
|
||||
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
|
||||
CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
|
||||
|
||||
// execute SpMM
|
||||
CHECK_CUSPARSE( cusparseSpMM(handle,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
|
||||
CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));
|
||||
|
||||
// destroy matrix/vector descriptors
|
||||
CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
|
||||
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
|
||||
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
|
||||
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
|
||||
#endif
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{
|
||||
|
||||
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
|
@ -658,7 +491,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
|||
}
|
||||
|
||||
kExtractOutliers<FORMAT><<<num_blocks, threads>>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
CUDA_CHECK_RETURN(hipPeekAtLastError());
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
|
|
72
csrc/ops.cuh
72
csrc/ops.cuh
|
@ -12,29 +12,31 @@
|
|||
#include <unistd.h>
|
||||
#include <assert.h>
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cublasLt.h>
|
||||
#include <cusparse.h>
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hipblas.h>
|
||||
// #include <cublasLt.h>
|
||||
#include <hipsparse.h>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
typedef struct cublasLtContext* cublasLtHandle_t;
|
||||
|
||||
#define CUDA_CHECK_RETURN(value) { \
|
||||
cudaError_t _m_cudaStat = value; \
|
||||
if (_m_cudaStat != cudaSuccess) { \
|
||||
hipError_t _m_cudaStat = value; \
|
||||
if (_m_cudaStat != hipSuccess) { \
|
||||
fprintf(stderr, "Error %s at line %d in file %s\n", \
|
||||
cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
|
||||
hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
|
||||
exit(1); \
|
||||
} }
|
||||
|
||||
#define THREADS_PER_BLOCKS (512)
|
||||
|
||||
#define CHECK_CUSPARSE(value) { \
|
||||
cusparseStatus_t _m_cudaStat = value; \
|
||||
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
|
||||
fprintf(stderr, "Error %s at line %d in file %s\n", \
|
||||
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
|
||||
hipsparseStatus_t _m_cudaStat = value; \
|
||||
if (_m_cudaStat != HIPSPARSE_STATUS_SUCCESS) { \
|
||||
fprintf(stderr, "Error <sparse error> at line %d in file %s\n", \
|
||||
__LINE__, __FILE__); \
|
||||
exit(1); \
|
||||
} }
|
||||
|
||||
|
@ -42,15 +44,15 @@
|
|||
#define THREADS_PER_BLOCKS (512)
|
||||
|
||||
|
||||
inline void checkCudaStatus(cudaError_t status) {
|
||||
if (status != cudaSuccess) {
|
||||
printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
|
||||
inline void checkCudaStatus(hipError_t status) {
|
||||
if (status != hipSuccess) {
|
||||
printf("cuda API failed with status %d: %s\n", status, hipGetErrorString(status));
|
||||
throw std::logic_error("cuda API failed");
|
||||
}
|
||||
}
|
||||
|
||||
inline int checkCublasStatus(cublasStatus_t status) {
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
inline int checkCublasStatus(hipblasStatus_t status) {
|
||||
if (status != HIPBLAS_STATUS_SUCCESS) {
|
||||
printf("cuBLAS API failed with status %d\n", status);
|
||||
//throw std::logic_error("cuBLAS API failed");
|
||||
return 1;
|
||||
|
@ -84,40 +86,40 @@ typedef enum Transform_t
|
|||
class Context
|
||||
{
|
||||
public:
|
||||
cublasHandle_t m_handle;
|
||||
hipblasHandle_t m_handle;
|
||||
|
||||
Context()
|
||||
{
|
||||
cublasHandle_t handle;
|
||||
cublasCreate_v2(&handle);
|
||||
hipblasHandle_t handle;
|
||||
hipblasCreate(&handle);
|
||||
m_handle = handle;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class ContextLt
|
||||
{
|
||||
public:
|
||||
cublasLtHandle_t m_handle;
|
||||
// class ContextLt
|
||||
// {
|
||||
// public:
|
||||
// cublasLtHandle_t m_handle;
|
||||
|
||||
ContextLt()
|
||||
{
|
||||
cublasLtHandle_t handle;
|
||||
cublasLtCreate(&handle);
|
||||
m_handle = handle;
|
||||
}
|
||||
// ContextLt()
|
||||
// {
|
||||
// cublasLtHandle_t handle;
|
||||
// cublasLtCreate(&handle);
|
||||
// m_handle = handle;
|
||||
// }
|
||||
|
||||
};
|
||||
// };
|
||||
|
||||
class ContextCusparse
|
||||
{
|
||||
public:
|
||||
cusparseHandle_t m_handle;
|
||||
hipsparseHandle_t m_handle;
|
||||
|
||||
ContextCusparse()
|
||||
{
|
||||
cusparseHandle_t handle;
|
||||
cusparseCreate(&handle);
|
||||
hipsparseHandle_t handle;
|
||||
hipsparseCreate(&handle);
|
||||
m_handle = handle;
|
||||
}
|
||||
|
||||
|
@ -170,7 +172,7 @@ void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col
|
|||
|
||||
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
|
||||
|
||||
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
|
||||
void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
|
||||
|
||||
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
|
||||
|
||||
|
|
|
@ -2,73 +2,75 @@
|
|||
|
||||
#include "Algo-Direct-Common.h"
|
||||
|
||||
namespace BinSearch {
|
||||
namespace Details {
|
||||
|
||||
template <typename T, Algos A>
|
||||
struct AlgoScalarBase<T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : DirectAux::DirectInfo<2, T, A>
|
||||
namespace BinSearch
|
||||
{
|
||||
private:
|
||||
typedef DirectAux::DirectInfo<2, T, A> base_t;
|
||||
static const size_t Offset=2;
|
||||
|
||||
public:
|
||||
AlgoScalarBase(const T* x, const uint32 n)
|
||||
: base_t(x, n)
|
||||
namespace Details
|
||||
{
|
||||
}
|
||||
|
||||
FORCE_INLINE uint32 scalar(T z) const
|
||||
{
|
||||
const T* px = base_t::data.xi;
|
||||
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z);
|
||||
uint32 iidx = buckets[bidx];
|
||||
px += iidx;
|
||||
if (z < *px)
|
||||
--iidx;
|
||||
if (z < *(px+1))
|
||||
--iidx;
|
||||
return iidx;
|
||||
}
|
||||
};
|
||||
template <typename T, Algos A>
|
||||
struct AlgoScalarBase<T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : DirectAux::DirectInfo<2, T, A>
|
||||
{
|
||||
private:
|
||||
typedef DirectAux::DirectInfo<2, T, A> base_t;
|
||||
static const size_t Offset = 2;
|
||||
|
||||
public:
|
||||
AlgoScalarBase(const T *x, const uint32 n)
|
||||
: base_t(x, n)
|
||||
{
|
||||
}
|
||||
|
||||
template <InstrSet I, typename T, Algos A>
|
||||
struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : AlgoScalarBase<T, A>
|
||||
{
|
||||
static const uint32 nElem = sizeof(typename InstrFloatTraits<I, T>::vec_t) / sizeof(T);
|
||||
FORCE_INLINE uint32 scalar(T z) const
|
||||
{
|
||||
const T *px = base_t::data.xi;
|
||||
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z);
|
||||
uint32 iidx = buckets[bidx];
|
||||
px += iidx;
|
||||
if (z < *px)
|
||||
--iidx;
|
||||
if (z < *(px + 1))
|
||||
--iidx;
|
||||
return iidx;
|
||||
}
|
||||
};
|
||||
|
||||
typedef FVec<I, T> fVec;
|
||||
typedef IVec<SSE, T> i128;
|
||||
template <InstrSet I, typename T, Algos A>
|
||||
struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : AlgoScalarBase<T, A>
|
||||
{
|
||||
static const uint32 nElem = sizeof(typename InstrFloatTraits<I, T>::vec_t) / sizeof(T);
|
||||
|
||||
struct Constants
|
||||
{
|
||||
fVec vscaler;
|
||||
fVec vcst0;
|
||||
IVec<I, T> one;
|
||||
};
|
||||
typedef FVec<I, T> fVec;
|
||||
typedef IVec<SSE, T> i128;
|
||||
|
||||
private:
|
||||
typedef AlgoScalarBase<T, A> base_t;
|
||||
struct Constants
|
||||
{
|
||||
fVec vscaler;
|
||||
fVec vcst0;
|
||||
IVec<I, T> one;
|
||||
};
|
||||
|
||||
FORCE_INLINE
|
||||
//NO_INLINE
|
||||
void resolve(const FVec<SSE, float>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
|
||||
{
|
||||
union U {
|
||||
__m128i vec;
|
||||
uint32 ui32[4];
|
||||
} u;
|
||||
private:
|
||||
typedef AlgoScalarBase<T, A> base_t;
|
||||
|
||||
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
const float *xi = base_t::data.xi;
|
||||
FORCE_INLINE
|
||||
// NO_INLINE
|
||||
void resolve(const FVec<SSE, float> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
|
||||
{
|
||||
union U
|
||||
{
|
||||
__m128i vec;
|
||||
uint32 ui32[4];
|
||||
} u;
|
||||
|
||||
// read indices t
|
||||
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
|
||||
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
|
||||
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
|
||||
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
|
||||
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
const float *xi = base_t::data.xi;
|
||||
|
||||
// read indices t
|
||||
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
|
||||
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
|
||||
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
|
||||
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
|
||||
|
||||
#if 0
|
||||
// read pairs ( X(t-1), X(t) )
|
||||
|
@ -87,65 +89,66 @@ private:
|
|||
__m128 vxm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6));
|
||||
__m128 vxp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6));
|
||||
#else
|
||||
__m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2));
|
||||
__m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0));
|
||||
__m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6));
|
||||
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
|
||||
__m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2));
|
||||
__m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0));
|
||||
__m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6));
|
||||
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
|
||||
#endif
|
||||
IVec<SSE, float> i(u.vec);
|
||||
IVec<SSE, float> vlem = vz < vxm;
|
||||
IVec<SSE, float> vlep = vz < vxp;
|
||||
i = i + vlem + vlep;
|
||||
i.store(pr);
|
||||
}
|
||||
IVec<SSE, float> i(u.vec);
|
||||
IVec<SSE, float> vlem = (vz < vxm);
|
||||
IVec<SSE, float> vlep = (vz < vxp);
|
||||
i = i + vlem + vlep;
|
||||
i.store(pr);
|
||||
}
|
||||
|
||||
FORCE_INLINE
|
||||
//NO_INLINE
|
||||
void resolve(const FVec<SSE, double>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
|
||||
{
|
||||
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
const double *xi = base_t::data.xi;
|
||||
FORCE_INLINE
|
||||
// NO_INLINE
|
||||
void resolve(const FVec<SSE, double> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
|
||||
{
|
||||
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
const double *xi = base_t::data.xi;
|
||||
|
||||
uint32 b1 = buckets[bidx.get1()];
|
||||
uint32 b0 = buckets[bidx.get0()];
|
||||
uint32 b1 = buckets[bidx.get1()];
|
||||
uint32 b0 = buckets[bidx.get0()];
|
||||
|
||||
const double *p1 = &xi[b1];
|
||||
const double *p0 = &xi[b0];
|
||||
const double *p1 = &xi[b1];
|
||||
const double *p0 = &xi[b0];
|
||||
|
||||
// read pairs ( X(t-1), X(t) )
|
||||
__m128d vx1 = _mm_loadu_pd(p1);
|
||||
__m128d vx0 = _mm_loadu_pd(p0);
|
||||
// read pairs ( X(t-1), X(t) )
|
||||
__m128d vx1 = _mm_loadu_pd(p1);
|
||||
__m128d vx0 = _mm_loadu_pd(p0);
|
||||
|
||||
// build:
|
||||
// { X(t(0)-1), X(t(1)-1) }
|
||||
// { X(t(0)), X(t(1)) }
|
||||
__m128d vxm = _mm_shuffle_pd(vx0, vx1, 0);
|
||||
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
|
||||
// build:
|
||||
// { X(t(0)-1), X(t(1)-1) }
|
||||
// { X(t(0)), X(t(1)) }
|
||||
__m128d vxm = _mm_shuffle_pd(vx0, vx1, 0);
|
||||
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
|
||||
|
||||
IVec<SSE, double> i(b1, b0);
|
||||
IVec<SSE, double> vlem = (vz < vxm);
|
||||
IVec<SSE, double> vlep = (vz < vxp);
|
||||
i = i + vlem + vlep;
|
||||
IVec<SSE, double> i(b1, b0);
|
||||
IVec<SSE, double> vlem = (vz < vxm);
|
||||
IVec<SSE, double> vlep = (vz < vxp);
|
||||
i = i + vlem + vlep;
|
||||
|
||||
union {
|
||||
__m128i vec;
|
||||
uint32 ui32[4];
|
||||
} u;
|
||||
u.vec = i;
|
||||
pr[0] = u.ui32[0];
|
||||
pr[1] = u.ui32[2];
|
||||
}
|
||||
union
|
||||
{
|
||||
__m128i vec;
|
||||
uint32 ui32[4];
|
||||
} u;
|
||||
u.vec = i;
|
||||
pr[0] = u.ui32[0];
|
||||
pr[1] = u.ui32[2];
|
||||
}
|
||||
|
||||
#ifdef USE_AVX
|
||||
|
||||
FORCE_INLINE
|
||||
//NO_INLINE
|
||||
void resolve(const FVec<AVX, float>& vz, const IVec<AVX, float>& bidx, uint32 *pr) const
|
||||
{
|
||||
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
const float *xi = base_t::data.xi;
|
||||
FORCE_INLINE
|
||||
// NO_INLINE
|
||||
void resolve(const FVec<AVX, float> &vz, const IVec<AVX, float> &bidx, uint32 *pr) const
|
||||
{
|
||||
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
const float *xi = base_t::data.xi;
|
||||
|
||||
#if 0 // use gather instructions
|
||||
#if 0 // use gather instructions
|
||||
|
||||
IVec<AVX,float> idxm;
|
||||
idxm.setidx(buckets, bidx);
|
||||
|
@ -159,21 +162,22 @@ private:
|
|||
|
||||
#else // do not use gather instrucions
|
||||
|
||||
union U {
|
||||
__m256i vec;
|
||||
uint32 ui32[8];
|
||||
} u;
|
||||
union U
|
||||
{
|
||||
__m256i vec;
|
||||
uint32 ui32[8];
|
||||
} u;
|
||||
|
||||
// read indices t
|
||||
// read indices t
|
||||
|
||||
const double *p7 = reinterpret_cast<const double *>(&xi[(u.ui32[7] = buckets[bidx.get7()])]);
|
||||
const double *p6 = reinterpret_cast<const double *>(&xi[(u.ui32[6] = buckets[bidx.get6()])]);
|
||||
const double *p5 = reinterpret_cast<const double *>(&xi[(u.ui32[5] = buckets[bidx.get5()])]);
|
||||
const double *p4 = reinterpret_cast<const double *>(&xi[(u.ui32[4] = buckets[bidx.get4()])]);
|
||||
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
|
||||
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
|
||||
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
|
||||
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
|
||||
const double *p7 = reinterpret_cast<const double *>(&xi[(u.ui32[7] = buckets[bidx.get7()])]);
|
||||
const double *p6 = reinterpret_cast<const double *>(&xi[(u.ui32[6] = buckets[bidx.get6()])]);
|
||||
const double *p5 = reinterpret_cast<const double *>(&xi[(u.ui32[5] = buckets[bidx.get5()])]);
|
||||
const double *p4 = reinterpret_cast<const double *>(&xi[(u.ui32[4] = buckets[bidx.get4()])]);
|
||||
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
|
||||
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
|
||||
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
|
||||
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
|
||||
|
||||
#if 0 // perform 8 loads in double precision
|
||||
|
||||
|
@ -210,96 +214,93 @@ private:
|
|||
|
||||
IVec<AVX, float> ip(u.vec);
|
||||
|
||||
#else // use __mm256_set_pd
|
||||
#else // use __mm256_set_pd
|
||||
|
||||
// read pairs ( X(t-1), X(t) )
|
||||
__m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) }
|
||||
__m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) }
|
||||
// read pairs ( X(t-1), X(t) )
|
||||
__m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) }
|
||||
__m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) }
|
||||
|
||||
// { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) }
|
||||
FVec<AVX, float> vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6) );
|
||||
// { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) }
|
||||
FVec<AVX, float> vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6) );
|
||||
// { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) }
|
||||
FVec<AVX, float> vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6));
|
||||
// { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) }
|
||||
FVec<AVX, float> vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6));
|
||||
|
||||
IVec<AVX, float> ip(u.vec);
|
||||
IVec<AVX, float> ip(u.vec);
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
IVec<AVX, float> vlem = vz < vxm;
|
||||
IVec<AVX, float> vlep = vz < vxp;
|
||||
ip = ip + vlem + vlep;
|
||||
IVec<AVX, float> vlem = vz < vxm;
|
||||
IVec<AVX, float> vlep = vz < vxp;
|
||||
ip = ip + vlem + vlep;
|
||||
|
||||
ip.store(pr);
|
||||
}
|
||||
ip.store(pr);
|
||||
}
|
||||
|
||||
FORCE_INLINE
|
||||
// NO_INLINE
|
||||
void resolve(const FVec<AVX, double> &vz, const IVec<SSE, float> &bidx, uint32 *pr) const
|
||||
{
|
||||
union
|
||||
{
|
||||
__m256i vec;
|
||||
uint64 ui64[4];
|
||||
} u;
|
||||
|
||||
const uint32 *buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
const double *xi = base_t::data.xi;
|
||||
|
||||
FORCE_INLINE
|
||||
//NO_INLINE
|
||||
void resolve(const FVec<AVX, double>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
|
||||
{
|
||||
union {
|
||||
__m256i vec;
|
||||
uint64 ui64[4];
|
||||
} u;
|
||||
// read indices t
|
||||
const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])];
|
||||
const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])];
|
||||
const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])];
|
||||
const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])];
|
||||
|
||||
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||
const double *xi = base_t::data.xi;
|
||||
// read pairs ( X(t-1), X(t) )
|
||||
__m128d xp3 = _mm_loadu_pd(p3);
|
||||
__m128d xp2 = _mm_loadu_pd(p2);
|
||||
__m128d xp1 = _mm_loadu_pd(p1);
|
||||
__m128d xp0 = _mm_loadu_pd(p0);
|
||||
|
||||
// read indices t
|
||||
const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])];
|
||||
const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])];
|
||||
const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])];
|
||||
const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])];
|
||||
// build:
|
||||
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
|
||||
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
|
||||
__m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1);
|
||||
__m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1);
|
||||
FVec<AVX, double> vxm = _mm256_unpacklo_pd(x02, x13);
|
||||
FVec<AVX, double> vxp = _mm256_unpackhi_pd(x02, x13);
|
||||
|
||||
// read pairs ( X(t-1), X(t) )
|
||||
__m128d xp3 = _mm_loadu_pd(p3);
|
||||
__m128d xp2 = _mm_loadu_pd(p2);
|
||||
__m128d xp1 = _mm_loadu_pd(p1);
|
||||
__m128d xp0 = _mm_loadu_pd(p0);
|
||||
// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0);
|
||||
// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0);
|
||||
// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3);
|
||||
// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3);
|
||||
// FVec<AVX, double> vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1);
|
||||
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
|
||||
|
||||
// build:
|
||||
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
|
||||
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
|
||||
__m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1);
|
||||
__m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1);
|
||||
FVec<AVX, double> vxm = _mm256_unpacklo_pd(x02,x13);
|
||||
FVec<AVX, double> vxp = _mm256_unpackhi_pd(x02,x13);
|
||||
|
||||
|
||||
// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0);
|
||||
// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0);
|
||||
// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3);
|
||||
// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3);
|
||||
// FVec<AVX, double> vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1);
|
||||
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
|
||||
|
||||
IVec<AVX, double> i(u.vec);
|
||||
IVec<AVX, double> vlem = vz < vxm;
|
||||
IVec<AVX, double> vlep = vz < vxp;
|
||||
i = i + vlem + vlep;
|
||||
i.extractLo32s().store(pr);
|
||||
}
|
||||
IVec<AVX, double> i(u.vec);
|
||||
IVec<AVX, double> vlem = vz < vxm;
|
||||
IVec<AVX, double> vlep = vz < vxp;
|
||||
i = i + vlem + vlep;
|
||||
i.extractLo32s().store(pr);
|
||||
}
|
||||
#endif
|
||||
|
||||
public:
|
||||
public:
|
||||
AlgoVecBase(const T *x, const uint32 n) : base_t(x, n) {}
|
||||
|
||||
AlgoVecBase(const T* x, const uint32 n) : base_t(x, n) {}
|
||||
void initConstants(Constants &cst) const
|
||||
{
|
||||
cst.vscaler.setN(base_t::data.scaler);
|
||||
cst.vcst0.setN(base_t::data.cst0);
|
||||
cst.one.setN(uint32(1));
|
||||
}
|
||||
|
||||
void initConstants(Constants& cst) const
|
||||
{
|
||||
cst.vscaler.setN(base_t::data.scaler);
|
||||
cst.vcst0.setN(base_t::data.cst0);
|
||||
cst.one.setN(uint32(1));
|
||||
}
|
||||
|
||||
void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
|
||||
{
|
||||
fVec vz(pz);
|
||||
resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr);
|
||||
}
|
||||
};
|
||||
} // namespace Details
|
||||
void vectorial(uint32 *pr, const T *pz, const Constants &cst) const
|
||||
{
|
||||
fVec vz(pz);
|
||||
resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr);
|
||||
}
|
||||
};
|
||||
} // namespace Details
|
||||
} // namespace BinSearch
|
||||
|
|
Loading…
Reference in New Issue
Block a user