Fixed matmul_fp4 transpose.

This commit is contained in:
Tim Dettmers 2023-02-05 06:16:56 -08:00
parent cfe4705e32
commit c361f84239
3 changed files with 5 additions and 5 deletions

View File

@ -496,7 +496,7 @@ class MatMulFP4(torch.autograd.Function):
# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias)
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
# 3. Save state
ctx.state = state
@ -531,7 +531,7 @@ class MatMulFP4(torch.autograd.Function):
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A))
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A).t())
return grad_A, grad_B, None, grad_bias, None

View File

@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
if has_bias:
out_torch += bias

View File

@ -1835,7 +1835,7 @@ def test_bench_matmul(batch, seq, model, hidden):
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_fp4(A, B_fp4, quant_state=state)
bnb.matmul_fp4(A, B_fp4.t(), quant_state=state)
torch.cuda.synchronize()
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )