Merge branch 'main' into patch-1

This commit is contained in:
Kashif Rasul 2023-02-03 09:01:48 +01:00 committed by GitHub
commit c52365ac1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 220 additions and 75 deletions

View File

@ -189,3 +189,15 @@ Improvements:
- StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu) - StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu)
- runtime performance of block-wise quantization slightly improved - runtime performance of block-wise quantization slightly improved
- added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one - added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one
### 0.37.0
#### Int8 Matmul + backward support for all GPUs
Features:
- Int8 MatmulLt now supports backward through inversion of the ColTuring/ColAmpere format. Slow, but memory efficient. Big thanks to @borzunov
- Int8 now supported on all GPUs. On devices with compute capability < 7.5, the Int weights are cast to 16/32-bit for the matrix multiplication. Contributed by @borzunov
Improvements:
- Improved logging for the CUDA detection mechanism.

View File

@ -60,8 +60,8 @@ CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env

View File

@ -12,6 +12,7 @@ Resources:
## TL;DR ## TL;DR
**Requirements** **Requirements**
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs. Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs.
**Installation**: **Installation**:
``pip install bitsandbytes`` ``pip install bitsandbytes``
@ -58,6 +59,10 @@ The bitsandbytes library is currently only supported on Linux distributions. Win
The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website. The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website.
To install run:
``pip install bitsandbytes``
## Using bitsandbytes ## Using bitsandbytes
### Using Int8 Matrix Multiplication ### Using Int8 Matrix Multiplication

View File

@ -0,0 +1 @@
from ._functions import undo_layout, get_inverse_transform_indices

View File

@ -2,6 +2,7 @@ import operator
import warnings import warnings
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from functools import reduce # Required in Python 3
from typing import Tuple, Optional
import torch import torch
@ -14,6 +15,12 @@ def prod(iterable):
tensor = torch.Tensor tensor = torch.Tensor
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
""" """
This class pools outlier dimensions across layers. This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
@ -48,6 +55,51 @@ class GlobalOutlierPooler:
return torch.Tensor(list(self.outliers)).to(torch.int64) return torch.Tensor(list(self.outliers)).to(torch.int64)
def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]):
"""
Compute a permutation of indices that invert the specified (tiled) matrix transformation
:param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2]
:param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere
:note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size
:example: transform_tile function for the turing layout (bitsandbytes.functional as F)
:returns: indices
"""
d1, d2 = tile_size
assert 0 < d1 * d2 < 2**64
tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2)
# encode each position in tile as a tuple of <= 8 unique bytes
permuted_tile_indices = torch.zeros_like(tile_indices)
for i in range(8):
# select i-th byte, apply transformation and trace where each index ended up
ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256
sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous()
assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow"
permuted_tile_i = transform_tile(sample_tile_i)
ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128
permuted_tile_indices += ith_permuted_indices * (256**i)
if d1 * d2 < 256**i:
break # if all indices fit in i bytes, stop early
return permuted_tile_indices
def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor:
"""
Undo a tiled permutation such as turing or ampere layout
:param permuted_tensor: torch tensor in a permuted layout
:param tile_indices: reverse transformation indices, from get_inverse_transform_indices
:return: contiguous row-major tensor
"""
(rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape
assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles"
tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t()
outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda
outputs[tile_indices.flatten()] = tensor
outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows)
outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
return outputs.reshape(rows, cols).contiguous()
class MatMul8bit(torch.autograd.Function): class MatMul8bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None): def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
@ -171,6 +223,8 @@ matmul_cublas = MatMul8bit.apply
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False
CB = None CB = None
CxB = None CxB = None
SB = None SB = None
@ -202,11 +256,22 @@ class MatmulLtState:
self.SBt = None self.SBt = None
self.CBt = None self.CBt = None
def get_tile_size(self):
assert self.formatB in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32)
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
# default to pytorch behavior if inputs are empty using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
# default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
ctx.is_empty = True ctx.is_empty = True
@ -214,9 +279,9 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.B = B ctx.B = B
ctx.bias = bias ctx.bias = bias
if A.shape[-1] == B.shape[0]: if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device)
else: else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A # 1. Quantize A
# 2. Quantize B # 2. Quantize B
@ -235,9 +300,7 @@ class MatMul8bitLt(torch.autograd.Function):
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
A.to(torch.float16), threshold=state.threshold
)
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights: if state.has_fp16_weights:
@ -248,12 +311,12 @@ class MatMul8bitLt(torch.autograd.Function):
state.subB = B[:, idx].t().contiguous() state.subB = B[:, idx].t().contiguous()
state.idx = idx state.idx = idx
else: else:
if state.CxB is None: if state.CxB is None and using_igemmlt:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format # we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
if not state.has_fp16_weights and state.CxB is None: if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None subA = None
@ -273,7 +336,10 @@ class MatMul8bitLt(torch.autograd.Function):
state.SCBt, state.SCBt,
coo_tensorB, coo_tensorB,
) = F.double_quant(B.to(torch.float16)) ) = F.double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB) if using_igemmlt:
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
state.CB = CB
else: else:
has_grad = False has_grad = False
@ -288,18 +354,17 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else: # else:
# state.idx = outlier_idx # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) if state.CxB is not None:
state.subB = ( outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
(outliers * state.SCB.view(-1, 1) / 127.0) else:
.t() outliers = state.CB[:, state.idx.long()].clone()
.contiguous()
.to(A.dtype) state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
)
CA[:, state.idx.long()] = 0 CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()] subA = A[:, state.idx.long()]
shapeB = state.SB[0] shapeB = state.SB[0] if state.SB else B.shape
if len(input_shape) == 3: if len(input_shape) == 3:
output_shape = (input_shape[0], input_shape[1], shapeB[0]) output_shape = (input_shape[0], input_shape[1], shapeB[0])
@ -307,16 +372,25 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0]) output_shape = (input_shape[0], shapeB[0])
# 3. Matmul # 3. Matmul
C32A, SA = F.transform(CA, "col32") if using_igemmlt:
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) C32A, SA = F.transform(CA, "col32")
# we apply the fused bias here out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
if bias is None or bias.dtype == torch.float16: else:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) A_wo_outliers = A.clone()
output = output.to(A.dtype) if state.idx is not None:
else: # apply bias separately A_wo_outliers[:, state.idx.long()] = 0
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.to(A.dtype).add_(bias) output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
if bias is not None:
output = output.add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: if coo_tensorA is not None and subA is not None:
@ -337,14 +411,13 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
return clone_func(output.view(output_shape)) return clone_func(output.view(output_shape))
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
if ctx.is_empty: if ctx.is_empty:
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA = ctx.tensors CAt, subA = ctx.tensors
@ -359,9 +432,7 @@ class MatMul8bitLt(torch.autograd.Function):
# Cast grad_output to fp16 # Cast grad_output to fp16
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.reshape( grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
-1, grad_output.shape[-1]
).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB: if req_gradB:
@ -376,17 +447,29 @@ class MatMul8bitLt(torch.autograd.Function):
if state.CBt is not None: if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32") C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None: if state.CxBt is None:
state.CxBt, state.SBt = F.transform( state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None: elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CxB is not None:
if state.tile_indices is None:
order, tile_size = state.formatB, state.get_tile_size()
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
with torch.no_grad():
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
CB = (
undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A)
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
)
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else: else:
raise Exception('State must contain either CBt or CB matrix for backward') raise Exception("State must contain either CBt or CB or CxB matrix for backward")
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None

View File

@ -80,9 +80,10 @@ class CUDASetup:
self.add_log_entry('python setup.py install') self.add_log_entry('python setup.py install')
def initialize(self): def initialize(self):
self.has_printed = False if not getattr(self, 'initialized', False):
self.lib = None self.has_printed = False
self.initialized = False self.lib = None
self.initialized = False
def run_cuda_setup(self): def run_cuda_setup(self):
self.initialized = True self.initialized = True
@ -103,7 +104,7 @@ class CUDASetup:
legacy_binary_name = "libbitsandbytes_cpu.so" legacy_binary_name = "libbitsandbytes_cpu.so"
self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...")
binary_path = package_dir / legacy_binary_name binary_path = package_dir / legacy_binary_name
if not binary_path.exists(): if not binary_path.exists() or torch.cuda.is_available():
self.add_log_entry('') self.add_log_entry('')
self.add_log_entry('='*48 + 'ERROR' + '='*37) self.add_log_entry('='*48 + 'ERROR' + '='*37)
self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:') self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:')
@ -112,6 +113,7 @@ class CUDASetup:
self.add_log_entry('3. You have multiple conflicting CUDA libraries') self.add_log_entry('3. You have multiple conflicting CUDA libraries')
self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!') self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!')
self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.')
self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.')
self.add_log_entry('='*80) self.add_log_entry('='*80)
self.add_log_entry('') self.add_log_entry('')
self.generate_instructions() self.generate_instructions()
@ -148,7 +150,7 @@ def is_cublasLt_compatible(cc):
if cc is not None: if cc is not None:
cc_major, cc_minor = cc.split('.') cc_major, cc_minor = cc.split('.')
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5): if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Proceeding to load CPU-only library...", is_warning=True) CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
else: else:
has_cublaslt = True has_cublaslt = True
return has_cublaslt return has_cublaslt
@ -362,7 +364,6 @@ def evaluate_cuda_setup():
print('') print('')
print('='*35 + 'BUG REPORT' + '='*35) print('='*35 + 'BUG REPORT' + '='*35)
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues') print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
print('='*80) print('='*80)
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None

View File

@ -209,19 +209,10 @@ class Int8Params(torch.nn.Parameter):
class Linear8bitLt(nn.Linear): class Linear8bitLt(nn.Linear):
def __init__( def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
self, memory_efficient_backward=False, threshold=0.0, index=None):
input_features, super().__init__(input_features, output_features, bias)
output_features, assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
):
super().__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index = index self.index = index
@ -231,9 +222,7 @@ class Linear8bitLt(nn.Linear):
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
self.weight = Int8Params( self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
)
def init_8bit_state(self): def init_8bit_state(self):
self.state.CB = self.weight.CB self.state.CB = self.weight.CB
@ -241,27 +230,20 @@ class Linear8bitLt(nn.Linear):
self.weight.CB = None self.weight.CB = None
self.weight.SCB = None self.weight.SCB = None
def forward(self, x): def forward(self, x: torch.Tensor):
self.state.is_training = self.training self.state.is_training = self.training
if self.weight.CB is not None: if self.weight.CB is not None:
self.init_8bit_state() self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != torch.float16: if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.half() self.bias.data = self.bias.data.to(x.dtype)
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None: if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass # we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight # we no longer need the row-major weight
del self.state.CB del self.state.CB
self.weight.data = self.state.CxB self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state.
del self.state.CxB
return out return out

View File

@ -18,7 +18,7 @@ def read(fname):
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.36.0-2", version=f"0.37.0",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description="8-bit optimizers and matrix multiplication routines.", description="8-bit optimizers and matrix multiplication routines.",

View File

@ -0,0 +1,61 @@
import bitsandbytes as bnb
import pytest
import torch
from bitsandbytes import functional as F
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
from bitsandbytes.nn.modules import Linear8bitLt
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
@pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5),
reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs",
)
def test_layout_exact_match():
x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda()
for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"):
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
tile_indices = get_inverse_transform_indices(transform, tile_size)
cxb = transform(x)
torch.cuda.synchronize()
restored_x = undo_layout(cxb, tile_indices)
torch.cuda.synchronize()
assert restored_x.is_contiguous()
assert torch.all(torch.eq(restored_x, x))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
def test_linear_no_igemmlt():
linear = torch.nn.Linear(1024, 3072)
x = torch.randn(3, 1024, dtype=torch.half)
linear_custom = Linear8bitLt(
linear.in_features,
linear.out_features,
linear.bias is not None,
has_fp16_weights=False,
threshold=6.0,
)
linear_custom.state.force_no_igemmlt = True
linear_custom.weight = bnb.nn.Int8Params(
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
).to(linear.weight.dtype)
linear_custom.bias = linear.bias
linear = linear_custom.cuda()
linear = linear.half().cuda()
x_ref = x.clone().cuda().requires_grad_(True)
x_ours = x.clone().cuda().requires_grad_(True)
fx_ref = linear(x_ref).float()
grad_proj = torch.randn_like(fx_ref)
(fx_ref * grad_proj).mean().backward()
fx_ours = linear_custom(x_ours).float()
(fx_ours * grad_proj).mean().backward()
assert torch.allclose(fx_ref, fx_ours, atol=0.02)
assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01)
assert not linear_custom.state.has_fp16_weights
assert linear_custom.state.CB is not None
assert linear_custom.state.CxB is None

View File

@ -382,7 +382,7 @@ names = [f"threshold_{vals}" for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
@pytest.mark.parametrize("memory_efficient_backward", [True, False]) @pytest.mark.parametrize("memory_efficient_backward", [False])
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
l1 = ( l1 = (
bnb.nn.Linear8bitLt( bnb.nn.Linear8bitLt(