Fixed a bug where gemv_4bit would return a wrongly sized tensor.
This commit is contained in:
parent
0f0390acb2
commit
6a905be5ce
|
@ -1475,7 +1475,10 @@ def gemv_4bit(
|
|||
absmax += offset
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
if len(A.shape) == 3:
|
||||
out = torch.zeros(size=(A.shape[0], A.shape[1], bout), dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user