Fixed matmul_fp4 transpose.
This commit is contained in:
parent
cfe4705e32
commit
c361f84239
|
@ -496,7 +496,7 @@ class MatMulFP4(torch.autograd.Function):
|
||||||
|
|
||||||
# 1. Dequantize
|
# 1. Dequantize
|
||||||
# 2. MatmulnN
|
# 2. MatmulnN
|
||||||
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype), bias)
|
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
|
||||||
|
|
||||||
# 3. Save state
|
# 3. Save state
|
||||||
ctx.state = state
|
ctx.state = state
|
||||||
|
@ -531,7 +531,7 @@ class MatMulFP4(torch.autograd.Function):
|
||||||
|
|
||||||
# not supported by PyTorch. TODO: create work-around
|
# not supported by PyTorch. TODO: create work-around
|
||||||
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
||||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A))
|
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A).t())
|
||||||
|
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
|
@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
||||||
|
|
||||||
if not transpose[0] and transpose[1]:
|
if not transpose[0] and transpose[1]:
|
||||||
out_torch = funcs[0](A, B.t())
|
out_torch = funcs[0](A, B.t())
|
||||||
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
|
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
|
||||||
elif not transpose[0] and not transpose[1]:
|
elif not transpose[0] and not transpose[1]:
|
||||||
out_torch = funcs[0](A, B)
|
out_torch = funcs[0](A, B)
|
||||||
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
|
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
|
||||||
|
|
||||||
if has_bias:
|
if has_bias:
|
||||||
out_torch += bias
|
out_torch += bias
|
||||||
|
|
|
@ -1835,7 +1835,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
bnb.matmul_fp4(A, B_fp4, quant_state=state)
|
bnb.matmul_fp4(A, B_fp4.t(), quant_state=state)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user