From cfe4705e321d884bae48ce785f29d4a0aff5518b Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 4 Feb 2023 22:00:04 -0800 Subject: [PATCH] Added matmul_fp4 to the benchmark. --- bitsandbytes/autograd/_functions.py | 5 +- bitsandbytes/functional.py | 5 +- tests/test_autograd.py | 6 +-- tests/test_functional.py | 84 +++++++++++++++++------------ 4 files changed, 56 insertions(+), 44 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 29c0b93..01d1eb2 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -495,7 +495,7 @@ class MatMulFP4(torch.autograd.Function): # 1. Dequantize - # 2. Matmul + # 2. MatmulnN output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias) # 3. Save state @@ -550,5 +550,6 @@ def matmul( return MatMul8bitLt.apply(A, B, out, bias, state) -def matmul_fp4(A: tensor, B: tensor, out: tensor = None, quant_state: List = None, bias=None): +def matmul_fp4(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): + assert quant_state is not None return MatMulFP4.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 92ac670..b38ba1d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -169,7 +169,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) lst = list(itertools.product([0, 1], repeat=precision_bits)) #for ev in evalues: bias = 2**(exponent_bits-1)+1 - print(bias) for evalue in range(2**(exponent_bits)): for bit_pattern in lst: value = (1 if evalue != 0 else 0) @@ -180,9 +179,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) value = value*2**-(bias) else: # normals - print(value, 1) value = value*2**-(evalue-bias-1) - print(value, 2) values.append(value) if signed: values.append(-value) @@ -196,7 +193,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values.append(0) values.sort() code = torch.Tensor(values) - #code /= code.max() + code /= code.max() return code diff --git a/tests/test_autograd.py b/tests/test_autograd.py index ccbcc87..a8b9207 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B2, quant_state=quant_state, bias=bias2) + out_bnb = funcs[1](A, B2, quant_state, bias=bias2) elif not transpose[0] and not transpose[1]: out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B2.t(), quant_state=quant_state, bias=bias2) + out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) if has_bias: out_torch += bias @@ -498,7 +498,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, n = out_bnb.numel() err = torch.abs(out_bnb - out_torch).float().mean().item() if n > 0: - assert err < 0.11 + assert err < 0.115 if any(req_grad): out_bnb.data.copy_(out_torch) diff --git a/tests/test_functional.py b/tests/test_functional.py index e6b7b81..49022dc 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1788,18 +1788,14 @@ batch_size = 1 seqdim = 1 values = [] values.append((batch_size, seqdim, 768, 4 * 768)) -# values.append((batch_size, seqdim, 1024, 4*1024)) -# values.append((batch_size, seqdim, 1536, 4*1536)) -# values.append((batch_size, seqdim, 2048, 4*2048)) -# values.append((batch_size, seqdim, 2560, 4*2560)) -# values.append((batch_size, seqdim, 4096, 4*4096)) -# values.append((batch_size, seqdim, 5140, 4*5140)) +#values.append((batch_size, seqdim, 1024, 4*1024)) +#values.append((batch_size, seqdim, 1536, 4*1536)) +#values.append((batch_size, seqdim, 2048, 4*2048)) +#values.append((batch_size, seqdim, 2560, 4*2560)) +#values.append((batch_size, seqdim, 4096, 4*4096)) +#values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) -names = [ - "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values -] - - +names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): iters = 128 @@ -1809,17 +1805,20 @@ def test_bench_matmul(batch, seq, model, hidden): B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") torch.nn.init.xavier_uniform_(B) + B_fp4, state = F.quantize_fp4(B) + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() linear8bit.eval() outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = ( - bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - ) + linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()) linearMixedBit.eval() + linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + # warmup for i in range(iters): torch.matmul(A, B.t()) @@ -1831,9 +1830,14 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): torch.matmul(A, B.t()) torch.cuda.synchronize() - print( - f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_fp4(A, B_fp4, quant_state=state) + torch.cuda.synchronize() + print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() @@ -1872,7 +1876,7 @@ def test_bench_matmul(batch, seq, model, hidden): Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) torch.cuda.synchronize() - #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") CxB, SB = F.nvidia_transform(CB, to_order=formatB) @@ -1886,7 +1890,7 @@ def test_bench_matmul(batch, seq, model, hidden): Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) out = Cout * statsB * statsA * (1.0 / (127 * 127)) torch.cuda.synchronize() - #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") linear8bit(A) torch.cuda.synchronize() @@ -1894,9 +1898,7 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): linear8bit(A) torch.cuda.synchronize() - print( - f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") linearMixedBit(A) torch.cuda.synchronize() @@ -1904,9 +1906,23 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): linearMixedBit(A) torch.cuda.synchronize() - print( - f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + linear8bit_train(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linear8bit_train(A) + torch.cuda.synchronize() + print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + linear8bit_train_thresh(A) + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + linear8bit_train(A) + torch.cuda.synchronize() + print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") def test_zeropoint(): def quant_zp(x): @@ -2050,7 +2066,6 @@ def test_fp8_quant(): p_bits = 7-e_bits code = F.create_fp8_map(True, e_bits, p_bits).cuda() - print(e_bits, p_bits) abserr = [] relerr = [] for i in range(100): @@ -2189,7 +2204,6 @@ def test_bench_dequantization(): torch.cuda.synchronize() t0 = time.time() for i in range(100): - #F.dequantize_blockwise(qa, SA, blocksize=2048) qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() #print((time.time()-t0)/1e6) @@ -2240,7 +2254,7 @@ def test_bench_fp4_dequant(): num_bytes = input_size+output_size GB = num_bytes/1e9 max_theoretical_s = GB/768 - print(max_theoretical_s*1e6) + #print(max_theoretical_s*1e6) b = torch.randn(128, 1024*12, device='cuda').half() iters = 5 @@ -2250,14 +2264,14 @@ def test_bench_fp4_dequant(): F.dequantize_fp4(qa, SA, blocksize=blocksize) #b.copy_(a) torch.cuda.synchronize() - print((time.time()-t0)/iters*1e6) + #print((time.time()-t0)/iters*1e6) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - torch.matmul(b, a.t()) - torch.cuda.synchronize() - print((time.time()-t0)/iters*1e6) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # torch.matmul(b, a.t()) + #torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6)