Fixed matmul_fp4 transpose.
This commit is contained in:
parent
cfe4705e32
commit
c361f84239
|
@ -496,7 +496,7 @@ class MatMulFP4(torch.autograd.Function):
|
|||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias)
|
||||
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
|
||||
|
||||
# 3. Save state
|
||||
ctx.state = state
|
||||
|
@ -531,7 +531,7 @@ class MatMulFP4(torch.autograd.Function):
|
|||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A))
|
||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A).t())
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
|
|
@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
|
||||
if not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
|
||||
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
|
||||
elif not transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
|
||||
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
|
||||
|
||||
if has_bias:
|
||||
out_torch += bias
|
||||
|
|
|
@ -1835,7 +1835,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul_fp4(A, B_fp4, quant_state=state)
|
||||
bnb.matmul_fp4(A, B_fp4.t(), quant_state=state)
|
||||
torch.cuda.synchronize()
|
||||
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user