Fixed prod Python < 3.7 compatibility in function.py.

This commit is contained in:
Tim Dettmers 2022-08-08 09:13:22 -07:00
parent 62441815bc
commit f9cbe2fe99
3 changed files with 11 additions and 4 deletions

View File

@ -6,6 +6,7 @@ import bitsandbytes.functional as F
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
# math.prod not compatible with python < 3.8
def prod(iterable): def prod(iterable):
return reduce(operator.mul, iterable, 1) return reduce(operator.mul, iterable, 1)

View File

@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
import operator
import random import random
import math import math
import torch import torch
@ -11,6 +12,11 @@ from typing import Tuple
from torch import Tensor from torch import Tensor
from .cextension import COMPILED_WITH_CUDA, lib from .cextension import COMPILED_WITH_CUDA, lib
from functools import reduce # Required in Python 3
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
name2qmap = {} name2qmap = {}
@ -326,8 +332,8 @@ def nvidia_transform(
dim1 = ct.c_int32(shape[0]) dim1 = ct.c_int32(shape[0])
dim2 = ct.c_int32(shape[1]) dim2 = ct.c_int32(shape[1])
elif ld is not None: elif ld is not None:
n = math.prod(shape) n = prod(shape)
dim1 = math.prod([shape[i] for i in ld]) dim1 = prod([shape[i] for i in ld])
dim2 = ct.c_int32(n // dim1) dim2 = ct.c_int32(n // dim1)
dim1 = ct.c_int32(dim1) dim1 = ct.c_int32(dim1)
else: else:
@ -1314,7 +1320,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
m = shapeA[0] * shapeA[1] m = shapeA[0] * shapeA[1]
rows = n = shapeB[0] rows = n = shapeB[0]
assert math.prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
# if the tensor is empty, return a transformed empty tensor with the right dimensions # if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA[0] == 0 and dimsA == 2: if shapeA[0] == 0 and dimsA == 2:

View File

@ -18,7 +18,7 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.31.5", version=f"0.31.7",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="8-bit optimizers and matrix multiplication routines.",