Use workaround for ROCm wave32 recognition

just sets __AMDGCN_WAVEFRONT_SIZE forcefully to 32.
Not correct (some GPU's don't support wave32), but works
on the supported GPU's. Can disable with DISABLE_WARP_32

With this blockwise quantize works and with that nf4 is supported.
This commit is contained in:
arlo-phoenix 2023-08-08 18:50:26 +00:00
parent 615d47583f
commit 0b481bfcc2
5 changed files with 31 additions and 13 deletions

View File

@ -10,9 +10,23 @@ ifeq ($(ROCM_HOME),)
ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f4- | rev) ROCM_HOME:= $(shell which hipcc | rev | cut -d'/' -f4- | rev)
endif endif
ifneq ($(CUDA_HOME),)
ifndef CUDA_VERSION
$(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU)
CUDA_VERSION:=
endif
else ifneq ($(ROCM_HOME),)
ifndef ROCM_TARGET
$(error ERROR: ROCM_TARGET not set. Call make with ROCM string (see https://www.llvm.org/docs/AMDGPUUsage.html#processors), for example: make hip ROCM_TARGET=gfx1030)
ROCM_TARGET:=
endif
endif
NVCC := $(CUDA_HOME)/bin/nvcc NVCC := $(CUDA_HOME)/bin/nvcc
HIPCC:= $(ROCM_HOME)/bin/hipcc
########################################### ###########################################
@ -114,12 +128,9 @@ cpuonly: $(BUILD_DIR) env
HIP_INCLUDE := -I $(ROCM_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include HIP_INCLUDE := -I $(ROCM_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include
HIP_LIB := -L $(ROCM_HOME)/lib -lhipblas -lhiprand -lhipsparse #-lhipblaslt #TODO: check if this is actually only gfx90a HIP_LIB := -L $(ROCM_HOME)/lib -lhipblas -lhiprand -lhipsparse #-lhipblaslt #TODO: check if this is actually only gfx90a
hip: $(BUILD_DIR) hip: $(BUILD_DIR)
# Add --offload-arch=gfx1030 if this fails $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/ops.cu
/usr/bin/hipcc -std=c++14 -c -fPIC $(HIP_INCLUDE) -o $(BUILD_DIR)/ops.o -D NO_CUBLASLT -D BITS_AND_BYTES_USE_ROCM $(CSRC)/ops.cu $(HIPCC) -std=c++14 -c -fPIC --offload-arch=$(ROCM_TARGET) $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -DNO_CUBLASLT -DBITS_AND_BYTES_USE_ROCM $(CSRC)/kernels.cu
/usr/bin/hipcc -std=c++14 -c -fPIC $(HIP_INCLUDE) -o $(BUILD_DIR)/kernels.o -D NO_CUBLASLT -D BITS_AND_BYTES_USE_ROCM $(CSRC)/kernels.cu
# /usr/bin/hipcc -fPIC -static $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.so
# HCC is deprecated, but used by hipBLASlt header. Since blas isn't even used doesn't matter, this is just so that it even compiles # HCC is deprecated, but used by hipBLASlt header. Since blas isn't even used doesn't matter, this is just so that it even compiles
$(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBITS_AND_BYTES_USE_ROCM -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so $(GPP) -std=c++14 -D__HIP_PLATFORM_HCC__ -D__HIP_PLATFORM_AMD__ -DBUILD_CUDA -DBITS_AND_BYTES_USE_ROCM -shared -fPIC $(HIP_INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(FILES_CPP) $(HIP_LIB) -o ./bitsandbytes/libbitsandbytes_hip_nohipblaslt.so

View File

@ -815,6 +815,7 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
if out is None: if out is None:
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
#TODO: catch rocm wave64 only, pytorch has a property, but that one likely contains the wrong waveSize
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
prev_device = pre_call(A.device) prev_device = pre_call(A.device)

View File

@ -736,23 +736,26 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c
} }
} }
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> template<typename T, int REQUESTED_BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
//__launch_bounds__(TH, 4) //__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{ {
#ifdef BITS_AND_BYTES_USE_ROCM #ifdef DISABLE_WARP_32
printf("kQuantizeBlockwise is not supported on Rocm!"); //CUDA has warpsize 32 so just multiply by 2 to get amd warp size
//currently just stopping below with a return anyways so this size isn't actually used, just needed for compilation
const int BLOCK_SIZE=((REQUESTED_BLOCK_SIZE / NUM_PER_TH % 64) != 0 ) ? REQUESTED_BLOCK_SIZE * 2 : REQUESTED_BLOCK_SIZE;
//TODO: figure out how to make compiler recognize what isn't executed based on template arguments, without the code below in ifndef would trigger static_assert if //TODO: figure out how to make compiler recognize what isn't executed based on template arguments, without the code below in ifndef would trigger static_assert if
//this condition is true //this condition is true
if ((BLOCK_SIZE / NUM_PER_TH % 64) != 0) if ((REQUESTED_BLOCK_SIZE / NUM_PER_TH % 64) != 0)
{ {
printf("kQuantizeBlockwise not fully supported on Rocm! BLOCK_SIZE/NUM_PER_TH needs to be divisible by 64."); printf("kQuantizeBlockwise not fully supported on Rocm! BLOCK_SIZE/NUM_PER_TH needs to be divisible by 64.");
return; return;
} }
#else
const int BLOCK_SIZE=REQUESTED_BLOCK_SIZE;
#endif #endif
#ifndef BITS_AND_BYTES_USE_ROCM
const int n_full = gridDim.x * BLOCK_SIZE; const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0; int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE); const int base_idx = (blockIdx.x * BLOCK_SIZE);
@ -854,7 +857,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
__syncthreads(); __syncthreads();
StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
} }
#endif
} }
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>

View File

@ -11,7 +11,7 @@
#include <common.h> #include <common.h>
#ifdef BITS_AND_BYTES_USE_ROCM #ifdef BITS_AND_BYTES_USE_ROCM
//#include <hipcub/device/device_scan.hpp> #include <hipcub/device/device_scan.hpp>
#else #else
#include <cub/device/device_scan.cuh> #include <cub/device/device_scan.cuh>
#endif #endif

View File

@ -14,6 +14,10 @@
#ifdef BITS_AND_BYTES_USE_ROCM #ifdef BITS_AND_BYTES_USE_ROCM
#ifndef DISABLE_WARP_32
#define __AMDGCN_WAVEFRONT_SIZE 32 // check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's
#endif
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h> #include <hip/hip_fp16.h>
#include <hipblas/hipblas.h> #include <hipblas/hipblas.h>