cast edge case
This commit is contained in:
parent
e35e2c665a
commit
cbfdf0b5ef
|
@ -212,9 +212,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
ctx.B = B
|
ctx.B = B
|
||||||
ctx.bias = bias
|
ctx.bias = bias
|
||||||
if A.shape[-1] == B.shape[0]:
|
if A.shape[-1] == B.shape[0]:
|
||||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
|
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
|
||||||
else:
|
else:
|
||||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)
|
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
# 2. Quantize B
|
# 2. Quantize B
|
||||||
|
|
Loading…
Reference in New Issue
Block a user