forked from mrq/bitsandbytes-rocm
Fixed prod Python < 3.7 compatibility in function.py.
This commit is contained in:
parent
62441815bc
commit
f9cbe2fe99
|
@ -6,6 +6,7 @@ import bitsandbytes.functional as F
|
|||
from dataclasses import dataclass
|
||||
from functools import reduce # Required in Python 3
|
||||
|
||||
# math.prod not compatible with python < 3.8
|
||||
def prod(iterable):
|
||||
return reduce(operator.mul, iterable, 1)
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import ctypes as ct
|
||||
import operator
|
||||
import random
|
||||
import math
|
||||
import torch
|
||||
|
@ -11,6 +12,11 @@ from typing import Tuple
|
|||
from torch import Tensor
|
||||
|
||||
from .cextension import COMPILED_WITH_CUDA, lib
|
||||
from functools import reduce # Required in Python 3
|
||||
|
||||
# math.prod not compatible with python < 3.8
|
||||
def prod(iterable):
|
||||
return reduce(operator.mul, iterable, 1)
|
||||
|
||||
name2qmap = {}
|
||||
|
||||
|
@ -326,8 +332,8 @@ def nvidia_transform(
|
|||
dim1 = ct.c_int32(shape[0])
|
||||
dim2 = ct.c_int32(shape[1])
|
||||
elif ld is not None:
|
||||
n = math.prod(shape)
|
||||
dim1 = math.prod([shape[i] for i in ld])
|
||||
n = prod(shape)
|
||||
dim1 = prod([shape[i] for i in ld])
|
||||
dim2 = ct.c_int32(n // dim1)
|
||||
dim1 = ct.c_int32(dim1)
|
||||
else:
|
||||
|
@ -1314,7 +1320,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
|||
m = shapeA[0] * shapeA[1]
|
||||
|
||||
rows = n = shapeB[0]
|
||||
assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
|
||||
assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
|
||||
|
||||
# if the tensor is empty, return a transformed empty tensor with the right dimensions
|
||||
if shapeA[0] == 0 and dimsA == 2:
|
||||
|
|
Loading…
Reference in New Issue
Block a user