Compare commits
10 Commits
3c9aca9124
...
c88f97a9c8
Author | SHA1 | Date | |
---|---|---|---|
c88f97a9c8 | |||
|
e38b9e91b7 | ||
|
c97c78bd66 | ||
|
0b481bfcc2 | ||
|
615d47583f | ||
|
705bc024d2 | ||
|
40361ecfbb | ||
|
3682106eb0 | ||
|
d10197bc93 | ||
|
18e827d666 |
|
@ -310,3 +310,8 @@ User experience:
|
|||
|
||||
Performance:
|
||||
- improved 4-bit inference performance for A100 GPUs. This degraded performance for A40/RTX3090 and RTX 4090 GPUs slightly.
|
||||
|
||||
### 0.41.0
|
||||
|
||||
Bug fixes:
|
||||
- Fixed bugs in dynamic exponent data type creation. Thank you @RossM, @KohakuBlueleaf, @ArrowM #659 #227 #262 #152
|
||||
|
|
21
Makefile
21
Makefile
|
@ -6,15 +6,27 @@ GPP:= /usr/bin/g++
|
|||
ifeq ($(CUDA_HOME),)
|
||||
CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev)
|
||||
endif
|
||||
ifeq ($(ROCM_HOME),)
|
||||
ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f4- | rev)
|
||||
endif
|
||||
|
||||
ifneq ($(CUDA_HOME),)
|
||||
ifndef CUDA_VERSION
|
||||
$(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU)
|
||||
CUDA_VERSION:=
|
||||
endif
|
||||
|
||||
else ifneq ($(ROCM_HOME),)
|
||||
ifndef ROCM_TARGET
|
||||
$(error ERROR: ROCM_TARGET not set. Call make with ROCM string (see https://www.llvm.org/docs/AMDGPUUsage.html#processors), for example: make hip ROCM_TARGET=gfx1030)
|
||||
ROCM_TARGET:=
|
||||
endif
|
||||
endif
|
||||
|
||||
|
||||
|
||||
NVCC := $(CUDA_HOME)/bin/nvcc
|
||||
HIPCC:= $(ROCM_HOME)/bin/hipcc
|
||||
|
||||
###########################################
|
||||
|
||||
|
@ -113,6 +125,15 @@ cuda12x: $(BUILD_DIR) env
|
|||
cpuonly: $(BUILD_DIR) env
|
||||
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so
|
||||
|
||||
|
||||
HIP_INCLUDE := -I $(ROCM_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include
|
||||
HIP_LIB := -L $(ROCM_HOME)/lib -lhipblas -lhiprand -lhipsparse #-lhipblaslt #TODO: check if this is actually only gfx90a
|
||||
hip: $(BUILD_DIR)
|
||||
$(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/ops.cu
|
||||
$(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/kernels.cu
|
||||
# HCC is deprecated, but used by hipBLASlt header. Since blas isn't even used doesn't matter, this is just so that it even compiles
|
||||
$(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBITS_AND_BYTES_USE_ROCM -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so
|
||||
|
||||
env:
|
||||
@echo "ENVIRONMENT"
|
||||
@echo "============================"
|
||||
|
|
25
README.md
25
README.md
|
@ -9,6 +9,31 @@ Resources:
|
|||
|
||||
- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/)
|
||||
|
||||
|
||||
## Quickstart Rocm
|
||||
|
||||
Works well with these docker images:
|
||||
- [rocm/pytorch](https://hub.docker.com/r/rocm/pytorch)
|
||||
- [rocm/pytorch-nightly](https://hub.docker.com/r/rocm/pytorch-nightly).
|
||||
|
||||
For installation then do:
|
||||
```bash
|
||||
git clone https://github.com/arlo-phoenix/bitsandbytes-rocm-5.6.git bitsandbytes
|
||||
cd bitsandbytes
|
||||
make hip ROCM_TARGET=gfx...
|
||||
pip install .
|
||||
```
|
||||
see https://www.llvm.org/docs/AMDGPUUsage.html#processors for finding ROCM_TARGET (e.g. gfx1030 for 6800XT,6900XT)
|
||||
|
||||
## Info about this port / Credits
|
||||
|
||||
Instead of just using the [hipified](https://github.com/ROCm-Developer-Tools/HIPIFY) output, I went through all the different variables/functions and used defines to make the Cuda code use the HIP equivalents. That idea is taken from the [llama.cpp rocblas port](https://github.com/ggerganov/llama.cpp/pull/1087).
|
||||
|
||||
The python/makefile/compatibility changes are just copied from [this clean older rocm port](https://github.com/agrocylo/bitsandbytes-rocm) by @agrocylo. Thanks for that, was easy to look through.
|
||||
|
||||
I very much recommend using docker if you want to run this. As this just redefines some Cuda variables/functions, I also had to include all the needed dependency headers. Including [hipBLASlt](https://github.com/ROCmSoftwarePlatform/hipBLASLt), which is still in it's infancy and not supported by most architectures, the header works though. That's also why some of the newer functions won't work and will just log that they are not functioning. The optimizers like AdamW8bit should work though and this fork will be a lot easier to keep up to date when the Cuda source files change.
|
||||
|
||||
|
||||
## TL;DR
|
||||
**Requirements**
|
||||
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
|
||||
|
|
|
@ -224,7 +224,7 @@ matmul_cublas = MatMul8bit.apply
|
|||
|
||||
def supports_igemmlt(device: torch.device) -> bool:
|
||||
"""check if this device supports the optimized int8 kernel"""
|
||||
if torch.cuda.get_device_capability(device=device) < (7, 5):
|
||||
if torch.cuda.get_device_capability(device=device) < (7, 5) or torch.version.hip:
|
||||
return False
|
||||
device_name = torch.cuda.get_device_name(device=device)
|
||||
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
|
||||
|
|
|
@ -332,7 +332,9 @@ def evaluate_cuda_setup():
|
|||
cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
|
||||
('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues'))
|
||||
cuda_setup.add_log_entry('='*80)
|
||||
|
||||
if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None
|
||||
if torch.version.hip: return 'libbitsandbytes_hip_nohipblaslt.so', None, None, None
|
||||
|
||||
cudart_path = determine_cuda_runtime_lib_path()
|
||||
ccs = get_compute_capabilities()
|
||||
|
|
|
@ -815,6 +815,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
|
|||
if out is None:
|
||||
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
|
||||
|
||||
#TODO: catch rocm wave64 only, pytorch has a property, but that one likely contains the wrong waveSize
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
|
|
|
@ -4,6 +4,23 @@
|
|||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#include <kernels.cuh>
|
||||
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <hipcub/block/block_radix_sort.hpp>
|
||||
#include <hipcub/warp/warp_reduce.hpp>
|
||||
#include <hipcub/block/block_load.hpp>
|
||||
#include <hipcub/block/block_discontinuity.hpp>
|
||||
#include <hipcub/block/block_store.hpp>
|
||||
#include <hipcub/block/block_reduce.hpp>
|
||||
#include <hip/hip_math_constants.h>
|
||||
#define cub hipcub
|
||||
#define __syncwarp __syncthreads //TODO: HIP doesn't have this so just sync threads
|
||||
|
||||
#else
|
||||
#include <math_constants.h>
|
||||
#include <mma.h>
|
||||
#include <cub/block/block_radix_sort.cuh>
|
||||
#include <cub/warp/warp_reduce.cuh>
|
||||
#include <cub/block/block_load.cuh>
|
||||
|
@ -11,18 +28,17 @@
|
|||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_reduce.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#include <math_constants.h>
|
||||
#endif
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <mma.h>
|
||||
|
||||
|
||||
#define HLF_MAX 65504
|
||||
#define TH 1024
|
||||
#define NUM 4
|
||||
#define NUM_BLOCK 4096
|
||||
|
||||
|
||||
#ifndef BITS_AND_BYTES_USE_ROCM
|
||||
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
|
||||
__device__ float atomicMax(float* address, float val) {
|
||||
int* address_as_i = reinterpret_cast<int*>(address);
|
||||
|
@ -47,6 +63,7 @@ __device__ float atomicMin(float* address, float val) {
|
|||
} while (assumed != old);
|
||||
return __int_as_float(old);
|
||||
}
|
||||
#endif
|
||||
|
||||
__device__ float dDequantizeFP4(unsigned char val, float absmax)
|
||||
{
|
||||
|
@ -719,10 +736,26 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
|
||||
template<typename T, int REQUESTED_BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
|
||||
//__launch_bounds__(TH, 4)
|
||||
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
|
||||
{
|
||||
#ifdef DISABLE_WARP_32
|
||||
//CUDA has warpsize 32 so just multiply by 2 to get amd warp size
|
||||
//currently just stopping below with a return anyways so this size isn't actually used, just needed for compilation
|
||||
const int BLOCK_SIZE=((REQUESTED_BLOCK_SIZE / NUM_PER_TH % 64) != 0 ) ? REQUESTED_BLOCK_SIZE * 2 : REQUESTED_BLOCK_SIZE;
|
||||
|
||||
//TODO: figure out how to make compiler recognize what isn't executed based on template arguments, without the code below in ifndef would trigger static_assert if
|
||||
//this condition is true
|
||||
if ((REQUESTED_BLOCK_SIZE / NUM_PER_TH % 64) != 0)
|
||||
{
|
||||
printf("kQuantizeBlockwise not fully supported on Rocm! BLOCK_SIZE/NUM_PER_TH needs to be divisible by 64.");
|
||||
return;
|
||||
}
|
||||
#else
|
||||
const int BLOCK_SIZE=REQUESTED_BLOCK_SIZE;
|
||||
#endif
|
||||
|
||||
const int n_full = gridDim.x * BLOCK_SIZE;
|
||||
int valid_items = 0;
|
||||
const int base_idx = (blockIdx.x * BLOCK_SIZE);
|
||||
|
@ -3956,6 +3989,7 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
|
|||
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
|
||||
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
|
||||
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
|
||||
|
||||
|
|
|
@ -5,12 +5,17 @@
|
|||
|
||||
#include <ops.cuh>
|
||||
#include <kernels.cuh>
|
||||
#include <cub/device/device_scan.cuh>
|
||||
#include <limits>
|
||||
#include <BinSearch.h>
|
||||
#include <cassert>
|
||||
#include <common.h>
|
||||
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
#include <hipcub/device/device_scan.hpp>
|
||||
#else
|
||||
#include <cub/device/device_scan.cuh>
|
||||
#endif
|
||||
|
||||
|
||||
using namespace BinSearch;
|
||||
using std::cout;
|
||||
|
|
48
csrc/ops.cuh
48
csrc/ops.cuh
|
@ -12,16 +12,60 @@
|
|||
#include <unistd.h>
|
||||
#include <assert.h>
|
||||
|
||||
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
#ifndef DISABLE_WARP_32
|
||||
#define __AMDGCN_WAVEFRONT_SIZE 32 // check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's
|
||||
#endif
|
||||
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hipsparse/hipsparse.h>
|
||||
#include <hip/hip_bfloat16.h>
|
||||
|
||||
#define cudaPeekAtLastError hipPeekAtLastError
|
||||
#define cudaMemset hipMemset
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
#define CUDA_R_8I HIPBLAS_R_8I
|
||||
#define CUDA_R_32I HIPBLAS_R_32I
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cublasLtMatrixLayoutCreate hipblasMatrixLayoutCreate
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cusparseStatus_t hipsparseStatus_t
|
||||
#define CUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
#define cublasCreate_v2 hipblasCreate
|
||||
#define cusparseHandle_t hipsparseHandle_t
|
||||
#define cusparseCreate hipsparseCreate
|
||||
#define __nv_bfloat16 hip_bfloat16
|
||||
#define cublasLtHandle_t hipblasHandle_t
|
||||
#define cublasLtCreate hipblasCreate
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues
|
||||
|
||||
#else
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cublasLt.h>
|
||||
#include <cusparse.h>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#endif
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -5,9 +5,22 @@
|
|||
|
||||
#if BUILD_CUDA
|
||||
#include <ops.cuh>
|
||||
|
||||
#ifdef BITS_AND_BYTES_USE_ROCM
|
||||
#include <hipblas/hipblas.h>
|
||||
#define cublasLtHandle_t hipblasHandle_t
|
||||
#define cudaMemAttachHost hipMemAttachHost
|
||||
#define cudaMallocManaged hipMallocManaged
|
||||
#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess
|
||||
#define cudaDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaMemPrefetchAsync hipMemPrefetchAsync
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#include <cpu_ops.h>
|
||||
|
||||
|
||||
|
||||
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
|
||||
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
|
||||
// maintain all that boilerplate
|
||||
|
|
|
@ -93,8 +93,8 @@ private:
|
|||
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
|
||||
#endif
|
||||
IVec<SSE, float> i(u.vec);
|
||||
IVec<SSE, float> vlem = vz < vxm;
|
||||
IVec<SSE, float> vlep = vz < vxp;
|
||||
IVec<SSE, float> vlem = operator< (vz,vxm);
|
||||
IVec<SSE, float> vlep = operator< (vz,vxp);
|
||||
i = i + vlem + vlep;
|
||||
i.store(pr);
|
||||
}
|
||||
|
@ -123,8 +123,8 @@ private:
|
|||
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
|
||||
|
||||
IVec<SSE, double> i(b1, b0);
|
||||
IVec<SSE, double> vlem = (vz < vxm);
|
||||
IVec<SSE, double> vlep = (vz < vxp);
|
||||
IVec<SSE, double> vlem = operator< (vz, vxm);
|
||||
IVec<SSE, double> vlep = operator< (vz, vxp);
|
||||
i = i + vlem + vlep;
|
||||
|
||||
union {
|
||||
|
@ -227,8 +227,8 @@ private:
|
|||
|
||||
#endif
|
||||
|
||||
IVec<AVX, float> vlem = vz < vxm;
|
||||
IVec<AVX, float> vlep = vz < vxp;
|
||||
IVec<AVX, float> vlem = operator< (vz, vxm);
|
||||
IVec<AVX, float> vlep = operator< (vz, vxp);
|
||||
ip = ip + vlem + vlep;
|
||||
|
||||
ip.store(pr);
|
||||
|
@ -277,8 +277,8 @@ private:
|
|||
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
|
||||
|
||||
IVec<AVX, double> i(u.vec);
|
||||
IVec<AVX, double> vlem = vz < vxm;
|
||||
IVec<AVX, double> vlep = vz < vxp;
|
||||
IVec<AVX, double> vlem = operator< (vz,vxm);
|
||||
IVec<AVX, double> vlep = operator< (vz,vxp);
|
||||
i = i + vlem + vlep;
|
||||
i.extractLo32s().store(pr);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user