recast to fp16

This commit is contained in:
justheuristic 2022-09-17 23:34:22 +03:00
parent fc4a135ed1
commit a9fe0ff98c

View File

@ -275,7 +275,7 @@ class MatMul8bitLt(torch.autograd.Function):
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B)
) = F.double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False