Fixed FP4 import and data type conversion in backward.
This commit is contained in:
parent
7f0773aede
commit
c93a90d075
|
@ -525,13 +525,9 @@ class MatMulFP4(torch.autograd.Function):
|
||||||
# compute grad_bias first before changing grad_output dtype
|
# compute grad_bias first before changing grad_output dtype
|
||||||
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||||
|
|
||||||
# Cast grad_output to fp16
|
|
||||||
if len(grad_output.shape) == 3:
|
|
||||||
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
|
||||||
|
|
||||||
# not supported by PyTorch. TODO: create work-around
|
# not supported by PyTorch. TODO: create work-around
|
||||||
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
||||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(ctx.dtype_A).t())
|
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t())
|
||||||
|
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
|
@ -2,4 +2,4 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4
|
from .modules import Int8Params, Linear8bitLt, StableEmbedding, LinearFP4, FP4Params
|
||||||
|
|
Loading…
Reference in New Issue
Block a user