diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index eeef93b..7848b7e 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -562,6 +562,6 @@ def matmul( def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): assert quant_state is not None if A.numel() == A.shape[-1] and A.requires_grad == False: - return F.cutlass3_gemm(A, B.t(), out, state=quant_state) + return F.gemv_4bit(A, B.t(), out, state=quant_state) else: return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95a15d5..e09b267 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -240,17 +240,19 @@ def create_normal_map(offset=0.9677083, use_extra_value=True): v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() v2 = [0]*(256-15) ## we have 15 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() - v = v1 + v2 + v3 else: v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() v2 = [0]*(256-14) ## we have 14 non-zero values in this data type v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() - v = v1 + v2 + v3 + + v = v1 + v2 + v3 values = torch.Tensor(v) values = values.sort().values values /= values.max() + assert values.numel() == 256 + return values def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): @@ -710,6 +712,47 @@ def dequantize_blockwise( return out +def get_4bit_type(typename, device=None, blocksize=64): + if device is None: device = 'cuda' + data = None + if typename == 'nf4': + data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, + -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, + 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, + 0.7229568362236023, 1.0] + elif typename == 'fp4': + # 0b000 = 0 + # 0b001 = 0.0625 + # 0b010 = 8 + # 0b011 = 12 + # 0b100 = 4 + # 0b101 = 6 + # 0b110 = 2 + # 0b111 = 3 + data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0] + elif typename == 'int4': + data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7] + elif typename == 'af4': + # Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good) + # https://arxiv.org/abs/2306.06965 + if blocksize == 64: + data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478, + -0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666, + 0.42563882, 0.55496234, 0.72424863, 1.][::-1] + else: + raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.') + + if data is None: + raise NotImplementedError(f'Typename {typename} not supported') + + data = Tensor(data) + data /= data.abs().max() + assert data.numel() == 16 + + return data.to(device) + + + def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') @@ -783,6 +826,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) + datatype = get_4bit_type(quant_type, device=A.device) + if compress_statistics: offset = absmax.mean() absmax -= offset @@ -790,9 +835,9 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype] else: - state = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype] return out, state @@ -839,7 +884,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: shape = out.shape dtype = out.dtype else: - absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state + absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state if compressed_stats is not None: @@ -1408,13 +1453,14 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 return sout -def cutlass3_gemm( +def gemv_4bit( A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False, - state=None + state=None, + storage_type='nf4' ): #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) if state is None: @@ -1491,8 +1537,6 @@ def cutlass3_gemm( ldb = sA[2] ldc = m - ptr = CUBLAS_Context.get_instance().get_context(A.device) - # B^T @ A^T = C^T # [km, nk -> mn] #lda = ldb = ldc = 1 @@ -1514,15 +1558,11 @@ def cutlass3_gemm( if B.dtype == torch.uint8: if A.dtype == torch.float16: - lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) elif A.dtype == torch.bfloat16: - lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') - elif A.dtype == torch.float32: - lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) - elif A.dtype == torch.float16: - lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) else: raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') diff --git a/csrc/kernels.cu b/csrc/kernels.cu index dd1f6f2..4131477 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3520,7 +3520,7 @@ template __global__ void kgemm_4bit_inference(int M, i } #define num_values_4bit 32 -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { // per threadblock: @@ -3568,7 +3568,9 @@ template __global__ void kgemm_4bit_inference_naive(in { #pragma unroll for(int j = 0; j < (num_values_8bit); j++) - if((inner_idx/2) + j < K) + if((inner_idx_halved) + j < K) + local_B_4bit[j] = B[offset_B+inner_idx_halved + j]; + else local_B_4bit[j] = 0b01110111; } } @@ -3578,6 +3580,9 @@ template __global__ void kgemm_4bit_inference_naive(in { local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax; local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax; + + //if(threadIdx.x == 0) + //printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax); } if(inner_idx+num_values_4bit) @@ -3773,8 +3778,8 @@ template __global__ void kgemm_4bit_inference(int M, int N, int K, ha template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 05d0715..d5349d6 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -125,7 +125,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kfunc(T *A, T *B, T value, long n); diff --git a/csrc/ops.cu b/csrc/ops.cu index 902129f..8bcee2c 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -729,12 +729,12 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) { int num_blocks = (m+3)/4; - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } template void func(T *A, T *B, T value, long n) @@ -757,8 +757,8 @@ template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index e4df195..699ff20 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); template void func(T *A, T *B, T value, long n); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 4cbabae..b1a079f 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -28,11 +28,11 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } -void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -394,11 +394,11 @@ extern "C" CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) - void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } - void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/tests/test_functional.py b/tests/test_functional.py index 552ccaa..54af27d 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1816,7 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden): linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4) + F.gemv_4bit(A, B_nf4.t(), state=state_nf4) # warmup for i in range(iters): @@ -1849,7 +1849,7 @@ def test_bench_matmul(batch, seq, model, hidden): t0 = time.time() for i in range(iters): #bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) - F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4) + F.gemv_4bit(A, B_nf4.t(), state=state_nf4) torch.cuda.synchronize() print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) @@ -2351,76 +2351,14 @@ def test_normal_map_tree(): print(pivots) -#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) -@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) -def test_cutlass3_gemm(dtype): - debug = True - #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: - #for dim in [4096, 5120, 6656, 8192]: - for dim in [4096]: - #for dim in [128+1]: - errs = [] - relerrs = [] - max_err = 0 - max_relerr = 0 - for i in range(100): - A = torch.randn(1, dim, dtype=dtype, device='cuda') - B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) - #B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim) - - #print('') - #print(A) - #print(B.t()) - #A[:, :-1] = 0 - #B[:, :-1] = 0 - - - C1 = torch.matmul(A, B.t()) - C2 = F.cutlass3_gemm(A, B.t()) - - # tensor cores are non-deterministic - # so we need to analyze errors around the mean - # to test our implementation - err = torch.abs(C1-C2) - mag = torch.abs(C1)+1e-8 - relerr = err/mag - max_err = max(err.max(), max_err) - max_relerr = max(relerr.max(), max_relerr) - err = err.mean().item() - relerr = relerr.mean().item() - - errs.append(err) - relerrs.append(relerr) - - #if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: - # print('') - # print(i, err, relerr) - # print(A.flatten()[-6:]) - # print(B.flatten()[-6:]) - # out = A.flatten()[-6:]*B.flatten()[-6:] - # print(out) - # print(out[:-1].sum()) - # print('='*80) - # print(C1.flatten()[-6:]) - # print(C2.flatten()[-6:]) - # #assert False, 'ERROR' - - c = int(C1.numel()*0.0014*(dim/256))+1 - - c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug) - #print(c/math.sqrt(dim)) - print('') - print(dim, sum(errs)/len(errs)/math.sqrt(dim)) - print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) - print(dim, (max_err.item(), max_relerr.item())) - #@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16']) #@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16']) -def test_gemm_4bit(dtype): +def test_gemv_4bit(dtype): print('') - #for dim in [64, 128, 256, 512, 1024, 2048, 4096]: - for dim in [4*1024]: + for dim in [64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4*1024]: + #for dim in [1*16]: errs = [] relerrs = [] max_err = 0 @@ -2446,9 +2384,10 @@ def test_gemm_4bit(dtype): qB, state = F.quantize_nf4(B) F.dequantize_nf4(qB, state) - #C2 = bnb.matmul_4bit(A, qB.t(), state) - C2 = F.cutlass3_gemm(A, qB.t(), state=state) - C1 = torch.matmul(A, B.t()) + C2 = F.gemv_4bit(A, qB.t(), state=state) + C3 = torch.matmul(A, B.t()) + A.requires_grad = True + C1 = bnb.matmul_4bit(A, qB.t(), state) #print(state) #print(qB) @@ -2457,8 +2396,7 @@ def test_gemm_4bit(dtype): #print(A) #print(B) #print('='*89) - #print(C1) - #print(C2) + #print(C3.flatten()[-20:]) #print(C3) #print(C1.shape, C2.shape) @@ -2485,10 +2423,16 @@ def test_gemm_4bit(dtype): #print(dim, sum(errs)/len(errs)/math.sqrt(dim)) #print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) #print(dim, (max_err.item(), max_relerr.item())) + print(C1.flatten()[-20:]) + print(C2.flatten()[-20:]) print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015) print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015) - assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011 - assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15 + if dtype == torch.float16: + assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5 + assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005 + else: + assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4 + assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003 @pytest.mark.skip("Row scale has some bugs for ampere") def test_managed():