add dtype <-> fp16 cast

This commit is contained in:
dbaranchuk 2022-08-26 04:11:40 +03:00
parent 4d6174bc63
commit b3fee1ed6a

View File

@ -213,6 +213,10 @@ class MatMul8bitLt(torch.autograd.Function):
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
# Cast A to fp16
A_dtype = A.dtype
A = A.to(torch.float16)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
@ -322,14 +326,21 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
# Cast fp16 output back to A.dtype
output = output.to(A_dtype)
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
#clone_func = torch.clone
return clone_func(output.view(output_shape))
def backward(ctx, grad_output):
if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
# Cast grad_output to fp16
grad_output_dtype = grad_output.dtype
grad_output.to(torch.float16)
req_gradA, req_gradB, req_gradBias = ctx.req_grads
assert not req_gradB, "TODO: support weight updates as well"
state = ctx.state
@ -350,6 +361,9 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradBias:
grad_bias = grad_output.sum(0)
# Cast grad_A back to grad_output_dtype
grad_output.to(grad_output_dtype)
return grad_A, grad_B, None, grad_bias, None