Added cast to fp4 layer for speed.

This commit is contained in:
Tim Dettmers 2023-02-24 10:17:57 -08:00
parent c93a90d075
commit 9851a10b46
2 changed files with 9 additions and 4 deletions

View File

@ -404,10 +404,10 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
if any(ctx.needs_input_grad[:2]): if any(ctx.needs_input_grad[:2]):
ctx.tensors = (CAt, subA) ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx) ctx.tensor_states = (SCAt, state.idx)
else: else:
ctx.tensors = [None, None] ctx.tensors = [None, None, A]
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
@ -420,7 +420,7 @@ class MatMul8bitLt(torch.autograd.Function):
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB formatB = ctx.formatB
state = ctx.state state = ctx.state
@ -436,6 +436,7 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB: if req_gradB:
#grad_B = torch.matmul(grad_output.t(), A)
CxAt, SAt = F.transform(CAt, formatB, transpose=True) CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)

View File

@ -190,7 +190,11 @@ class LinearFP4(nn.Linear):
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias, quant_state=self.weight.quant_state)
inp_dtype = x.dtype
x = x.to(torch.float16)
out = bnb.matmul_fp4(x, self.weight.t(), bias=self.bias.half(), quant_state=self.weight.quant_state)
out = out.to(inp_dtype)
return out return out