forked from mrq/bitsandbytes-rocm
add dtype <-> fp16 cast
This commit is contained in:
parent
4d6174bc63
commit
b3fee1ed6a
|
@ -213,6 +213,10 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
else:
|
else:
|
||||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
||||||
|
|
||||||
|
# Cast A to fp16
|
||||||
|
A_dtype = A.dtype
|
||||||
|
A = A.to(torch.float16)
|
||||||
|
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
# 2. Quantize B
|
# 2. Quantize B
|
||||||
# 3. Matmul
|
# 3. Matmul
|
||||||
|
@ -322,14 +326,21 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
ctx.tensor_states = (None, None)
|
ctx.tensor_states = (None, None)
|
||||||
ctx.save_for_backward(None, None)
|
ctx.save_for_backward(None, None)
|
||||||
|
|
||||||
|
# Cast fp16 output back to A.dtype
|
||||||
|
output = output.to(A_dtype)
|
||||||
|
|
||||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||||
#clone_func = torch.clone
|
|
||||||
return clone_func(output.view(output_shape))
|
return clone_func(output.view(output_shape))
|
||||||
|
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
if ctx.is_empty:
|
if ctx.is_empty:
|
||||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||||
|
|
||||||
|
# Cast grad_output to fp16
|
||||||
|
grad_output_dtype = grad_output.dtype
|
||||||
|
grad_output.to(torch.float16)
|
||||||
|
|
||||||
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
||||||
assert not req_gradB, "TODO: support weight updates as well"
|
assert not req_gradB, "TODO: support weight updates as well"
|
||||||
state = ctx.state
|
state = ctx.state
|
||||||
|
@ -350,6 +361,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
if req_gradBias:
|
if req_gradBias:
|
||||||
grad_bias = grad_output.sum(0)
|
grad_bias = grad_output.sum(0)
|
||||||
|
|
||||||
|
# Cast grad_A back to grad_output_dtype
|
||||||
|
grad_output.to(grad_output_dtype)
|
||||||
|
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user