From c361f84239d52844ddae724e40c2c9a5d49284d5 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 5 Feb 2023 06:16:56 -0800 Subject: [PATCH] Fixed matmul_fp4 transpose. --- bitsandbytes/autograd/_functions.py | 4 ++-- tests/test_autograd.py | 4 ++-- tests/test_functional.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 01d1eb2..6db90f5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -496,7 +496,7 @@ class MatMulFP4(torch.autograd.Function): # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias) + output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias) # 3. Save state ctx.state = state @@ -531,7 +531,7 @@ class MatMulFP4(torch.autograd.Function): # not supported by PyTorch. TODO: create work-around #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) - if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A)) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A).t()) return grad_A, grad_B, None, grad_bias, None diff --git a/tests/test_autograd.py b/tests/test_autograd.py index a8b9207..436c6b1 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, if not transpose[0] and transpose[1]: out_torch = funcs[0](A, B.t()) - out_bnb = funcs[1](A, B2, quant_state, bias=bias2) + out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) elif not transpose[0] and not transpose[1]: out_torch = funcs[0](A, B) - out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) + out_bnb = funcs[1](A, B2, quant_state, bias=bias2) if has_bias: out_torch += bias diff --git a/tests/test_functional.py b/tests/test_functional.py index 49022dc..23b7558 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1835,7 +1835,7 @@ def test_bench_matmul(batch, seq, model, hidden): torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul_fp4(A, B_fp4, quant_state=state) + bnb.matmul_fp4(A, B_fp4.t(), quant_state=state) torch.cuda.synchronize() print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )