Added Int8 matmul support for all GPUs. Full backward support.
This commit is contained in:
parent
92ab6a8d5f
commit
de53588934
4
Makefile
4
Makefile
|
@ -60,8 +60,8 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
|
|||
|
||||
|
||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from ._functions import undo_layout, get_inverse_transform_indices
|
|
@ -2,6 +2,7 @@ import operator
|
|||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce # Required in Python 3
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -14,6 +15,12 @@ def prod(iterable):
|
|||
|
||||
tensor = torch.Tensor
|
||||
|
||||
|
||||
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
|
||||
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
|
||||
|
||||
|
||||
|
||||
"""
|
||||
This class pools outlier dimensions across layers.
|
||||
This is particularly important for small models where outlier features
|
||||
|
@ -48,6 +55,51 @@ class GlobalOutlierPooler:
|
|||
return torch.Tensor(list(self.outliers)).to(torch.int64)
|
||||
|
||||
|
||||
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
|
||||
"""
|
||||
Compute a permutation of indices that invert the specified (tiled) matrix transformation
|
||||
|
||||
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
|
||||
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
|
||||
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
|
||||
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
|
||||
:returns: indices
|
||||
"""
|
||||
d1, d2 = tile_size
|
||||
assert 0 < d1 * d2 < 2**64
|
||||
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
|
||||
# encode each position in tile as a tuple of <= 8 unique bytes
|
||||
permuted_tile_indices = torch.zeros_like(tile_indices)
|
||||
for i in range(8):
|
||||
# select i-th byte, apply transformation and trace where each index ended up
|
||||
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
|
||||
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
|
||||
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
|
||||
permuted_tile_i = transform_tile(sample_tile_i)
|
||||
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
|
||||
permuted_tile_indices += ith_permuted_indices * (256**i)
|
||||
if d1 * d2 < 256**i:
|
||||
break # if all indices fit in i bytes, stop early
|
||||
return permuted_tile_indices
|
||||
|
||||
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
|
||||
"""
|
||||
Undo a tiled permutation such as turing or ampere layout
|
||||
|
||||
:param permuted_tensor: torch tensor in a permuted layout
|
||||
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
|
||||
:return: contiguous row-major tensor
|
||||
"""
|
||||
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
|
||||
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
|
||||
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
|
||||
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
|
||||
outputs[tile_indices.flatten()] = tensor
|
||||
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
|
||||
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
|
||||
return outputs.reshape(rows, cols).contiguous()
|
||||
|
||||
|
||||
class MatMul8bit(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
|
||||
|
@ -171,6 +223,8 @@ matmul_cublas = MatMul8bit.apply
|
|||
|
||||
@dataclass
|
||||
class MatmulLtState:
|
||||
tile_indices: Optional[torch.Tensor] = None
|
||||
force_no_igemmlt: bool = False
|
||||
CB = None
|
||||
CxB = None
|
||||
SB = None
|
||||
|
@ -202,11 +256,22 @@ class MatmulLtState:
|
|||
self.SBt = None
|
||||
self.CBt = None
|
||||
|
||||
def get_tile_size(self):
|
||||
assert self.formatB in (
|
||||
"col_turing",
|
||||
"col_ampere",
|
||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
||||
|
||||
|
||||
class MatMul8bitLt(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||
# default to pytorch behavior if inputs are empty
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
|
||||
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
|
@ -214,9 +279,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
ctx.B = B
|
||||
ctx.bias = bias
|
||||
if A.shape[-1] == B.shape[0]:
|
||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
|
||||
return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
|
||||
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
|
@ -235,9 +300,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
# 1. Quantize A
|
||||
if len(A.shape) == 3:
|
||||
A = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
||||
A.to(torch.float16), threshold=state.threshold
|
||||
)
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
|
||||
|
||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||
if state.has_fp16_weights:
|
||||
|
@ -248,12 +311,12 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
state.subB = B[:, idx].t().contiguous()
|
||||
state.idx = idx
|
||||
else:
|
||||
if state.CxB is None:
|
||||
if state.CxB is None and using_igemmlt:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
if not state.has_fp16_weights and state.CxB is None:
|
||||
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
subA = None
|
||||
|
||||
|
@ -273,7 +336,10 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
state.SCBt,
|
||||
coo_tensorB,
|
||||
) = F.double_quant(B.to(torch.float16))
|
||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||
if using_igemmlt:
|
||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||
else:
|
||||
state.CB = CB
|
||||
else:
|
||||
has_grad = False
|
||||
|
||||
|
@ -288,18 +354,17 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||
# else:
|
||||
# state.idx = outlier_idx
|
||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||
state.subB = (
|
||||
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||
.t()
|
||||
.contiguous()
|
||||
.to(A.dtype)
|
||||
)
|
||||
if state.CxB is not None:
|
||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||
else:
|
||||
outliers = state.CB[:, state.idx.long()].clone()
|
||||
|
||||
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
|
||||
CA[:, state.idx.long()] = 0
|
||||
CAt[:, state.idx.long()] = 0
|
||||
subA = A[:, state.idx.long()]
|
||||
|
||||
shapeB = state.SB[0]
|
||||
shapeB = state.SB[0] if state.SB else B.shape
|
||||
|
||||
if len(input_shape) == 3:
|
||||
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
||||
|
@ -307,16 +372,25 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
output_shape = (input_shape[0], shapeB[0])
|
||||
|
||||
# 3. Matmul
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||
# we apply the fused bias here
|
||||
if using_igemmlt:
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||
if bias is None or bias.dtype == torch.float16:
|
||||
# we apply the fused bias here
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||
output = output.to(A.dtype)
|
||||
else: # apply bias separately
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||
output = output.to(A.dtype).add_(bias)
|
||||
|
||||
if bias is None or bias.dtype == torch.float16:
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||
output = output.to(A.dtype)
|
||||
else: # apply bias separately
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||
output = output.to(A.dtype).add_(bias)
|
||||
else:
|
||||
A_wo_outliers = A.clone()
|
||||
if state.idx is not None:
|
||||
A_wo_outliers[:, state.idx.long()] = 0
|
||||
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
|
||||
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
|
||||
if bias is not None:
|
||||
output = output.add_(bias)
|
||||
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
if coo_tensorA is not None and subA is not None:
|
||||
|
@ -337,14 +411,13 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
|
||||
return clone_func(output.view(output_shape))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||
CAt, subA = ctx.tensors
|
||||
|
@ -359,9 +432,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
# Cast grad_output to fp16
|
||||
if len(grad_output.shape) == 3:
|
||||
grad_output = grad_output.reshape(
|
||||
-1, grad_output.shape[-1]
|
||||
).contiguous()
|
||||
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||
if req_gradB:
|
||||
|
@ -376,17 +447,29 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
if state.CBt is not None:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
if state.CxBt is None:
|
||||
state.CxBt, state.SBt = F.transform(
|
||||
state.CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
||||
elif state.CB is not None:
|
||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
|
||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
elif state.CxB is not None:
|
||||
|
||||
if state.tile_indices is None:
|
||||
order, tile_size = state.formatB, state.get_tile_size()
|
||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
|
||||
|
||||
CB = (
|
||||
undo_layout(state.CxB, state.tile_indices)
|
||||
.to(ctx.dtype_A)
|
||||
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
||||
)
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
else:
|
||||
raise Exception('State must contain either CBt or CB matrix for backward')
|
||||
raise Exception("State must contain either CBt or CB or CxB matrix for backward")
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
|
|
@ -209,19 +209,10 @@ class Int8Params(torch.nn.Parameter):
|
|||
|
||||
|
||||
class Linear8bitLt(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
input_features,
|
||||
output_features,
|
||||
bias=True,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
):
|
||||
super().__init__(
|
||||
input_features, output_features, bias
|
||||
)
|
||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
|
@ -231,9 +222,7 @@ class Linear8bitLt(nn.Linear):
|
|||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(
|
||||
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
|
||||
)
|
||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
|
@ -241,27 +230,20 @@ class Linear8bitLt(nn.Linear):
|
|||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor):
|
||||
self.state.is_training = self.training
|
||||
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != torch.float16:
|
||||
self.bias.data = self.bias.data.half()
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
elif self.state.memory_efficient_backward and self.state.CxB is not None:
|
||||
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
|
||||
# Thus, we delete CxB from the state.
|
||||
del self.state.CxB
|
||||
|
||||
return out
|
||||
|
|
61
tests/test_linear8bitlt.py
Normal file
61
tests/test_linear8bitlt.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
import bitsandbytes as bnb
|
||||
import pytest
|
||||
import torch
|
||||
from bitsandbytes import functional as F
|
||||
|
||||
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
|
||||
from bitsandbytes.nn.modules import Linear8bitLt
|
||||
|
||||
# contributed by Alex Borzunov, see:
|
||||
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
|
||||
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
|
||||
)
|
||||
def test_layout_exact_match():
|
||||
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
|
||||
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
|
||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
||||
tile_indices = get_inverse_transform_indices(transform, tile_size)
|
||||
cxb = transform(x)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
restored_x = undo_layout(cxb, tile_indices)
|
||||
torch.cuda.synchronize()
|
||||
assert restored_x.is_contiguous()
|
||||
assert torch.all(torch.eq(restored_x, x))
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
def test_linear_no_igemmlt():
|
||||
linear = torch.nn.Linear(1024, 3072)
|
||||
x = torch.randn(3, 1024, dtype=torch.half)
|
||||
linear_custom = Linear8bitLt(
|
||||
linear.in_features,
|
||||
linear.out_features,
|
||||
linear.bias is not None,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
)
|
||||
linear_custom.state.force_no_igemmlt = True
|
||||
|
||||
linear_custom.weight = bnb.nn.Int8Params(
|
||||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
||||
).to(linear.weight.dtype)
|
||||
linear_custom.bias = linear.bias
|
||||
linear = linear_custom.cuda()
|
||||
linear = linear.half().cuda()
|
||||
|
||||
x_ref = x.clone().cuda().requires_grad_(True)
|
||||
x_ours = x.clone().cuda().requires_grad_(True)
|
||||
fx_ref = linear(x_ref).float()
|
||||
grad_proj = torch.randn_like(fx_ref)
|
||||
(fx_ref * grad_proj).mean().backward()
|
||||
|
||||
fx_ours = linear_custom(x_ours).float()
|
||||
(fx_ours * grad_proj).mean().backward()
|
||||
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
|
||||
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
|
||||
assert not linear_custom.state.has_fp16_weights
|
||||
assert linear_custom.state.CB is not None
|
||||
assert linear_custom.state.CxB is None
|
|
@ -382,7 +382,7 @@ names = [f"threshold_{vals}" for vals in values]
|
|||
|
||||
|
||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||
@pytest.mark.parametrize("memory_efficient_backward", [True, False])
|
||||
@pytest.mark.parametrize("memory_efficient_backward", [False])
|
||||
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||
l1 = (
|
||||
bnb.nn.Linear8bitLt(
|
||||
|
|
Loading…
Reference in New Issue
Block a user