Added matmul_fp4 to the benchmark.

This commit is contained in:
Tim Dettmers 2023-02-04 22:00:04 -08:00
parent 13c0a4dc5d
commit cfe4705e32
4 changed files with 56 additions and 44 deletions

View File

@ -495,7 +495,7 @@ class MatMulFP4(torch.autograd.Function):
# 1. Dequantize
# 2. Matmul
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias)
# 3. Save state
@ -550,5 +550,6 @@ def matmul(
return MatMul8bitLt.apply(A, B, out, bias, state)
def matmul_fp4(A: tensor, B: tensor, out: tensor = None, quant_state: List = None, bias=None):
def matmul_fp4(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
assert quant_state is not None
return MatMulFP4.apply(A, B, out, bias, quant_state)

View File

@ -169,7 +169,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
lst = list(itertools.product([0, 1], repeat=precision_bits))
#for ev in evalues:
bias = 2**(exponent_bits-1)+1
print(bias)
for evalue in range(2**(exponent_bits)):
for bit_pattern in lst:
value = (1 if evalue != 0 else 0)
@ -180,9 +179,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
value = value*2**-(bias)
else:
# normals
print(value, 1)
value = value*2**-(evalue-bias-1)
print(value, 2)
values.append(value)
if signed:
values.append(-value)
@ -196,7 +193,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values.append(0)
values.sort()
code = torch.Tensor(values)
#code /= code.max()
code /= code.max()
return code

View File

@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B2, quant_state=quant_state, bias=bias2)
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B2.t(), quant_state=quant_state, bias=bias2)
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
if has_bias:
out_torch += bias
@ -498,7 +498,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
n = out_bnb.numel()
err = torch.abs(out_bnb - out_torch).float().mean().item()
if n > 0:
assert err < 0.11
assert err < 0.115
if any(req_grad):
out_bnb.data.copy_(out_torch)

View File

@ -1788,18 +1788,14 @@ batch_size = 1
seqdim = 1
values = []
values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 1024, 4*1024))
# values.append((batch_size, seqdim, 1536, 4*1536))
# values.append((batch_size, seqdim, 2048, 4*2048))
# values.append((batch_size, seqdim, 2560, 4*2560))
# values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
names = [
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
]
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
def test_bench_matmul(batch, seq, model, hidden):
iters = 128
@ -1809,17 +1805,20 @@ def test_bench_matmul(batch, seq, model, hidden):
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B)
B_fp4, state = F.quantize_fp4(B)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit.eval()
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
linearMixedBit = (
bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
)
linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half())
linearMixedBit.eval()
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
# warmup
for i in range(iters):
torch.matmul(A, B.t())
@ -1831,9 +1830,14 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print(
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_fp4(A, B_fp4, quant_state=state)
torch.cuda.synchronize()
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
@ -1872,7 +1876,7 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
torch.cuda.synchronize()
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
@ -1886,7 +1890,7 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
out = Cout * statsB * statsA * (1.0 / (127 * 127))
torch.cuda.synchronize()
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A)
torch.cuda.synchronize()
@ -1894,9 +1898,7 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters):
linear8bit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linearMixedBit(A)
torch.cuda.synchronize()
@ -1904,9 +1906,23 @@ def test_bench_matmul(batch, seq, model, hidden):
for i in range(iters):
linearMixedBit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit_train(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linear8bit_train(A)
torch.cuda.synchronize()
print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit_train_thresh(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linear8bit_train(A)
torch.cuda.synchronize()
print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
def test_zeropoint():
def quant_zp(x):
@ -2050,7 +2066,6 @@ def test_fp8_quant():
p_bits = 7-e_bits
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
print(e_bits, p_bits)
abserr = []
relerr = []
for i in range(100):
@ -2189,7 +2204,6 @@ def test_bench_dequantization():
torch.cuda.synchronize()
t0 = time.time()
for i in range(100):
#F.dequantize_blockwise(qa, SA, blocksize=2048)
qa, SA = F.quantize_blockwise(a)
torch.cuda.synchronize()
#print((time.time()-t0)/1e6)
@ -2240,7 +2254,7 @@ def test_bench_fp4_dequant():
num_bytes = input_size+output_size
GB = num_bytes/1e9
max_theoretical_s = GB/768
print(max_theoretical_s*1e6)
#print(max_theoretical_s*1e6)
b = torch.randn(128, 1024*12, device='cuda').half()
iters = 5
@ -2250,14 +2264,14 @@ def test_bench_fp4_dequant():
F.dequantize_fp4(qa, SA, blocksize=blocksize)
#b.copy_(a)
torch.cuda.synchronize()
print((time.time()-t0)/iters*1e6)
#print((time.time()-t0)/iters*1e6)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
torch.matmul(b, a.t())
torch.cuda.synchronize()
print((time.time()-t0)/iters*1e6)
#torch.cuda.synchronize()
#t0 = time.time()
#for i in range(iters):
# torch.matmul(b, a.t())
#torch.cuda.synchronize()
#print((time.time()-t0)/iters*1e6)