Added abitrary data types; fixed a bug for small matrices.

This commit is contained in:
Tim Dettmers 2023-07-09 12:04:09 -07:00
parent eefbf60270
commit 4b88d69de7
8 changed files with 98 additions and 109 deletions

View File

@ -562,6 +562,6 @@ def matmul(
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
return F.cutlass3_gemm(A, B.t(), out, state=quant_state)
return F.gemv_4bit(A, B.t(), out, state=quant_state)
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

View File

@ -240,17 +240,19 @@ def create_normal_map(offset=0.9677083, use_extra_value=True):
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
v = v1 + v2 + v3
else:
v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
v = v1 + v2 + v3
v = v1 + v2 + v3
values = torch.Tensor(v)
values = values.sort().values
values /= values.max()
assert values.numel() == 256
return values
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
@ -710,6 +712,47 @@ def dequantize_blockwise(
return out
def get_4bit_type(typename, device=None, blocksize=64):
if device is None: device = 'cuda'
data = None
if typename == 'nf4':
data = [-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635,
-0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725,
0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941,
0.7229568362236023, 1.0]
elif typename == 'fp4':
# 0b000 = 0
# 0b001 = 0.0625
# 0b010 = 8
# 0b011 = 12
# 0b100 = 4
# 0b101 = 6
# 0b110 = 2
# 0b111 = 3
data = [0, 0.0625, 8.0, 12.0, 4.0, 6.0, 2.0, 3.0, -0, -0.0625, -8.0, -12.0, -4.0, -6.0, -2.0, -3.0]
elif typename == 'int4':
data = [7, 6, 5, 4, 3, 2, 1, 0, -0, -1, -2, -3, -4, -5, -6, -7]
elif typename == 'af4':
# Taken from: NF4 Isn't Information Theoretically Optimal (and that's Good)
# https://arxiv.org/abs/2306.06965
if blocksize == 64:
data = [-1., -0.69441008, -0.51243739, -0.3736951, -0.25607552, -0.14982478,
-0.04934812, 0., 0.04273164, 0.12934483, 0.21961274, 0.31675666,
0.42563882, 0.55496234, 0.72424863, 1.][::-1]
else:
raise NotImplementedError(f'4-bit AbnormalFloats currently only support blocksize 64.')
if data is None:
raise NotImplementedError(f'Typename {typename} not supported')
data = Tensor(data)
data /= data.abs().max()
assert data.numel() == 16
return data.to(device)
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
@ -783,6 +826,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
datatype = get_4bit_type(quant_type, device=A.device)
if compress_statistics:
offset = absmax.mean()
absmax -= offset
@ -790,9 +835,9 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
del absmax
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type, datatype]
else:
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type, datatype]
return out, state
@ -839,7 +884,7 @@ def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
shape = out.shape
dtype = out.dtype
else:
absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
if compressed_stats is not None:
@ -1408,13 +1453,14 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
return sout
def cutlass3_gemm(
def gemv_4bit(
A: Tensor,
B: Tensor,
out: Tensor = None,
transposed_A=False,
transposed_B=False,
state=None
state=None,
storage_type='nf4'
):
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None:
@ -1491,8 +1537,6 @@ def cutlass3_gemm(
ldb = sA[2]
ldc = m
ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T
# [km, nk -> mn]
#lda = ldb = ldc = 1
@ -1514,15 +1558,11 @@ def cutlass3_gemm(
if B.dtype == torch.uint8:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
lib.cgemm_4bit_inference_naive_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
elif A.dtype == torch.bfloat16:
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
lib.cgemm_4bit_inference_naive_bf16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(state[-1]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
elif A.dtype == torch.float32:
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
elif A.dtype == torch.float16:
lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
else:
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')

View File

@ -3520,7 +3520,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
}
#define num_values_4bit 32
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
{
// per threadblock:
@ -3568,7 +3568,9 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
{
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
if((inner_idx/2) + j < K)
if((inner_idx_halved) + j < K)
local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
else
local_B_4bit[j] = 0b01110111;
}
}
@ -3578,6 +3580,9 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
{
local_B[k*2] = quant_map[local_B_4bit[k] >> 4]*local_absmax;
local_B[k*2 + 1] = quant_map[local_B_4bit[k] & 0x0F]*local_absmax;
//if(threadIdx.x == 0)
//printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax);
}
if(inner_idx+num_values_4bit)
@ -3773,8 +3778,8 @@ template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, ha
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<__nv_bfloat16, 128>(int M, int N, int K, __nv_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
template __global__ void kExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);

View File

@ -125,7 +125,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);

View File

@ -729,12 +729,12 @@ template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsi
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
{
int num_blocks = (m+3)/4;
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
kgemm_4bit_inference_naive<T, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
}
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
@ -757,8 +757,8 @@ template void func<float, ARANGE>(float *A, float *B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<__nv_bfloat16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);

View File

@ -200,7 +200,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);

View File

@ -28,11 +28,11 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive<half>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive<__nv_bfloat16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
@ -394,11 +394,11 @@ extern "C"
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize)
{ gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); }
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }

View File

@ -1816,7 +1816,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
# warmup
for i in range(iters):
@ -1849,7 +1849,7 @@ def test_bench_matmul(batch, seq, model, hidden):
t0 = time.time()
for i in range(iters):
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
F.cutlass3_gemm(A, B_nf4.t(), state=state_nf4)
F.gemv_4bit(A, B_nf4.t(), state=state_nf4)
torch.cuda.synchronize()
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
@ -2351,76 +2351,14 @@ def test_normal_map_tree():
print(pivots)
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
def test_cutlass3_gemm(dtype):
debug = True
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4096, 5120, 6656, 8192]:
for dim in [4096]:
#for dim in [128+1]:
errs = []
relerrs = []
max_err = 0
max_relerr = 0
for i in range(100):
A = torch.randn(1, dim, dtype=dtype, device='cuda')
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
#print('')
#print(A)
#print(B.t())
#A[:, :-1] = 0
#B[:, :-1] = 0
C1 = torch.matmul(A, B.t())
C2 = F.cutlass3_gemm(A, B.t())
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# to test our implementation
err = torch.abs(C1-C2)
mag = torch.abs(C1)+1e-8
relerr = err/mag
max_err = max(err.max(), max_err)
max_relerr = max(relerr.max(), max_relerr)
err = err.mean().item()
relerr = relerr.mean().item()
errs.append(err)
relerrs.append(relerr)
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
# print('')
# print(i, err, relerr)
# print(A.flatten()[-6:])
# print(B.flatten()[-6:])
# out = A.flatten()[-6:]*B.flatten()[-6:]
# print(out)
# print(out[:-1].sum())
# print('='*80)
# print(C1.flatten()[-6:])
# print(C2.flatten()[-6:])
# #assert False, 'ERROR'
c = int(C1.numel()*0.0014*(dim/256))+1
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug)
#print(c/math.sqrt(dim))
print('')
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
print(dim, (max_err.item(), max_relerr.item()))
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=['fp16', 'bf16'])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def test_gemm_4bit(dtype):
def test_gemv_4bit(dtype):
print('')
#for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
for dim in [4*1024]:
for dim in [64, 128, 256, 512, 1024, 2048, 4096]:
#for dim in [4*1024]:
#for dim in [1*16]:
errs = []
relerrs = []
max_err = 0
@ -2446,9 +2384,10 @@ def test_gemm_4bit(dtype):
qB, state = F.quantize_nf4(B)
F.dequantize_nf4(qB, state)
#C2 = bnb.matmul_4bit(A, qB.t(), state)
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
C1 = torch.matmul(A, B.t())
C2 = F.gemv_4bit(A, qB.t(), state=state)
C3 = torch.matmul(A, B.t())
A.requires_grad = True
C1 = bnb.matmul_4bit(A, qB.t(), state)
#print(state)
#print(qB)
@ -2457,8 +2396,7 @@ def test_gemm_4bit(dtype):
#print(A)
#print(B)
#print('='*89)
#print(C1)
#print(C2)
#print(C3.flatten()[-20:])
#print(C3)
#print(C1.shape, C2.shape)
@ -2485,10 +2423,16 @@ def test_gemm_4bit(dtype):
#print(dim, sum(errs)/len(errs)/math.sqrt(dim))
#print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
#print(dim, (max_err.item(), max_relerr.item()))
print(C1.flatten()[-20:])
print(C2.flatten()[-20:])
print(sum(errs)/len(errs)/math.sqrt(dim) , 0.00015)
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , 0.0015)
assert sum(errs)/len(errs)/math.sqrt(dim) < 0.011
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.15
if dtype == torch.float16:
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
else:
assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003
@pytest.mark.skip("Row scale has some bugs for ampere")
def test_managed():