forked from mrq/bitsandbytes-rocm
req_gradA for casted & more efficient and accurate fp16 backward
This commit is contained in:
parent
b3fee1ed6a
commit
8d34d36f15
|
@ -213,10 +213,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
else:
|
||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
||||
|
||||
# Cast A to fp16
|
||||
A_dtype = A.dtype
|
||||
A = A.to(torch.float16)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
|
@ -229,6 +225,11 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
input_shape = A.shape
|
||||
if state.outlier_pool is None:
|
||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||
|
||||
# Cast A to fp16
|
||||
A_dtype = A.dtype
|
||||
A = A.to(torch.float16)
|
||||
|
||||
assert (
|
||||
A.dtype == torch.float16
|
||||
), f"The input data type needs to be fp16 but {A.dtype} was found!"
|
||||
|
@ -337,14 +338,14 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
|
||||
# Cast grad_output to fp16
|
||||
grad_output_dtype = grad_output.dtype
|
||||
grad_output.to(torch.float16)
|
||||
|
||||
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
||||
assert not req_gradB, "TODO: support weight updates as well"
|
||||
state = ctx.state
|
||||
|
||||
# Cast grad_output to fp16
|
||||
grad_output_dtype = grad_output.dtype
|
||||
grad_output = grad_output.to(torch.float16)
|
||||
|
||||
if len(grad_output.shape) == 3:
|
||||
grad_output = grad_output.reshape(
|
||||
-1, grad_output.shape[-1]
|
||||
|
@ -354,9 +355,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
if req_gradA:
|
||||
CB = state.CB.half()
|
||||
SCB = state.SCB.unsqueeze(1).half()
|
||||
B = (CB * SCB) / 127.0
|
||||
grad_A = torch.mm(grad_output, B).view(ctx.grad_shape)
|
||||
SCB = (state.SCB.unsqueeze(1) / 127.0).half()
|
||||
CB *= SCB
|
||||
grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
|
||||
|
||||
if req_gradBias:
|
||||
grad_bias = grad_output.sum(0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user