Fixed a bug where gemv_4bit would return a wrongly sized tensor.
This commit is contained in:
parent
0f0390acb2
commit
6a905be5ce
|
@ -1475,7 +1475,10 @@ def gemv_4bit(
|
||||||
absmax += offset
|
absmax += offset
|
||||||
|
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
if len(A.shape) == 3:
|
||||||
|
out = torch.zeros(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
|
||||||
|
else:
|
||||||
|
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user