4-bit draft.
This commit is contained in:
parent
cad839941b
commit
21723f796a
|
@ -1380,10 +1380,15 @@ def cutlass3_gemm(
|
|||
out: Tensor = None,
|
||||
transposed_A=False,
|
||||
transposed_B=False,
|
||||
state=None
|
||||
):
|
||||
sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
|
||||
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
|
||||
if state is None:
|
||||
Bshape = B.shape
|
||||
else:
|
||||
Bshape = state[1]
|
||||
if out is None:
|
||||
out = torch.zeros(size=sout, dtype=A.dtype, device=A.device)
|
||||
out = torch.zeros(size=(A.shape[0], Bshape[1]), dtype=A.dtype, device=A.device)
|
||||
|
||||
sA = A.shape
|
||||
sB = B.shape
|
||||
|
@ -1456,7 +1461,13 @@ def cutlass3_gemm(
|
|||
# [km, nk -> mn]
|
||||
#lda = ldb = ldc = 1
|
||||
#lda = 1
|
||||
#print(m, n, k, lda, ldb, ldc)
|
||||
if state is not None:
|
||||
m = Bshape[0]
|
||||
k = Bshape[1]
|
||||
lda = Bshape[1]
|
||||
ldc = Bshape[0]
|
||||
ldb = (ldb+1)//2
|
||||
print(m, n, k, lda, ldb, ldc)
|
||||
is_on_gpu([B, A, out])
|
||||
m = ct.c_int32(m)
|
||||
n = ct.c_int32(n)
|
||||
|
@ -1464,7 +1475,10 @@ def cutlass3_gemm(
|
|||
lda = ct.c_int32(lda)
|
||||
ldb = ct.c_int32(ldb)
|
||||
ldc = ct.c_int32(ldc)
|
||||
if A.dtype == torch.float32:
|
||||
|
||||
if B.dtype == torch.uint8:
|
||||
lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.float32:
|
||||
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
|
|
222
csrc/kernels.cu
222
csrc/kernels.cu
|
@ -69,6 +69,27 @@ __device__ float dDequantizeFP4(unsigned char val, float absmax)
|
|||
}
|
||||
}
|
||||
|
||||
__device__ float d2DequantizeFP4(unsigned char val)
|
||||
{
|
||||
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
|
||||
if((val & 0b0110) == 0)
|
||||
{
|
||||
// subnormal
|
||||
if((val & 0b0001) == 0)
|
||||
return 0.0f;
|
||||
else
|
||||
return sign*0.0625f;
|
||||
}
|
||||
else
|
||||
{
|
||||
// normal
|
||||
float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f);
|
||||
float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f;
|
||||
|
||||
return sign*exponent*fraction;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
|
||||
{
|
||||
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
|
||||
|
@ -145,7 +166,7 @@ __device__ unsigned char dQuantizeFP4(float x)
|
|||
return 0b0000+sign;
|
||||
}
|
||||
|
||||
__device__ float dDequantizeNF4(unsigned char val, float absmax)
|
||||
__device__ half dhDequantizeNF4(unsigned char val)
|
||||
{
|
||||
// the values for this tree was generated by test_normal_map_tree
|
||||
// in the file tests/test_functional.py
|
||||
|
@ -153,49 +174,103 @@ __device__ float dDequantizeNF4(unsigned char val, float absmax)
|
|||
if((val & 0b0100) == 4) // 1
|
||||
if((val & 0b0010) == 2) // 11
|
||||
if((val & 0b0001) == 1) // 111
|
||||
return 1.0f*absmax;
|
||||
return 1.0f;
|
||||
else
|
||||
return 0.7229568362236023f*absmax;
|
||||
return 0.7229568362236023f;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 110
|
||||
return 0.5626170039176941f*absmax;
|
||||
return 0.5626170039176941f;
|
||||
else
|
||||
return 0.44070982933044434f*absmax;
|
||||
return 0.44070982933044434f;
|
||||
else
|
||||
if((val & 0b0010) == 2) //10
|
||||
if((val & 0b0001) == 1) // 101
|
||||
return 0.33791524171829224f*absmax;
|
||||
return 0.33791524171829224f;
|
||||
else
|
||||
return 0.24611230194568634f*absmax;
|
||||
return 0.24611230194568634f;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 100
|
||||
return 0.16093020141124725f*absmax;
|
||||
return 0.16093020141124725f;
|
||||
else
|
||||
return 0.07958029955625534f*absmax;
|
||||
return 0.07958029955625534f;
|
||||
|
||||
else
|
||||
if((val & 0b0100) == 4) // 0
|
||||
if((val & 0b0010) == 2) //01
|
||||
if((val & 0b0001) == 1) // 011
|
||||
return 0.0f*absmax;
|
||||
return 0.0f;
|
||||
else
|
||||
return -0.09105003625154495f*absmax;
|
||||
return -0.09105003625154495f;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 010
|
||||
return -0.18477343022823334f*absmax;
|
||||
return -0.18477343022823334f;
|
||||
else
|
||||
return -0.28444138169288635f*absmax;
|
||||
return -0.28444138169288635f;
|
||||
else
|
||||
if((val & 0b0010) == 2) //00
|
||||
if((val & 0b0001) == 1) // 001
|
||||
return -0.39491748809814453f*absmax;
|
||||
return -0.39491748809814453f;
|
||||
else
|
||||
return -0.5250730514526367f*absmax;
|
||||
return -0.5250730514526367f;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 000
|
||||
return -0.6961928009986877f*absmax;
|
||||
return -0.6961928009986877f;
|
||||
else
|
||||
return -1.0f*absmax;
|
||||
return -1.0f;
|
||||
|
||||
}
|
||||
|
||||
__device__ float dDequantizeNF4(unsigned char val)
|
||||
{
|
||||
// the values for this tree was generated by test_normal_map_tree
|
||||
// in the file tests/test_functional.py
|
||||
if((val & 0b1000) == 8)
|
||||
if((val & 0b0100) == 4) // 1
|
||||
if((val & 0b0010) == 2) // 11
|
||||
if((val & 0b0001) == 1) // 111
|
||||
return 1.0f;
|
||||
else
|
||||
return 0.7229568362236023f;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 110
|
||||
return 0.5626170039176941f;
|
||||
else
|
||||
return 0.44070982933044434f;
|
||||
else
|
||||
if((val & 0b0010) == 2) //10
|
||||
if((val & 0b0001) == 1) // 101
|
||||
return 0.33791524171829224f;
|
||||
else
|
||||
return 0.24611230194568634f;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 100
|
||||
return 0.16093020141124725f;
|
||||
else
|
||||
return 0.07958029955625534f;
|
||||
|
||||
else
|
||||
if((val & 0b0100) == 4) // 0
|
||||
if((val & 0b0010) == 2) //01
|
||||
if((val & 0b0001) == 1) // 011
|
||||
return 0.0f;
|
||||
else
|
||||
return -0.09105003625154495f;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 010
|
||||
return -0.18477343022823334f;
|
||||
else
|
||||
return -0.28444138169288635f;
|
||||
else
|
||||
if((val & 0b0010) == 2) //00
|
||||
if((val & 0b0001) == 1) // 001
|
||||
return -0.39491748809814453f;
|
||||
else
|
||||
return -0.5250730514526367f;
|
||||
else
|
||||
if((val & 0b0001) == 1) // 000
|
||||
return -0.6961928009986877f;
|
||||
else
|
||||
return -1.0f;
|
||||
|
||||
}
|
||||
|
||||
|
@ -800,8 +875,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
|
|||
#pragma unroll NUM_PER_TH
|
||||
for(int j = 0; j < NUM_PER_TH; j++)
|
||||
{
|
||||
vals[j*2] = dDequantizeNF4(qvals[j] >> 4, local_abs_max);
|
||||
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F, local_abs_max);
|
||||
vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max;
|
||||
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
@ -2947,7 +3022,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
|
|||
//// 9. write outputs to matmul output matrix
|
||||
//}
|
||||
|
||||
template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit)
|
||||
template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f)
|
||||
{
|
||||
if(limit_base + ITEMS <= limit)
|
||||
reinterpret_cast<TCAST*>(local)[0] = reinterpret_cast<TCAST*>(buffer)[idx/ITEMS];
|
||||
|
@ -2958,7 +3033,7 @@ template <typename T, typename TCAST, int ITEMS> __device__ inline void vector_l
|
|||
if(limit_base + k < limit)
|
||||
local[k] = buffer[idx+k];
|
||||
else
|
||||
local[k] = 0.0f;
|
||||
local[k] = (T)zero_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3024,6 +3099,109 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
|
|||
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
|
||||
}
|
||||
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
typedef cub::BlockReduce<T, THREADS> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage reduce;
|
||||
int col_offset = blockIdx.x *8;
|
||||
|
||||
T local_A[32];
|
||||
unsigned char local_B_4bit[16];
|
||||
T local_B[32];
|
||||
T local_C[8];
|
||||
|
||||
__shared__ T smem_C[8];
|
||||
|
||||
if(threadIdx.x < 8)
|
||||
smem_C[threadIdx.x] = T(0);
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
local_C[k] = T(0);
|
||||
|
||||
|
||||
for(int idx = threadIdx.x*32; idx < K; idx+=blockDim.x*32)
|
||||
{
|
||||
|
||||
// we load only 8 values per iteration from A, so we
|
||||
// need to do 4 loads for every single load from B
|
||||
// for B, we have packed values, so the 16 8-bit values
|
||||
// turn into 32 4-bit values to 4x 4 loads turns into 4x 8 loads
|
||||
vector_load<T, int4, 8>(local_A, A, idx, idx, K);
|
||||
vector_load<T, int4, 8>(&(local_A[8]), A, idx+8, idx+8, K);
|
||||
vector_load<T, int4, 8>(&(local_A[16]), A, idx+16, idx+16, K);
|
||||
vector_load<T, int4, 8>(&(local_A[24]), A, idx+24, idx+24, K);
|
||||
|
||||
for(int col = 0; col < 8; col++)
|
||||
{
|
||||
if((col + col_offset) >= M){ break; }
|
||||
|
||||
int offset_B = (col_offset+col)*ldb;
|
||||
// 0111 -> 0.0f in NF4
|
||||
// since we have packed 8-bits, we need cat(0b0111, 0b0111) = 0b01110111
|
||||
vector_load<unsigned char, int4, 16>(local_B_4bit, B, (offset_B+idx+1)/2, (idx+1)/2, (K+1)/2, 0b01110111);
|
||||
|
||||
int absidx = (idx + offset_B)/blocksize;
|
||||
half local_absmax = __ldg(&(absmax[absidx]));
|
||||
//for(int k = 0; k < 16; k++)
|
||||
//printf("%i %i ", local_B_4bit[k] >> 4, local_B_4bit[k] & 0x0F);
|
||||
//printf("\n");
|
||||
|
||||
//vector_load<T, int4, 8>(local_A, A, idx, idx, K);
|
||||
|
||||
#pragma unroll 16
|
||||
for(int k = 0; k < 16; k++)
|
||||
{
|
||||
|
||||
//if(local_B_4bit[k ] != 0b01110111)
|
||||
//printf("(%i %i %i) %i -> %f, %i -> %f\n", threadIdx.x , k, K, local_B_4bit[k ] >> 4, dDequantizeNF4(local_B_4bit[k ] >> 4, local_absmax),
|
||||
//local_B_4bit[k ] & 0x0F, dDequantizeNF4(local_B_4bit[k ] & 0x0F, local_absmax));
|
||||
//local_B[k*2] = d2DequantizeFP4(local_B_4bit[k] >> 4);//*local_absmax;
|
||||
//local_B[k*2 + 1] = d2DequantizeFP4(local_B_4bit[k] & 0x0F);//*local_absmax;
|
||||
local_B[k*2] = (half)(local_B_4bit[k] >> 4)*local_absmax;
|
||||
local_B[k*2 + 1] = (half)(local_B_4bit[k] & 0x0F)*local_absmax;
|
||||
//local_B[k*2] = (half)dDequantizeNF4(local_B_4bit[k ] >> 4);//*local_absmax;
|
||||
//local_B[k*2 + 1] = (half)dDequantizeNF4(local_B_4bit[k ] & 0x0F);//*local_absmax;
|
||||
}
|
||||
|
||||
#pragma unroll 32
|
||||
//for(int k = 0; k < 8; k++)
|
||||
for(int k = 0; k < 32; k++)
|
||||
{
|
||||
local_C[col] += local_A[k]*local_B[k];
|
||||
//if((float)local_A[k] != 0.0 && (float)local_B[k] != 0.0)
|
||||
//if((float)local_B[k] != 0.0)
|
||||
//printf("%i %i %i %i %f*%f\n", threadIdx.x, k, col, (float)local_A[k], (float)local_B[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
{
|
||||
local_C[k] = BlockReduce(reduce).Reduce(local_C[k], cub::Sum());
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
#pragma unroll 8
|
||||
for(int k = 0; k < 8; k++)
|
||||
smem_C[k] = local_C[k];
|
||||
}
|
||||
else if(threadIdx.x >= 32)
|
||||
// early return for unused warps
|
||||
return;
|
||||
|
||||
__syncwarp();
|
||||
|
||||
|
||||
if(threadIdx.x < 8 && col_offset + threadIdx.x < M)
|
||||
out[col_offset + threadIdx.x ] = smem_C[threadIdx.x];
|
||||
}
|
||||
|
||||
//#define ROWS 2
|
||||
//template <typename T, int ITEMS, int THREADS> __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc)
|
||||
//{
|
||||
|
@ -3207,6 +3385,8 @@ template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half *
|
|||
template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
|
||||
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
|
||||
|
||||
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
|
||||
//template __global__ void kMatmul_inference_4bit<NF4, half, half, half>(half *A, unsigned char *B, half *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
template __global__ void with_staging_unified<2>(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
||||
|
|
|
@ -139,5 +139,6 @@ template <size_t stages_count /* Pipeline with stages_count stages */>
|
|||
__global__ void with_staging_unified(float const* global_in, float * global_out, size_t size, size_t batch_sz);
|
||||
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
#endif
|
||||
|
|
18
csrc/ops.cu
18
csrc/ops.cu
|
@ -695,10 +695,28 @@ template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out
|
|||
gemm_device<T, 16, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
}
|
||||
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
dim3 dimBlock(128);
|
||||
int num_blocks = (m+7)/8;
|
||||
|
||||
cout << num_blocks << endl;
|
||||
cout << lda << endl;
|
||||
cout << ldb << endl;
|
||||
cout << ldc << endl;
|
||||
|
||||
cout << m << endl;
|
||||
cout << n << endl;
|
||||
cout << k << endl;
|
||||
kgemm_4bit_inference<T, 128><<< num_blocks, dimBlock, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
|
|
@ -191,6 +191,7 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
|||
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
|
||||
void pipeline_test(float *A, float *B, size_t n, size_t batch_size);
|
||||
|
|
|
@ -25,6 +25,9 @@ void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, in
|
|||
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
|
||||
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
|
||||
|
||||
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
|
||||
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
|
||||
void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
|
@ -319,6 +322,9 @@ extern "C"
|
|||
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
|
||||
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||
|
||||
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
|
|
|
@ -2352,8 +2352,8 @@ def test_normal_map_tree():
|
|||
print(pivots)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_cutlass3_gemm(dtype):
|
||||
for i in range(1):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
|
@ -2373,6 +2373,32 @@ def test_cutlass3_gemm(dtype):
|
|||
|
||||
torch.testing.assert_close(C1, C2, atol=1e-05, rtol=0.005)
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_gemm_4bit(dtype):
|
||||
for i in range(1):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#torch.random.manual_seed(17)
|
||||
A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B)
|
||||
|
||||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
|
||||
|
||||
C1 = torch.matmul(A, B.t())
|
||||
#C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
#print(C1)
|
||||
#print(C2)
|
||||
|
||||
#torch.testing.assert_close(C1, C2, atol=1e-5, rtol=0.005)
|
||||
|
||||
|
||||
def test_pipeline_func():
|
||||
a = torch.rand(2, 4).cuda()
|
||||
|
|
Loading…
Reference in New Issue
Block a user