Fixed a bug where gemv_4bit would return a wrongly sized tensor.

This commit is contained in:
Tim Dettmers 2023-07-09 15:34:02 -07:00
parent 0f0390acb2
commit 6a905be5ce

View File

@ -1475,7 +1475,10 @@ def gemv_4bit(
absmax += offset
if out is None:
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
if len(A.shape) == 3:
out = torch.zeros(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
else:
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)