forked from mrq/bitsandbytes-rocm
add dtype <-> fp16 cast
This commit is contained in:
parent
4d6174bc63
commit
b3fee1ed6a
|
@ -213,6 +213,10 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
else:
|
||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
||||
|
||||
# Cast A to fp16
|
||||
A_dtype = A.dtype
|
||||
A = A.to(torch.float16)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
|
@ -322,14 +326,21 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
# Cast fp16 output back to A.dtype
|
||||
output = output.to(A_dtype)
|
||||
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||
#clone_func = torch.clone
|
||||
return clone_func(output.view(output_shape))
|
||||
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
|
||||
# Cast grad_output to fp16
|
||||
grad_output_dtype = grad_output.dtype
|
||||
grad_output.to(torch.float16)
|
||||
|
||||
req_gradA, req_gradB, req_gradBias = ctx.req_grads
|
||||
assert not req_gradB, "TODO: support weight updates as well"
|
||||
state = ctx.state
|
||||
|
@ -350,6 +361,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
if req_gradBias:
|
||||
grad_bias = grad_output.sum(0)
|
||||
|
||||
# Cast grad_A back to grad_output_dtype
|
||||
grad_output.to(grad_output_dtype)
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user