cast edge case

This commit is contained in:
justheuristic 2022-09-18 00:35:42 +03:00
parent e35e2c665a
commit cbfdf0b5ef

View File

@ -212,9 +212,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.B = B
ctx.bias = bias
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B