CUTLASS compiles.

This commit is contained in:
Tim Dettmers 2023-04-25 17:15:51 -07:00
parent 6e2544da25
commit 84964db937
5 changed files with 20 additions and 14 deletions

View File

@ -1,7 +1,8 @@
MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST)))
ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH)))
GPP:= /usr/bin/g++
#GPP:= /usr/bin/g++
GPP:= /sw/gcc/11.2.0/bin/g++
ifeq ($(CUDA_HOME),)
CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev)
endif
@ -25,7 +26,7 @@ FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
INCLUDE_cutlass := -I $(ROOT_DIR)/dependencies/cutlass/include
INCLUDE_cutlass := -I $(ROOT_DIR)/dependencies/cutlass/include -I $(ROOT_DIR)/dependencies/cutlass/tools/util/include/ -I $(ROOT_DIR)/dependencies/cutlass/include/cute/util/
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
# NVIDIA NVCC compilation flags
@ -104,7 +105,7 @@ cuda11x: $(BUILD_DIR) env
cuda11x_cutlass: $(BUILD_DIR) env cutlass
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(INCLUDE_cutlass) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++20 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
$(GPP) -std=c++17 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda12x: $(BUILD_DIR) env
$(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)

View File

@ -176,7 +176,7 @@ def create_custom_map(seed=0, scale=0.01):
#v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207
#v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30
#v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293
v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88
#v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88
# 7B evo start
#v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06
@ -186,7 +186,7 @@ def create_custom_map(seed=0, scale=0.01):
# 13B evo start
#v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
#v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
#v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
# mean evo 7B + 13B
#v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]

View File

@ -228,6 +228,7 @@ class LinearNF4(Linear4bit):
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
class Int8Params(torch.nn.Parameter):
def __new__(
cls,

View File

@ -12,6 +12,14 @@
#include <cub/block/block_reduce.cuh>
#include <cub/cub.cuh>
#include <math_constants.h>
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
#include "cutlass/util/print_error.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/cublas_wrappers.hpp"
#include "cutlass/util/helper_cuda.hpp"
#define HLF_MAX 65504
#define TH 1024
@ -2709,7 +2717,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
}
}
#define C 1.0f/127.0f
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
template <typename T, int SPMM_ITEMS, int BITS>
@ -2813,7 +2821,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
float valB = local_valsB[k];
float valA = local_valA[i];
if(valB != 0.0 && valA != 0.0)
local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA;
local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA;
}
else
local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
@ -2960,7 +2968,7 @@ void
gemm_device(MShape M, NShape N, KShape K,
TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
TC * C, CStride dC, CBlockLayout , CThreadLayout tC,
TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
Alpha alpha, Beta beta)
{
using namespace cute;
@ -2991,7 +2999,7 @@ gemm_device(MShape M, NShape N, KShape K,
// Represent the full tensors
auto mA = make_tensor(make_gmem_ptr(A), make_shape(M,K), dA); // (M,K)
auto mB = make_tensor(make_gmem_ptr(B), make_shape(N,K), dB); // (N,K)
auto mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N)
auto mC = make_tensor(make_gmem_ptr(out), make_shape(M,N), dC); // (M,N)
// Get the appropriate blocks for this thread block --
// potential for thread block locality
@ -3034,7 +3042,6 @@ gemm_device(MShape M, NShape N, KShape K,
// Clear the accumulators
clear(tCrC);
#if 1
// TUTORIAL: Example of a very simple compute loop
// Data is read from global to shared memory via the tA|tB partitioning
@ -3071,7 +3078,6 @@ gemm_device(MShape M, NShape N, KShape K,
__syncthreads();
}
#endif
axpby(alpha, tCrC, beta, tCgC);
}

View File

@ -666,11 +666,9 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
#include <thrust/host_vector.h>
#include <thrust/device_vector.h>
#include <cute/tensor.hpp>
template <typename TA, typename TB, typename TC,
typename Alpha, typename Beta>
void