Added matmul_fp4 to the benchmark.
This commit is contained in:
parent
13c0a4dc5d
commit
cfe4705e32
|
@ -495,7 +495,7 @@ class MatMulFP4(torch.autograd.Function):
|
||||||
|
|
||||||
|
|
||||||
# 1. Dequantize
|
# 1. Dequantize
|
||||||
# 2. Matmul
|
# 2. MatmulnN
|
||||||
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias)
|
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias)
|
||||||
|
|
||||||
# 3. Save state
|
# 3. Save state
|
||||||
|
@ -550,5 +550,6 @@ def matmul(
|
||||||
return MatMul8bitLt.apply(A, B, out, bias, state)
|
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||||
|
|
||||||
|
|
||||||
def matmul_fp4(A: tensor, B: tensor, out: tensor = None, quant_state: List = None, bias=None):
|
def matmul_fp4(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
||||||
|
assert quant_state is not None
|
||||||
return MatMulFP4.apply(A, B, out, bias, quant_state)
|
return MatMulFP4.apply(A, B, out, bias, quant_state)
|
||||||
|
|
|
@ -169,7 +169,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
||||||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||||
#for ev in evalues:
|
#for ev in evalues:
|
||||||
bias = 2**(exponent_bits-1)+1
|
bias = 2**(exponent_bits-1)+1
|
||||||
print(bias)
|
|
||||||
for evalue in range(2**(exponent_bits)):
|
for evalue in range(2**(exponent_bits)):
|
||||||
for bit_pattern in lst:
|
for bit_pattern in lst:
|
||||||
value = (1 if evalue != 0 else 0)
|
value = (1 if evalue != 0 else 0)
|
||||||
|
@ -180,9 +179,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
||||||
value = value*2**-(bias)
|
value = value*2**-(bias)
|
||||||
else:
|
else:
|
||||||
# normals
|
# normals
|
||||||
print(value, 1)
|
|
||||||
value = value*2**-(evalue-bias-1)
|
value = value*2**-(evalue-bias-1)
|
||||||
print(value, 2)
|
|
||||||
values.append(value)
|
values.append(value)
|
||||||
if signed:
|
if signed:
|
||||||
values.append(-value)
|
values.append(-value)
|
||||||
|
@ -196,7 +193,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
||||||
values.append(0)
|
values.append(0)
|
||||||
values.sort()
|
values.sort()
|
||||||
code = torch.Tensor(values)
|
code = torch.Tensor(values)
|
||||||
#code /= code.max()
|
code /= code.max()
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
|
|
@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
||||||
|
|
||||||
if not transpose[0] and transpose[1]:
|
if not transpose[0] and transpose[1]:
|
||||||
out_torch = funcs[0](A, B.t())
|
out_torch = funcs[0](A, B.t())
|
||||||
out_bnb = funcs[1](A, B2, quant_state=quant_state, bias=bias2)
|
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
|
||||||
elif not transpose[0] and not transpose[1]:
|
elif not transpose[0] and not transpose[1]:
|
||||||
out_torch = funcs[0](A, B)
|
out_torch = funcs[0](A, B)
|
||||||
out_bnb = funcs[1](A, B2.t(), quant_state=quant_state, bias=bias2)
|
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
|
||||||
|
|
||||||
if has_bias:
|
if has_bias:
|
||||||
out_torch += bias
|
out_torch += bias
|
||||||
|
@ -498,7 +498,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
||||||
n = out_bnb.numel()
|
n = out_bnb.numel()
|
||||||
err = torch.abs(out_bnb - out_torch).float().mean().item()
|
err = torch.abs(out_bnb - out_torch).float().mean().item()
|
||||||
if n > 0:
|
if n > 0:
|
||||||
assert err < 0.11
|
assert err < 0.115
|
||||||
|
|
||||||
if any(req_grad):
|
if any(req_grad):
|
||||||
out_bnb.data.copy_(out_torch)
|
out_bnb.data.copy_(out_torch)
|
||||||
|
|
|
@ -1788,18 +1788,14 @@ batch_size = 1
|
||||||
seqdim = 1
|
seqdim = 1
|
||||||
values = []
|
values = []
|
||||||
values.append((batch_size, seqdim, 768, 4 * 768))
|
values.append((batch_size, seqdim, 768, 4 * 768))
|
||||||
# values.append((batch_size, seqdim, 1024, 4*1024))
|
#values.append((batch_size, seqdim, 1024, 4*1024))
|
||||||
# values.append((batch_size, seqdim, 1536, 4*1536))
|
#values.append((batch_size, seqdim, 1536, 4*1536))
|
||||||
# values.append((batch_size, seqdim, 2048, 4*2048))
|
#values.append((batch_size, seqdim, 2048, 4*2048))
|
||||||
# values.append((batch_size, seqdim, 2560, 4*2560))
|
#values.append((batch_size, seqdim, 2560, 4*2560))
|
||||||
# values.append((batch_size, seqdim, 4096, 4*4096))
|
#values.append((batch_size, seqdim, 4096, 4*4096))
|
||||||
# values.append((batch_size, seqdim, 5140, 4*5140))
|
#values.append((batch_size, seqdim, 5140, 4*5140))
|
||||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||||
names = [
|
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||||
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||||
def test_bench_matmul(batch, seq, model, hidden):
|
def test_bench_matmul(batch, seq, model, hidden):
|
||||||
iters = 128
|
iters = 128
|
||||||
|
@ -1809,17 +1805,20 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
|
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
|
||||||
torch.nn.init.xavier_uniform_(B)
|
torch.nn.init.xavier_uniform_(B)
|
||||||
|
|
||||||
|
B_fp4, state = F.quantize_fp4(B)
|
||||||
|
|
||||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||||
linear8bit.eval()
|
linear8bit.eval()
|
||||||
|
|
||||||
outliers = torch.randint(0, model, size=(5,)).cuda()
|
outliers = torch.randint(0, model, size=(5,)).cuda()
|
||||||
A[:, :, outliers] = 8.0
|
A[:, :, outliers] = 8.0
|
||||||
|
|
||||||
linearMixedBit = (
|
linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half())
|
||||||
bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
|
||||||
)
|
|
||||||
linearMixedBit.eval()
|
linearMixedBit.eval()
|
||||||
|
|
||||||
|
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||||
|
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
torch.matmul(A, B.t())
|
torch.matmul(A, B.t())
|
||||||
|
@ -1831,9 +1830,14 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
torch.matmul(A, B.t())
|
torch.matmul(A, B.t())
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(
|
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||||
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
|
||||||
)
|
torch.cuda.synchronize()
|
||||||
|
t0 = time.time()
|
||||||
|
for i in range(iters):
|
||||||
|
bnb.matmul_fp4(A, B_fp4, quant_state=state)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
@ -1872,7 +1876,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||||
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||||
|
|
||||||
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
||||||
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||||
|
@ -1886,7 +1890,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||||
out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||||
|
|
||||||
linear8bit(A)
|
linear8bit(A)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
@ -1894,9 +1898,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
linear8bit(A)
|
linear8bit(A)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(
|
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||||
f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
linearMixedBit(A)
|
linearMixedBit(A)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
@ -1904,9 +1906,23 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
linearMixedBit(A)
|
linearMixedBit(A)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print(
|
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||||
f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
|
||||||
)
|
linear8bit_train(A)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t0 = time.time()
|
||||||
|
for i in range(iters):
|
||||||
|
linear8bit_train(A)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||||
|
|
||||||
|
linear8bit_train_thresh(A)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t0 = time.time()
|
||||||
|
for i in range(iters):
|
||||||
|
linear8bit_train(A)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||||
|
|
||||||
def test_zeropoint():
|
def test_zeropoint():
|
||||||
def quant_zp(x):
|
def quant_zp(x):
|
||||||
|
@ -2050,7 +2066,6 @@ def test_fp8_quant():
|
||||||
p_bits = 7-e_bits
|
p_bits = 7-e_bits
|
||||||
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
|
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
|
||||||
|
|
||||||
print(e_bits, p_bits)
|
|
||||||
abserr = []
|
abserr = []
|
||||||
relerr = []
|
relerr = []
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
|
@ -2189,7 +2204,6 @@ def test_bench_dequantization():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
#F.dequantize_blockwise(qa, SA, blocksize=2048)
|
|
||||||
qa, SA = F.quantize_blockwise(a)
|
qa, SA = F.quantize_blockwise(a)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
#print((time.time()-t0)/1e6)
|
#print((time.time()-t0)/1e6)
|
||||||
|
@ -2240,7 +2254,7 @@ def test_bench_fp4_dequant():
|
||||||
num_bytes = input_size+output_size
|
num_bytes = input_size+output_size
|
||||||
GB = num_bytes/1e9
|
GB = num_bytes/1e9
|
||||||
max_theoretical_s = GB/768
|
max_theoretical_s = GB/768
|
||||||
print(max_theoretical_s*1e6)
|
#print(max_theoretical_s*1e6)
|
||||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||||
|
|
||||||
iters = 5
|
iters = 5
|
||||||
|
@ -2250,14 +2264,14 @@ def test_bench_fp4_dequant():
|
||||||
F.dequantize_fp4(qa, SA, blocksize=blocksize)
|
F.dequantize_fp4(qa, SA, blocksize=blocksize)
|
||||||
#b.copy_(a)
|
#b.copy_(a)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print((time.time()-t0)/iters*1e6)
|
#print((time.time()-t0)/iters*1e6)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
#torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
#t0 = time.time()
|
||||||
for i in range(iters):
|
#for i in range(iters):
|
||||||
torch.matmul(b, a.t())
|
# torch.matmul(b, a.t())
|
||||||
torch.cuda.synchronize()
|
#torch.cuda.synchronize()
|
||||||
print((time.time()-t0)/iters*1e6)
|
#print((time.time()-t0)/iters*1e6)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user