Compare commits

...

10 Commits

Author SHA1 Message Date
mrq
c88f97a9c8 drop support for gfx903 because depending on hipblaslt gums up too many things 2023-10-12 19:16:14 -05:00
arlo-phoenix
e38b9e91b7 Revert get_cuda_version ROCM version change
not called anymore
2023-08-08 21:31:20 +02:00
arlo-phoenix
c97c78bd66 Update README rocm quickstart 2023-08-08 21:28:37 +02:00
arlo-phoenix
0b481bfcc2 Use workaround for ROCm wave32 recognition
just sets __AMDGCN_WAVEFRONT_SIZE forcefully to 32.
Not correct (some GPU's don't support wave32), but works
on the supported GPU's. Can disable with DISABLE_WARP_32

With this blockwise quantize works and with that nf4 is supported.
2023-08-08 18:50:26 +00:00
arlo-phoenix
615d47583f README: Add quickstart and info section 2023-08-05 02:42:13 +02:00
arlo-phoenix
705bc024d2 Makefile: Add make hip 2023-08-05 02:41:58 +02:00
arlo-phoenix
40361ecfbb Adapt python to work with HIP 2023-08-05 02:12:48 +02:00
arlo-phoenix
3682106eb0 Algo-Direct2.h: fix hipcc issue
from https://github.com/agrocylo/bitsandbytes-rocm, thanks
2023-08-05 02:12:14 +02:00
arlo-phoenix
d10197bc93 Add HIP to cuda defines
collected by hipifying all files and then comparing with original
Cuda file
2023-08-05 02:11:46 +02:00
Tim Dettmers
18e827d666 Version 0.41.1. 2023-08-03 20:01:10 -07:00
12 changed files with 168 additions and 18 deletions

View File

@ -310,3 +310,8 @@ User experience:
Performance:
- improved 4-bit inference performance for A100 GPUs. This degraded performance for A40/RTX3090 and RTX 4090 GPUs slightly.
### 0.41.0
Bug fixes:
- Fixed bugs in dynamic exponent data type creation. Thank you @RossM, @KohakuBlueleaf, @ArrowM #659 #227 #262 #152

View File

@ -6,15 +6,27 @@ GPP:= /usr/bin/g++
ifeq ($(CUDA_HOME),)
CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev)
endif
ifeq ($(ROCM_HOME),)
ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f4- | rev)
endif
ifneq ($(CUDA_HOME),)
ifndef CUDA_VERSION
$(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU)
CUDA_VERSION:=
endif
else ifneq ($(ROCM_HOME),)
ifndef ROCM_TARGET
$(error ERROR: ROCM_TARGET not set. Call make with ROCM string (see https://www.llvm.org/docs/AMDGPUUsage.html#processors), for example: make hip ROCM_TARGET=gfx1030)
ROCM_TARGET:=
endif
endif
NVCC := $(CUDA_HOME)/bin/nvcc
HIPCC:= $(ROCM_HOME)/bin/hipcc
###########################################
@ -113,6 +125,15 @@ cuda12x: $(BUILD_DIR) env
cpuonly: $(BUILD_DIR) env
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so
HIP_INCLUDE := -I $(ROCM_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include
HIP_LIB := -L $(ROCM_HOME)/lib -lhipblas -lhiprand -lhipsparse #-lhipblaslt #TODO: check if this is actually only gfx90a
hip: $(BUILD_DIR)
$(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/ops.cu
$(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/kernels.cu
# HCC is deprecated, but used by hipBLASlt header. Since blas isn't even used doesn't matter, this is just so that it even compiles
$(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBITS_AND_BYTES_USE_ROCM -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so
env:
@echo "ENVIRONMENT"
@echo "============================"

View File

@ -9,6 +9,31 @@ Resources:
- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/)
## Quickstart Rocm
Works well with these docker images:
- [rocm/pytorch](https://hub.docker.com/r/rocm/pytorch)
- [rocm/pytorch-nightly](https://hub.docker.com/r/rocm/pytorch-nightly).
For installation then do:
```bash
git clone https://github.com/arlo-phoenix/bitsandbytes-rocm-5.6.git bitsandbytes
cd bitsandbytes
make hip ROCM_TARGET=gfx...
pip install .
```
see https://www.llvm.org/docs/AMDGPUUsage.html#processors for finding ROCM_TARGET (e.g. gfx1030 for 6800XT,6900XT)
## Info about this port / Credits
Instead of just using the [hipified](https://github.com/ROCm-Developer-Tools/HIPIFY) output, I went through all the different variables/functions and used defines to make the Cuda code use the HIP equivalents. That idea is taken from the [llama.cpp rocblas port](https://github.com/ggerganov/llama.cpp/pull/1087).
The python/makefile/compatibility changes are just copied from [this clean older rocm port](https://github.com/agrocylo/bitsandbytes-rocm) by @agrocylo. Thanks for that, was easy to look through.
I very much recommend using docker if you want to run this. As this just redefines some Cuda variables/functions, I also had to include all the needed dependency headers. Including [hipBLASlt](https://github.com/ROCmSoftwarePlatform/hipBLASLt), which is still in it's infancy and not supported by most architectures, the header works though. That's also why some of the newer functions won't work and will just log that they are not functioning. The optimizers like AdamW8bit should work though and this fork will be a lot easier to keep up to date when the Cuda source files change.
## TL;DR
**Requirements**
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.

View File

@ -224,7 +224,7 @@ matmul_cublas = MatMul8bit.apply
def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
if torch.cuda.get_device_capability(device=device) < (7, 5) or torch.version.hip:
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series

View File

@ -332,7 +332,9 @@ def evaluate_cuda_setup():
cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues'))
cuda_setup.add_log_entry('='*80)
if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None
if torch.version.hip: return 'libbitsandbytes_hip_nohipblaslt.so', None, None, None
cudart_path = determine_cuda_runtime_lib_path()
ccs = get_compute_capabilities()

View File

@ -815,6 +815,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
if out is None:
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
#TODO: catch rocm wave64 only, pytorch has a property, but that one likely contains the wrong waveSize
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
prev_device = pre_call(A.device)

View File

@ -4,6 +4,23 @@
// LICENSE file in the root directory of this source tree.
#include <kernels.cuh>
#ifdef BITS_AND_BYTES_USE_ROCM
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
#include <hipcub/block/block_radix_sort.hpp>
#include <hipcub/warp/warp_reduce.hpp>
#include <hipcub/block/block_load.hpp>
#include <hipcub/block/block_discontinuity.hpp>
#include <hipcub/block/block_store.hpp>
#include <hipcub/block/block_reduce.hpp>
#include <hip/hip_math_constants.h>
#define cub hipcub
#define __syncwarp __syncthreads //TODO: HIP doesn't have this so just sync threads
#else
#include <math_constants.h>
#include <mma.h>
#include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh>
#include <cub/block/block_load.cuh>
@ -11,18 +28,17 @@
#include <cub/block/block_store.cuh>
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include <math_constants.h>
#endif
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <mma.h>
#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096
#ifndef BITS_AND_BYTES_USE_ROCM
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
__device__ float atomicMax(float* address, float val) {
int* address_as_i = reinterpret_cast<int*>(address);
@ -47,6 +63,7 @@ __device__ float atomicMin(float* address, float val) {
} while (assumed != old);
return __int_as_float(old);
}
#endif
__device__ float dDequantizeFP4(unsigned char val, float absmax)
{
@ -719,10 +736,26 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
}
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
template<typename T, int REQUESTED_BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
#ifdef DISABLE_WARP_32
//CUDA has warpsize 32 so just multiply by 2 to get amd warp size
//currently just stopping below with a return anyways so this size isn't actually used, just needed for compilation
const int BLOCK_SIZE=((REQUESTED_BLOCK_SIZE / NUM_PER_TH % 64) != 0 ) ? REQUESTED_BLOCK_SIZE * 2 : REQUESTED_BLOCK_SIZE;
//TODO: figure out how to make compiler recognize what isn't executed based on template arguments, without the code below in ifndef would trigger static_assert if
//this condition is true
if ((REQUESTED_BLOCK_SIZE / NUM_PER_TH % 64) != 0)
{
printf("kQuantizeBlockwise not fully supported on Rocm! BLOCK_SIZE/NUM_PER_TH needs to be divisible by 64.");
return;
}
#else
const int BLOCK_SIZE=REQUESTED_BLOCK_SIZE;
#endif
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
@ -3956,6 +3989,7 @@ MAKE_optimizerStatic8bit2State(ADAM, float)
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \

View File

@ -5,12 +5,17 @@
#include <ops.cuh>
#include <kernels.cuh>
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
#include <cassert>
#include <common.h>
#ifdef BITS_AND_BYTES_USE_ROCM
#include <hipcub/device/device_scan.hpp>
#else
#include <cub/device/device_scan.cuh>
#endif
using namespace BinSearch;
using std::cout;

View File

@ -12,16 +12,60 @@
#include <unistd.h>
#include <assert.h>
#ifdef BITS_AND_BYTES_USE_ROCM
#ifndef DISABLE_WARP_32
#define __AMDGCN_WAVEFRONT_SIZE 32 // check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's
#endif
#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
#include <hipblas/hipblas.h>
#include <hipsparse/hipsparse.h>
#include <hip/hip_bfloat16.h>
#define cudaPeekAtLastError hipPeekAtLastError
#define cudaMemset hipMemset
#define cublasGemmEx hipblasGemmEx
#define cublasStatus_t hipblasStatus_t
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUDA_R_8I HIPBLAS_R_8I
#define CUDA_R_32I HIPBLAS_R_32I
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define cublasStatus_t hipblasStatus_t
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasOperation_t hipblasOperation_t
#define cublasLtMatrixLayoutCreate hipblasMatrixLayoutCreate
#define cudaError_t hipError_t
#define cudaGetErrorString hipGetErrorString
#define cudaSuccess hipSuccess
#define cusparseStatus_t hipsparseStatus_t
#define CUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS
#define cublasStatus_t hipblasStatus_t
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define cublasHandle_t hipblasHandle_t
#define cublasCreate_v2 hipblasCreate
#define cusparseHandle_t hipsparseHandle_t
#define cusparseCreate hipsparseCreate
#define __nv_bfloat16 hip_bfloat16
#define cublasLtHandle_t hipblasHandle_t
#define cublasLtCreate hipblasCreate
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues
#else
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cusparse.h>
#include <vector>
#include <functional>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#endif
#include <vector>
#include <functional>

View File

@ -5,9 +5,22 @@
#if BUILD_CUDA
#include <ops.cuh>
#ifdef BITS_AND_BYTES_USE_ROCM
#include <hipblas/hipblas.h>
#define cublasLtHandle_t hipblasHandle_t
#define cudaMemAttachHost hipMemAttachHost
#define cudaMallocManaged hipMallocManaged
#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess
#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cudaMemPrefetchAsync hipMemPrefetchAsync
#endif
#endif
#include <cpu_ops.h>
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
// maintain all that boilerplate

View File

@ -93,8 +93,8 @@ private:
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
#endif
IVec<SSE, float> i(u.vec);
IVec<SSE, float> vlem = vz < vxm;
IVec<SSE, float> vlep = vz < vxp;
IVec<SSE, float> vlem = operator< (vz,vxm);
IVec<SSE, float> vlep = operator< (vz,vxp);
i = i + vlem + vlep;
i.store(pr);
}
@ -123,8 +123,8 @@ private:
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
IVec<SSE, double> i(b1, b0);
IVec<SSE, double> vlem = (vz < vxm);
IVec<SSE, double> vlep = (vz < vxp);
IVec<SSE, double> vlem = operator< (vz, vxm);
IVec<SSE, double> vlep = operator< (vz, vxp);
i = i + vlem + vlep;
union {
@ -227,8 +227,8 @@ private:
#endif
IVec<AVX, float> vlem = vz < vxm;
IVec<AVX, float> vlep = vz < vxp;
IVec<AVX, float> vlem = operator< (vz, vxm);
IVec<AVX, float> vlep = operator< (vz, vxp);
ip = ip + vlem + vlep;
ip.store(pr);
@ -277,8 +277,8 @@ private:
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
IVec<AVX, double> i(u.vec);
IVec<AVX, double> vlem = vz < vxm;
IVec<AVX, double> vlep = vz < vxp;
IVec<AVX, double> vlem = operator< (vz,vxm);
IVec<AVX, double> vlep = operator< (vz,vxp);
i = i + vlem + vlep;
i.extractLo32s().store(pr);
}

View File

@ -18,7 +18,7 @@ def read(fname):
setup(
name=f"bitsandbytes",
version=f"0.41.0",
version=f"0.41.1",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
description="k-bit optimizers and matrix multiplication routines.",