Added paging.
This commit is contained in:
parent
264a948539
commit
ec38ba95b0
|
@ -26,6 +26,8 @@ try:
|
|||
lib.cadam_8bit_blockwise_fp32
|
||||
lib.get_context.restype = ct.c_void_p
|
||||
lib.get_cusparse.restype = ct.c_void_p
|
||||
lib.cget_managed_ptr.restype = ct.c_void_p
|
||||
lib.cget_stream.restype = ct.c_void_p
|
||||
COMPILED_WITH_CUDA = True
|
||||
except AttributeError:
|
||||
warn("The installed version of bitsandbytes was compiled without GPU support. "
|
||||
|
|
|
@ -130,6 +130,61 @@ class Cusparse_Context:
|
|||
cls._instance.initialize()
|
||||
return cls._instance
|
||||
|
||||
dtype2bytes = {}
|
||||
dtype2bytes[torch.float32] = 4
|
||||
dtype2bytes[torch.float16] = 2
|
||||
dtype2bytes[torch.bfloat16] = 2
|
||||
dtype2bytes[torch.uint8] = 1
|
||||
dtype2bytes[torch.int8] = 1
|
||||
|
||||
def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)):
|
||||
num_bytes = dtype2bytes[dtype]*prod(shape)
|
||||
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
|
||||
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
|
||||
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
|
||||
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape))
|
||||
out.is_paged = True
|
||||
out.page_deviceid = device.index
|
||||
return out
|
||||
|
||||
def prefetch_tensor(A, to_cpu=False):
|
||||
assert A.is_paged, 'Only paged tensors can be prefetched!'
|
||||
if to_cpu:
|
||||
deviceid = -1
|
||||
else:
|
||||
deviceid = A.page_deviceid
|
||||
|
||||
num_bytes = dtype2bytes[A.dtype]*A.numel()
|
||||
lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
|
||||
|
||||
def elementwise_func(func_name, A, B, value, prefetch=True):
|
||||
func = None
|
||||
if A.dtype == torch.float32:
|
||||
func = getattr(lib, f'c{func_name}_fp32', None)
|
||||
cvalue = ct.c_float(value)
|
||||
elif A.dtype == torch.uint8:
|
||||
func = getattr(lib, f'c{func_name}_uint8', None)
|
||||
cvalue = ct.c_uint8(value)
|
||||
|
||||
if func is None: raise NotImplementedError(f'Function not implemented: {func_name}')
|
||||
|
||||
is_managed = getattr(A, 'is_managed', False)
|
||||
if is_managed and prefetch:
|
||||
prefetch_tensor(A)
|
||||
if B is not None: prefetch_tensor(B)
|
||||
|
||||
func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel()))
|
||||
if A.is_paged or B.is_paged:
|
||||
# paged function are fully asynchronous
|
||||
# if we return from this function, we want to the tensor
|
||||
# to be in the correct state, that is the final state after the
|
||||
# operation occured. So we synchronize.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value)
|
||||
def arange(A, device=None): elementwise_func('arange', A, None, 0)
|
||||
def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0)
|
||||
|
||||
|
||||
def create_linear_map(signed=True, total_bits=8, add_zero=True):
|
||||
sign = (-1.0 if signed else 0.0)
|
||||
|
|
|
@ -3522,49 +3522,23 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
|
|||
//}
|
||||
|
||||
|
||||
__device__ void compute(float* global_out, float const* shared_in)
|
||||
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n)
|
||||
{
|
||||
|
||||
}
|
||||
template <size_t stages_count /* Pipeline with stages_count stages */>
|
||||
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz) {
|
||||
auto grid = cooperative_groups::this_grid();
|
||||
auto block = cooperative_groups::this_thread_block();
|
||||
assert(size == batch_sz * grid.size()); // Assume input size fits batch_sz * grid_size
|
||||
|
||||
extern __shared__ float shared[]; // stages_count * block.size() * sizeof(int) bytes
|
||||
size_t shared_offset[stages_count];
|
||||
for (int s = 0; s < stages_count; ++s) shared_offset[s] = s * block.size();
|
||||
|
||||
__shared__ cuda::pipeline_shared_state<
|
||||
cuda::thread_scope::thread_scope_block,
|
||||
stages_count
|
||||
> shared_state;
|
||||
auto pipeline = cuda::make_pipeline(block, &shared_state);
|
||||
|
||||
auto block_batch = [&](size_t batch) -> int {
|
||||
return block.group_index().x * block.size() + grid.size() * batch;
|
||||
};
|
||||
|
||||
// compute_batch: next batch to process
|
||||
// fetch_batch: next batch to fetch from global memory
|
||||
for (size_t compute_batch = 0, fetch_batch = 0; compute_batch < batch_sz; ++compute_batch) {
|
||||
// The outer loop iterates over the computation of the batches
|
||||
for (; fetch_batch < batch_sz && fetch_batch < (compute_batch + stages_count); ++fetch_batch) {
|
||||
// This inner loop iterates over the memory transfers, making sure that the pipeline is always full
|
||||
pipeline.producer_acquire();
|
||||
size_t shared_idx = fetch_batch % stages_count;
|
||||
size_t batch_idx = fetch_batch;
|
||||
size_t block_batch_idx = block_batch(batch_idx);
|
||||
cuda::memcpy_async(block, shared + shared_offset[shared_idx], global_in + block_batch_idx, sizeof(float) * block.size(), pipeline);
|
||||
pipeline.producer_commit();
|
||||
}
|
||||
pipeline.consumer_wait();
|
||||
int shared_idx = compute_batch % stages_count;
|
||||
int batch_idx = compute_batch;
|
||||
compute(global_out + block_batch(batch_idx), shared + shared_offset[shared_idx]);
|
||||
pipeline.consumer_release();
|
||||
for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x))
|
||||
{
|
||||
switch(FUNC)
|
||||
{
|
||||
case FILL:
|
||||
A[i] = (T)value;
|
||||
break;
|
||||
case ARANGE:
|
||||
A[i] = (T)i;
|
||||
break;
|
||||
case _MUL:
|
||||
A[i] = A[i]*B[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -3572,19 +3546,10 @@ __global__ void with_staging_unified(float const* global_in, float * global_out,
|
|||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
//template <class MShape, class NShape, class KShape,
|
||||
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
|
||||
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
|
||||
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
|
||||
// class Alpha, class Beta>
|
||||
//__global__ static
|
||||
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
|
||||
//void
|
||||
//gemm_device(MShape M, NShape N, KShape K,
|
||||
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
|
||||
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||
// half alpha, half beta);
|
||||
template __global__ void kfunc<float, FILL>(float *A, float *B, float value, long n);
|
||||
template __global__ void kfunc<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
|
||||
template __global__ void kfunc<float, ARANGE>(float *A, float *B, float value, long n);
|
||||
template __global__ void kfunc<float, _MUL>(float *A, float *B, float value, long n);
|
||||
|
||||
// these are not used and make no sense, but the compiler needs them
|
||||
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||
|
@ -3611,9 +3576,6 @@ template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * _
|
|||
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
|
||||
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
||||
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
||||
|
|
|
@ -122,23 +122,9 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
|
||||
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
||||
//template <class MShape, class NShape, class KShape,
|
||||
// class TA, class AStride, class ABlockLayout, class AThreadLayout,
|
||||
// class TB, class BStride, class BBlockLayout, class BThreadLayout,
|
||||
// class TC, class CStride, class CBlockLayout, class CThreadLayout,
|
||||
// class Alpha, class Beta>
|
||||
//__global__ static
|
||||
//__launch_bounds__(decltype(size(CThreadLayout{}))::value)
|
||||
//void
|
||||
//gemm_device(MShape M, NShape N, KShape K,
|
||||
// TA const* A, AStride dA, ABlockLayout blockA, AThreadLayout tA,
|
||||
// TB const* B, BStride dB, BBlockLayout blockB, BThreadLayout tB,
|
||||
// TC * out, CStride dC, CBlockLayout , CThreadLayout tC,
|
||||
// Alpha alpha, Beta beta);
|
||||
template <size_t stages_count /* Pipeline with stages_count stages */>
|
||||
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
||||
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
|
||||
|
||||
#endif
|
||||
|
|
25
csrc/ops.cu
25
csrc/ops.cu
|
@ -663,16 +663,6 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
|||
}
|
||||
|
||||
|
||||
void pipeline_test(float *A, float *B, size_t n, size_t batch_size)
|
||||
{
|
||||
|
||||
int threads = 256;
|
||||
int num_blocks = (n+(256*batch_size)+1)/(batch_size*256);
|
||||
|
||||
with_staging_unified<2><<<num_blocks, threads>>>(A, B, n, batch_size);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
|
||||
|
@ -717,10 +707,25 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
|
|||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
||||
{
|
||||
int threads = 512;
|
||||
int blocks = n/threads;
|
||||
blocks = n % threads == 0 ? blocks : blocks + 1;
|
||||
blocks = blocks > 65535 ? 65535 : blocks;
|
||||
kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template void func<float, FILL>(float *A, float *B, float value, long n);
|
||||
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
|
||||
template void func<float, ARANGE>(float *A, float *B, float value, long n);
|
||||
template void func<float, _MUL>(float *A, float *B, float value, long n);
|
||||
|
||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
|
|
|
@ -93,6 +93,13 @@ typedef enum DataType_t
|
|||
NF4 = 2,
|
||||
} DataType_t;
|
||||
|
||||
typedef enum Funcs_t
|
||||
{
|
||||
FILL = 0,
|
||||
ARANGE = 1,
|
||||
_MUL = 2,
|
||||
} Funcs_t;
|
||||
|
||||
class Context
|
||||
{
|
||||
public:
|
||||
|
@ -193,6 +200,6 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
|
|||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
|
||||
|
||||
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);
|
||||
#endif
|
||||
|
|
|
@ -28,6 +28,14 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
|
|||
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
||||
|
||||
MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
|
||||
MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
|
||||
MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
||||
MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
||||
|
||||
|
||||
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
|
||||
void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
|
@ -314,7 +322,6 @@ extern "C"
|
|||
|
||||
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
|
||||
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
||||
void cpipeline_test(float *A, float *B, size_t n, size_t batch_size){ pipeline_test(A, B, n, batch_size); }
|
||||
|
||||
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
|
||||
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||
|
@ -325,6 +332,29 @@ extern "C"
|
|||
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void *cget_managed_ptr(size_t bytes)
|
||||
{
|
||||
void *ptr;
|
||||
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void cprefetch(void *ptr, size_t bytes, int device)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \
|
||||
|
||||
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
|
||||
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
|
||||
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
||||
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
|
|
|
@ -2489,8 +2489,38 @@ def test_gemm_4bit(dtype):
|
|||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
|
||||
def test_pipeline_func():
|
||||
a = torch.rand(2, 4).cuda()
|
||||
out = F.pipeline_test(a, 2)
|
||||
print(a)
|
||||
print(out)
|
||||
def test_managed():
|
||||
n = 32*10
|
||||
A = F.get_paged(n, n, dtype=torch.float32)
|
||||
B = F.get_paged(n, n, dtype=torch.uint8)
|
||||
B2 = F.get_paged(n, n, dtype=torch.float32)
|
||||
assert A.is_paged
|
||||
assert B.is_paged
|
||||
assert A.page_deviceid==0
|
||||
assert B.page_deviceid==0
|
||||
F.fill(A, 17.0)
|
||||
F.fill(B, 17)
|
||||
F.fill(B2, 2)
|
||||
assert (A==17).sum().item() == n*n
|
||||
assert (B==17).sum().item() == n*n
|
||||
C = A*B.float()
|
||||
assert (C==289).sum().item() == n*n
|
||||
F._mul(A, B2)
|
||||
F._mul(A, B2)
|
||||
F._mul(A, B2)
|
||||
assert (A==17*(2**3)).sum().item() == n*n
|
||||
# F.prefetch_tensor(A)
|
||||
# F.prefetch_tensor(B)
|
||||
|
||||
|
||||
# F.fill(B2, 17.0)
|
||||
# F._mul(A, B2)
|
||||
|
||||
# F.prefetch_tensor(A, to_cpu=True)
|
||||
# F.prefetch_tensor(B, to_cpu=True)
|
||||
# F.prefetch_tensor(B2, to_cpu=True)
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# assert (A==17).sum().item() == n*n
|
||||
|
||||
# torch.testing.assert_allclose(A, torch.ones(A.shape)*289)
|
||||
|
|
Loading…
Reference in New Issue
Block a user