Added matmul_fp4 to the benchmark.
This commit is contained in:
parent
13c0a4dc5d
commit
cfe4705e32
|
@ -495,7 +495,7 @@ class MatMulFP4(torch.autograd.Function):
|
|||
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. Matmul
|
||||
# 2. MatmulnN
|
||||
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias)
|
||||
|
||||
# 3. Save state
|
||||
|
@ -550,5 +550,6 @@ def matmul(
|
|||
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||
|
||||
|
||||
def matmul_fp4(A: tensor, B: tensor, out: tensor = None, quant_state: List = None, bias=None):
|
||||
def matmul_fp4(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
||||
assert quant_state is not None
|
||||
return MatMulFP4.apply(A, B, out, bias, quant_state)
|
||||
|
|
|
@ -169,7 +169,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||
#for ev in evalues:
|
||||
bias = 2**(exponent_bits-1)+1
|
||||
print(bias)
|
||||
for evalue in range(2**(exponent_bits)):
|
||||
for bit_pattern in lst:
|
||||
value = (1 if evalue != 0 else 0)
|
||||
|
@ -180,9 +179,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
value = value*2**-(bias)
|
||||
else:
|
||||
# normals
|
||||
print(value, 1)
|
||||
value = value*2**-(evalue-bias-1)
|
||||
print(value, 2)
|
||||
values.append(value)
|
||||
if signed:
|
||||
values.append(-value)
|
||||
|
@ -196,7 +193,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
values.append(0)
|
||||
values.sort()
|
||||
code = torch.Tensor(values)
|
||||
#code /= code.max()
|
||||
code /= code.max()
|
||||
|
||||
return code
|
||||
|
||||
|
|
|
@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
|
||||
if not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B2, quant_state=quant_state, bias=bias2)
|
||||
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
|
||||
elif not transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B2.t(), quant_state=quant_state, bias=bias2)
|
||||
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
|
||||
|
||||
if has_bias:
|
||||
out_torch += bias
|
||||
|
@ -498,7 +498,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
n = out_bnb.numel()
|
||||
err = torch.abs(out_bnb - out_torch).float().mean().item()
|
||||
if n > 0:
|
||||
assert err < 0.11
|
||||
assert err < 0.115
|
||||
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
|
|
|
@ -1788,18 +1788,14 @@ batch_size = 1
|
|||
seqdim = 1
|
||||
values = []
|
||||
values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
# values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
# values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
# values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
# values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
# values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
# values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
#values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
#values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
#values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
#values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
#values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = [
|
||||
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||
def test_bench_matmul(batch, seq, model, hidden):
|
||||
iters = 128
|
||||
|
@ -1809,17 +1805,20 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
B_fp4, state = F.quantize_fp4(B)
|
||||
|
||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
linear8bit.eval()
|
||||
|
||||
outliers = torch.randint(0, model, size=(5,)).cuda()
|
||||
A[:, :, outliers] = 8.0
|
||||
|
||||
linearMixedBit = (
|
||||
bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||
)
|
||||
linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half())
|
||||
linearMixedBit.eval()
|
||||
|
||||
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||
|
||||
# warmup
|
||||
for i in range(iters):
|
||||
torch.matmul(A, B.t())
|
||||
|
@ -1831,9 +1830,14 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
for i in range(iters):
|
||||
torch.matmul(A, B.t())
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_fp4(A, B_fp4, quant_state=state)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
|
@ -1872,7 +1876,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
||||
torch.cuda.synchronize()
|
||||
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
||||
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
|
@ -1886,7 +1890,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
||||
torch.cuda.synchronize()
|
||||
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -1894,9 +1898,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
for i in range(iters):
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -1904,9 +1906,23 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
for i in range(iters):
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit_train(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linear8bit_train(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit_train_thresh(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
linear8bit_train(A)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
def test_zeropoint():
|
||||
def quant_zp(x):
|
||||
|
@ -2050,7 +2066,6 @@ def test_fp8_quant():
|
|||
p_bits = 7-e_bits
|
||||
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
|
||||
|
||||
print(e_bits, p_bits)
|
||||
abserr = []
|
||||
relerr = []
|
||||
for i in range(100):
|
||||
|
@ -2189,7 +2204,6 @@ def test_bench_dequantization():
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
#F.dequantize_blockwise(qa, SA, blocksize=2048)
|
||||
qa, SA = F.quantize_blockwise(a)
|
||||
torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/1e6)
|
||||
|
@ -2240,7 +2254,7 @@ def test_bench_fp4_dequant():
|
|||
num_bytes = input_size+output_size
|
||||
GB = num_bytes/1e9
|
||||
max_theoretical_s = GB/768
|
||||
print(max_theoretical_s*1e6)
|
||||
#print(max_theoretical_s*1e6)
|
||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||
|
||||
iters = 5
|
||||
|
@ -2250,14 +2264,14 @@ def test_bench_fp4_dequant():
|
|||
F.dequantize_fp4(qa, SA, blocksize=blocksize)
|
||||
#b.copy_(a)
|
||||
torch.cuda.synchronize()
|
||||
print((time.time()-t0)/iters*1e6)
|
||||
#print((time.time()-t0)/iters*1e6)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
torch.matmul(b, a.t())
|
||||
torch.cuda.synchronize()
|
||||
print((time.time()-t0)/iters*1e6)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# torch.matmul(b, a.t())
|
||||
#torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/iters*1e6)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user