Fixed accidential deletion of limits in kernel.
This commit is contained in:
parent
2221f4cee0
commit
306f6b2362
|
@ -3595,7 +3595,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
if(inner_idx+num_values_4bit)
|
if(inner_idx+num_values_4bit < K)
|
||||||
{
|
{
|
||||||
if(BITS==16)
|
if(BITS==16)
|
||||||
{
|
{
|
||||||
|
@ -3619,11 +3619,17 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
for(int k = 0; k < num_values_4bit; k++)
|
||||||
local_A[k] = A[inner_idx +k];
|
if(inner_idx + k < K)
|
||||||
|
local_A[k] = A[inner_idx + k];
|
||||||
|
else
|
||||||
|
local_A[k] = T(0.0f);
|
||||||
|
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for(int k = 0; k < num_values_4bit; k++)
|
for(int k = 0; k < num_values_4bit; k++)
|
||||||
{
|
{
|
||||||
|
if((float)local_A[k] < -10.0f || (float)local_B[k] < -10.0f || local_C > 10.0f)
|
||||||
|
printf("%f %f = %f\n", (float)local_A[k], (float)local_B[k], local_C);
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
local_C += (float)(local_A[k]*local_B[k]);
|
local_C += (float)(local_A[k]*local_B[k]);
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -2378,8 +2378,8 @@ def test_gemv_4bit(dtype, storage_type, double_quant):
|
||||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||||
B = torch.randn(4*dim, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
#B = torch.randn(4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||||
#B = torch.randn(1, dim+2, dtype=dtype, device='cuda')/math.sqrt(dim)
|
B = torch.randn(dim*4, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||||
|
|
||||||
#print('')
|
#print('')
|
||||||
#print(A)
|
#print(A)
|
||||||
|
@ -2432,13 +2432,18 @@ def test_gemv_4bit(dtype, storage_type, double_quant):
|
||||||
#print(dim, (max_err.item(), max_relerr.item()))
|
#print(dim, (max_err.item(), max_relerr.item()))
|
||||||
print(C1.flatten()[-20:])
|
print(C1.flatten()[-20:])
|
||||||
print(C2.flatten()[-20:])
|
print(C2.flatten()[-20:])
|
||||||
print(C3.flatten()[-20:])
|
#print(C1.flatten())
|
||||||
|
#print(C2.flatten())
|
||||||
|
#print(C3.flatten()[-20:])
|
||||||
print(sum(errs)/len(errs)/math.sqrt(dim) , dim)
|
print(sum(errs)/len(errs)/math.sqrt(dim) , dim)
|
||||||
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim)
|
print(sum(relerrs)/len(relerrs)/math.sqrt(dim) , dim)
|
||||||
if dtype == torch.float16:
|
if dtype == torch.float16:
|
||||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
|
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-5
|
||||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
|
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.0005
|
||||||
else:
|
elif dtype == torch.float32:
|
||||||
|
assert sum(errs)/len(errs)/math.sqrt(dim) < 5e-8
|
||||||
|
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 1e-8
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4
|
assert sum(errs)/len(errs)/math.sqrt(dim) < 3e-4
|
||||||
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003
|
assert sum(relerrs)/len(relerrs)/math.sqrt(dim) < 0.003
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user