reran black with linelength 80 for greater readability

This commit is contained in:
Titus von Koeller 2022-08-01 09:32:47 -07:00
parent 3fd06fb620
commit ea7c14f8ef
17 changed files with 665 additions and 203 deletions

View File

@ -3,8 +3,13 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .autograd._functions import (MatmulLtState, bmm_cublas, matmul, from .autograd._functions import (
matmul_cublas, mm_cublas) MatmulLtState,
bmm_cublas,
matmul,
matmul_cublas,
mm_cublas,
)
from .cextension import COMPILED_WITH_CUDA from .cextension import COMPILED_WITH_CUDA
from .nn import modules from .nn import modules

View File

@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function):
qgrad_output, S1 = F.vectorwise_quant( qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type grad_output, dim=dims, quant_type=quant_type
) )
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) qA, S2 = F.vectorwise_quant(
A, dim=dims, quant_type=quant_type
)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant( grad_B = F.vectorwise_mm_dequant(
igrad_B, igrad_B,
@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function):
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant( grad_A = F.vectorwise_mm_dequant(
igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type igrad_A,
S1,
S3.permute(permute_dim),
grad_output.dtype,
quant_type,
) )
return grad_A, grad_B, None, None, None return grad_A, grad_B, None, None, None
@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function):
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
A, threshold=state.threshold
)
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights: if state.has_fp16_weights:
@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function):
if state.CxB is None: if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format # we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(
state.CB, to_order=formatB
)
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB # # generate outlier index and subB
@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function):
if (state.is_training and not has_grad) or state.CxB is None: if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads() state.reset_grads()
CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B) (
CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B)
state.CxB, state.SB = F.transform(CB, to_order=formatB) state.CxB, state.SB = F.transform(CB, to_order=formatB)
else: else:
has_grad = False has_grad = False
@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = ( state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half() (outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.half()
) )
CA[:, state.idx.long()] = 0 CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0
@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB formatB = ctx.formatB
state = ctx.state state = ctx.state
assert state.has_fp16_weights, "Backprop only supported for fp16 weights." assert (
state.has_fp16_weights
), "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() grad_output = grad_output.view(
-1, grad_output.shape[-1]
).contiguous()
grad_A = grad_B = None grad_A = grad_B = None
@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply
def matmul( def matmul(
A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0 A: tensor,
B: tensor,
out: tensor = None,
state: MatmulLtState = None,
threshold=0.0,
): ):
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:

View File

@ -1,7 +1,7 @@
""" """
build is dependent on extract factors the build is dependent on:
- compute capability [X] compute capability
- dependent on GPU family [ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version - CUDA version
- Software: - Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
@ -19,6 +19,8 @@ evaluation:
""" """
import ctypes import ctypes
import shlex
import subprocess
from os import environ as env from os import environ as env
from pathlib import Path from pathlib import Path
from typing import Set, Union from typing import Set, Union
@ -26,10 +28,31 @@ from typing import Set, Union
from .utils import print_err, warn_of_missing_prerequisite from .utils import print_err, warn_of_missing_prerequisite
def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple):
return tuple(
to_decode.decode("UTF-8").strip()
for to_decode in subprocess_err_out_tuple
)
def execute_and_return_decoded_std_streams(command_string):
return _decode(
subprocess.Popen(
shlex.split(command_string),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
).communicate()
)
std_out, std_err = execute_and_return_decoded_std_streams()
return std_out, std_err
def check_cuda_result(cuda, result_val): def check_cuda_result(cuda, result_val):
if result_val != 0: if result_val != 0:
# TODO: undefined name 'error_str'
cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
print(f"Count not initialize CUDA - failure!") print("Count not initialize CUDA - failure!")
raise Exception("CUDA exception!") raise Exception("CUDA exception!")
return result_val return result_val
@ -53,7 +76,9 @@ def get_compute_capability():
result = ctypes.c_int() result = ctypes.c_int()
device = ctypes.c_int() device = ctypes.c_int()
# TODO: local variable 'context' is assigned to but never used
context = ctypes.c_void_p() context = ctypes.c_void_p()
# TODO: local variable 'error_str' is assigned to but never used
error_str = ctypes.c_char_p() error_str = ctypes.c_char_p()
result = check_cuda_result(cuda, cuda.cuInit(0)) result = check_cuda_result(cuda, cuda.cuInit(0))
@ -61,7 +86,9 @@ def get_compute_capability():
result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus))) result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
ccs = [] ccs = []
for i in range(nGpus.value): for i in range(nGpus.value):
result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) result = check_cuda_result(
cuda, cuda.cuDeviceGet(ctypes.byref(device), i)
)
result = check_cuda_result( result = check_cuda_result(
cuda, cuda,
cuda.cuDeviceComputeCapability( cuda.cuDeviceComputeCapability(
@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path(
} - non_existent_directories } - non_existent_directories
if len(cuda_runtime_libs) > 1: if len(cuda_runtime_libs) > 1:
err_msg = f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." err_msg = (
f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
)
raise FileNotFoundError(err_msg) raise FileNotFoundError(err_msg)
elif len(cuda_runtime_libs) < 1: elif len(cuda_runtime_libs) < 1:
err_msg = f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.." err_msg = (
f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
)
raise FileNotFoundError(err_msg) raise FileNotFoundError(err_msg)
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))

View File

@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS""" """C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {} str2optimizer32bit = {}
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) str2optimizer32bit["momentum"] = (
str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16) lib.cmomentum32bit_g32,
str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16) lib.cmomentum32bit_g16,
str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) )
str2optimizer32bit["rmsprop"] = (
lib.crmsprop32bit_g32,
lib.crmsprop32bit_g16,
)
str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_g32,
lib.cadagrad32bit_g16,
)
str2optimizer32bit["lars"] = (
lib.cmomentum32bit_g32,
lib.cmomentum32bit_g16,
)
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer8bit = {} str2optimizer8bit = {}
str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) str2optimizer8bit["adam"] = (
lib.cadam_static_8bit_g32,
lib.cadam_static_8bit_g16,
)
str2optimizer8bit["momentum"] = ( str2optimizer8bit["momentum"] = (
lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g32,
lib.cmomentum_static_8bit_g16, lib.cmomentum_static_8bit_g16,
@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA:
lib.crmsprop_static_8bit_g32, lib.crmsprop_static_8bit_g32,
lib.crmsprop_static_8bit_g16, lib.crmsprop_static_8bit_g16,
) )
str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16) str2optimizer8bit["lamb"] = (
lib.cadam_static_8bit_g32,
lib.cadam_static_8bit_g16,
)
str2optimizer8bit["lars"] = ( str2optimizer8bit["lars"] = (
lib.cmomentum_static_8bit_g32, lib.cmomentum_static_8bit_g32,
lib.cmomentum_static_8bit_g16, lib.cmomentum_static_8bit_g16,
@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7):
if not signed: if not signed:
additional_items = 2 * additional_items additional_items = 2 * additional_items
for i in range(n): for i in range(n):
fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1 fraction_items = (
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
)
boundaries = torch.linspace(0.1, 1, fraction_items) boundaries = torch.linspace(0.1, 1, fraction_items)
means = (boundaries[:-1] + boundaries[1:]) / 2.0 means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(n - 1) + i)) * means).tolist() data += ((10 ** (-(n - 1) + i)) * means).tolist()
@ -272,7 +292,13 @@ def get_transform_buffer(
def nvidia_transform( def nvidia_transform(
A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None A,
to_order,
from_order="row",
out=None,
transpose=False,
state=None,
ld=None,
): ):
if state is None: if state is None:
state = (A.shape, from_order) state = (A.shape, from_order)
@ -352,7 +378,11 @@ def estimate_quantiles(
def quantize_blockwise( def quantize_blockwise(
A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None A: Tensor,
code: Tensor = None,
absmax: Tensor = None,
rand=None,
out: Tensor = None,
) -> Tensor: ) -> Tensor:
""" """
Quantize tensor A in blocks of size 4096 values. Quantize tensor A in blocks of size 4096 values.
@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
""" """
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.float32) out = torch.zeros_like(A, dtype=torch.float32)
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) lib.cdequantize(
get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
)
return out return out
@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d(
) )
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): def check_matmul(
A, B, out, transposed_A, transposed_B, expected_type=torch.int8
):
if not torch.cuda.is_initialized(): if not torch.cuda.is_initialized():
torch.cuda.init() torch.cuda.init()
if A.dtype != expected_type or B.dtype != expected_type: if A.dtype != expected_type or B.dtype != expected_type:
@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
def igemm( def igemm(
A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False A: Tensor,
B: Tensor,
out: Tensor = None,
transposed_A=False,
transposed_B=False,
): ):
sout = check_matmul(A, B, out, transposed_A, transposed_B) sout = check_matmul(A, B, out, transposed_A, transposed_B)
if out is None: if out is None:
@ -1193,7 +1231,11 @@ def igemm(
def batched_igemm( def batched_igemm(
A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False A: Tensor,
B: Tensor,
out: Tensor = None,
transposed_A=False,
transposed_B=False,
): ):
if not len(A.shape) == 3 or not len(B.shape) == 3: if not len(A.shape) == 3 or not len(B.shape) == 3:
raise ValueError( raise ValueError(
@ -1392,9 +1434,13 @@ def mm_dequant(
if out is None: if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device) out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
if new_row_stats is None: if new_row_stats is None:
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) new_row_stats = torch.empty(
out_shape[0], dtype=torch.float32, device=A.device
)
if new_col_stats is None: if new_col_stats is None:
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) new_col_stats = torch.empty(
out_shape[1], dtype=torch.float32, device=A.device
)
assert ( assert (
new_row_stats.shape[0] == row_stats.shape[0] new_row_stats.shape[0] == row_stats.shape[0]
), f"{new_row_stats.shape} vs {row_stats.shape}" ), f"{new_row_stats.shape} vs {row_stats.shape}"
@ -1440,13 +1486,13 @@ def get_colrow_absmax(
col_tiles = (cols + 255) // 256 col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16 tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None: if row_stats is None:
row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_( row_stats = torch.empty(
-50000.0 (rows,), dtype=torch.float32, device=device
) ).fill_(-50000.0)
if col_stats is None: if col_stats is None:
col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_( col_stats = torch.empty(
-50000.0 (cols,), dtype=torch.float32, device=device
) ).fill_(-50000.0)
if nnz_block_ptr is None and threshold > 0.0: if nnz_block_ptr is None and threshold > 0.0:
nnz_block_ptr = torch.zeros( nnz_block_ptr = torch.zeros(
@ -1462,7 +1508,13 @@ def get_colrow_absmax(
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
lib.cget_col_row_stats( lib.cget_col_row_stats(
ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols ptrA,
ptrRowStats,
ptrColStats,
ptrNnzrows,
ct.c_float(threshold),
rows,
cols,
) )
post_call(prev_device) post_call(prev_device)
@ -1526,7 +1578,9 @@ class CSCSparseTensor(object):
def coo2csr(cooA): def coo2csr(cooA):
values, counts = torch.unique(cooA.rowidx, return_counts=True) values, counts = torch.unique(cooA.rowidx, return_counts=True)
values.add_(1) values.add_(1)
rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device) rowptr = torch.zeros(
(cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
)
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
rowptr.cumsum_(0) rowptr.cumsum_(0)
return CSRSparseTensor( return CSRSparseTensor(
@ -1540,10 +1594,14 @@ def coo2csc(cooA):
values = cooA.values[col2rowidx] values = cooA.values[col2rowidx]
colvalues, counts = torch.unique(val, return_counts=True) colvalues, counts = torch.unique(val, return_counts=True)
colvalues.add_(1) colvalues.add_(1)
colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device) colptr = torch.zeros(
(cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
)
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
colptr.cumsum_(0) colptr.cumsum_(0)
return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values) return CSCSparseTensor(
cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
)
def coo_zeros(rows, cols, nnz, device, dtype=torch.half): def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
@ -1568,7 +1626,9 @@ def double_quant(
rows = A.shape[0] rows = A.shape[0]
if row_stats is None or col_stats is None: if row_stats is None or col_stats is None:
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
A, threshold=threshold
)
if out_col is None: if out_col is None:
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
@ -1663,7 +1723,13 @@ def get_special_format_str():
def transform( def transform(
A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None A,
to_order,
from_order="row",
out=None,
transpose=False,
state=None,
ld=None,
): ):
if state is None: if state is None:
state = (A.shape, from_order) state = (A.shape, from_order)
@ -1716,7 +1782,9 @@ def transform(
def spmm_coo(cooA, B, out=None): def spmm_coo(cooA, B, out=None):
if out is None: if out is None:
out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) out = torch.empty(
(cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
)
nnz = cooA.nnz nnz = cooA.nnz
assert cooA.rowidx.numel() == nnz assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz assert cooA.colidx.numel() == nnz
@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx):
assert formatA in ["col_turing", "col_ampere"] assert formatA in ["col_turing", "col_ampere"]
assert A.device.type == "cuda" assert A.device.type == "cuda"
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) out = torch.zeros(
(shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
)
idx_size = ct.c_int32(idx.numel()) idx_size = ct.c_int32(idx.numel())
rows = ct.c_int32(shapeA[0]) rows = ct.c_int32(shapeA[0])

View File

@ -2,8 +2,19 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set, from typing import (
Tuple, TypeVar, Union, overload) Any,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
overload,
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding):
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
def __new__( def __new__(
cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None cls,
data=None,
requires_grad=True,
has_fp16_weights=False,
CB=None,
SCB=None,
): ):
cls.has_fp16_weights = has_fp16_weights cls.has_fp16_weights = has_fp16_weights
cls.CB = None cls.CB = None
@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter):
return self.cuda(device) return self.cuda(device)
else: else:
new_param = Int8Params( new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking), super().to(
device=device, dtype=dtype, non_blocking=non_blocking
),
requires_grad=self.requires_grad, requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights, has_fp16_weights=self.has_fp16_weights,
) )
@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear):
threshold=0.0, threshold=0.0,
index=None, index=None,
): ):
super(Linear8bitLt, self).__init__(input_features, output_features, bias) super(Linear8bitLt, self).__init__(
input_features, output_features, bias
)
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index = index self.index = index
@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear):
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights) self.weight = Int8Params(
self.weight.data, has_fp16_weights=has_fp16_weights
)
def init_8bit_state(self): def init_8bit_state(self):
self.state.CB = self.weight.CB self.state.CB = self.weight.CB

View File

@ -23,7 +23,9 @@ class Adagrad(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:

View File

@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer):
savedir=None, savedir=None,
): ):
defaults = dict( defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
) )
super(AnalysisAdam, self).__init__(params, defaults) super(AnalysisAdam, self).__init__(params, defaults)
self.analysis = bnb_analysis self.analysis = bnb_analysis
@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer):
state["relerrors"] = torch.zeros( state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device (256, 256), device=p_data_fp32.device
) )
state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device) state["counts"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer):
beta1, beta2 = group["betas"] beta1, beta2 = group["betas"]
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 step_size = (
group["lr"] * math.sqrt(bias_correction2) / bias_correction1
)
e = state["abserrors"] e = state["abserrors"]
rele = state["relerrors"] rele = state["relerrors"]
counts = state["counts"] counts = state["counts"]
@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer):
denom = exp_avg_sq.sqrt().add_(group["eps"]) denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg / denom update_fp32 = exp_avg / denom
if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000: if (
p_data_fp32.numel() <= 8192
or p_data_fp32.numel() > 50000 * 1000
):
# embedding layer or too small # embedding layer or too small
p_data_fp32 += -step_size * update_fp32 p_data_fp32 += -step_size * update_fp32
else: else:
@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer):
# 3. dequantize # 3. dequantize
# Error will be calculated automatically! # Error will be calculated automatically!
else: else:
raise ValueError(f"Invalid analysis value: {self.analysis}!") raise ValueError(
f"Invalid analysis value: {self.analysis}!"
)
denom = state2.sqrt().add_(group["eps"]) denom = state2.sqrt().add_(group["eps"])
update_8bit = state1 / denom update_8bit = state1 / denom
@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer):
if self.savedir != "" and state["step"] % 100 == 0: if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir): if not os.path.exists(self.savedir):
os.makedirs(self.savedir) os.makedirs(self.savedir)
shapestr = "_".join([str(dim) for dim in p_data_fp32.shape]) shapestr = "_".join(
[str(dim) for dim in p_data_fp32.shape]
)
pathe = os.path.join( pathe = os.path.join(
self.savedir, f"{p_id}_{shapestr}_abserr.pkl" self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
) )

View File

@ -24,7 +24,9 @@ class LARS(Optimizer1State):
max_unorm=0.02, max_unorm=0.02,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"LARS without momentum is not supported!") raise NotImplementedError(
f"LARS without momentum is not supported!"
)
super(LARS, self).__init__( super(LARS, self).__init__(
"lars", "lars",
params, params,
@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State):
max_unorm=0.02, max_unorm=0.02,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"LARS without momentum is not supported!") raise NotImplementedError(
f"LARS without momentum is not supported!"
)
super(LARS8bit, self).__init__( super(LARS8bit, self).__init__(
"lars", "lars",
params, params,
@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State):
max_unorm=0.02, max_unorm=0.02,
): ):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f"LARS without momentum is not supported!") raise NotImplementedError(
f"LARS without momentum is not supported!"
)
super(LARS32bit, self).__init__( super(LARS32bit, self).__init__(
"lars", "lars",
params, params,
@ -121,7 +127,9 @@ class PytorchLARS(Optimizer):
if momentum < 0.0: if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum)) raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
defaults = dict( defaults = dict(
lr=lr, lr=lr,
@ -132,7 +140,9 @@ class PytorchLARS(Optimizer):
max_unorm=max_unorm, max_unorm=max_unorm,
) )
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError(
"Nesterov momentum requires a momentum and zero dampening"
)
super(PytorchLARS, self).__init__(params, defaults) super(PytorchLARS, self).__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -46,9 +46,13 @@ class GlobalOptimManager(object):
for group_index, group in enumerate(param_groups): for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]): for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config: if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[id(p)] self.index2config[(group_index, p_index)] = self.pid2config[
id(p)
]
def override_config(self, parameters, key=None, value=None, key_value_dict=None): def override_config(
self, parameters, key=None, value=None, key_value_dict=None
):
""" """
Overrides initial optimizer config for specific parameters. Overrides initial optimizer config for specific parameters.
@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer):
if len(groups) != len(saved_groups): if len(groups) != len(saved_groups):
raise ValueError( raise ValueError(
"loaded state dict has a different number of " "parameter groups" "loaded state dict has a different number of "
"parameter groups"
) )
param_lens = (len(g["params"]) for g in groups) param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups) saved_lens = (len(g["params"]) for g in saved_groups)
@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer):
new_group["params"] = group["params"] new_group["params"] = group["params"]
return new_group return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)
]
self.__setstate__({"state": state, "param_groups": param_groups}) self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self): def to_gpu(self):
@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer):
# found the matching parameter # found the matching parameter
# init override # init override
self.mng.pid2config[id(p)] = config self.mng.pid2config[id(p)] = config
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[ self.mng.index2config[
id(p) (gindex, pindex)
] ] = self.mng.pid2config[id(p)]
found = True found = True
@torch.no_grad() @torch.no_grad()
@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer):
raise NotImplementedError(f"init_state method needs to be overidden") raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex): def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(f"The update_step method needs to be overidden") raise NotImplementedError(
f"The update_step method needs to be overidden"
)
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):
@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit):
betas = [float(b) for b in betas] betas = [float(b) for b in betas]
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") raise ValueError(
f"Invalid beta parameter at index {i}: {betas[i]}"
)
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits) super(Optimizer2State, self).__init__(params, defaults, optim_bits)
@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit):
state = self.state[p] state = self.state[p]
state["step"] = 0 state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = torch.zeros_like( state["state1"] = torch.zeros_like(
p, p,
memory_format=torch.preserve_format, memory_format=torch.preserve_format,
@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit):
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
self.fill_qmap() self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device) p.device
)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
p.device
)
state["state1"] = torch.zeros_like( state["state1"] = torch.zeros_like(
p, p,
@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit):
(blocks,), dtype=torch.float32, device=p.device (blocks,), dtype=torch.float32, device=p.device
) )
else: else:
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) state["max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max1"] = torch.zeros( state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device (1,), dtype=torch.float32, device=p.device
) )
state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device) state["max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max2"] = torch.zeros( state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device (1,), dtype=torch.float32, device=p.device
) )
@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit):
state["new_max2"], state["new_max2"],
config["weight_decay"], config["weight_decay"],
gnorm_scale=gnorm_scale, gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None, unorm_vec=state["unorm_vec"]
if config["max_unorm"] > 0.0
else None,
max_unorm=config["max_unorm"], max_unorm=config["max_unorm"],
) )
@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit):
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") raise ValueError(
f"Invalid beta parameter at index {i}: {betas[i]}"
)
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay)
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits) super(Optimizer1State, self).__init__(params, defaults, optim_bits)
@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit):
state = self.state[p] state = self.state[p]
state["step"] = 0 state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = torch.zeros_like( state["state1"] = torch.zeros_like(
p, p,
memory_format=torch.preserve_format, memory_format=torch.preserve_format,
@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit):
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
self.fill_qmap() self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device) self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
p.device
)
state["state1"] = torch.zeros_like( state["state1"] = torch.zeros_like(
p, p,
@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit):
(blocks,), dtype=torch.float32, device=p.device (blocks,), dtype=torch.float32, device=p.device
) )
else: else:
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device) state["max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["new_max1"] = torch.zeros( state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device (1,), dtype=torch.float32, device=p.device
) )

View File

@ -22,7 +22,9 @@ class RMSprop(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if alpha == 0: if alpha == 0:
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__( super(RMSprop, self).__init__(
@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State):
block_wise=True, block_wise=True,
): ):
if alpha == 0: if alpha == 0:
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__( super(RMSprop8bit, self).__init__(
@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State):
): ):
if alpha == 0: if alpha == 0:
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!") raise NotImplementedError(
f"RMSprop with alpha==0.0 is not supported!"
)
if centered: if centered:
raise NotImplementedError(f"Centered RMSprop is not supported!") raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__( super(RMSprop32bit, self).__init__(

View File

@ -1,6 +1,6 @@
import sys
import shlex import shlex
import subprocess import subprocess
import sys
def execute_and_return(command_string: str) -> Tuple[str, str]: def execute_and_return(command_string: str) -> Tuple[str, str]:

View File

@ -14,23 +14,31 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
torch.int8 torch.int8
) )
elif dims == 3: elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( A = torch.randint(
torch.int8 -128, 127, size=(dim1, dim2, dim3), device="cuda"
) ).to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.t().float()) C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32") A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, "colx") B2, SB = F.transform(B, "colx")
if dims == 2: if dims == 2:
C2, SC = F.transform( C2, SC = F.transform(
torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"), torch.zeros(
A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"
),
"col32", "col32",
) )
else: else:
C2, SC = F.transform( C2, SC = F.transform(
torch.zeros( torch.zeros(
A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda" A.shape[0],
A.shape[1],
B.shape[0],
dtype=torch.int32,
device="cuda",
), ),
"col32", "col32",
) )

View File

@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)] transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ["FF", "FT", "TT", "TF"] str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16] dtype = [torch.float32, torch.float16]
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) values = list(
product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
)
str_values = list( str_values = list(
product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose) product(
dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
)
) )
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
@ -31,7 +35,9 @@ names = [
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
values,
ids=names,
) )
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None A.grad = None
B.grad = None B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) torch.testing.assert_allclose(
gradB1, gradB2, atol=0.18, rtol=0.3
)
# batched matrix multiply # batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]: if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn( A = torch.randn(
size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0] size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
) )
B = torch.randn( B = torch.randn(
size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1] size=(dim1, dim3, dim4),
device="cuda",
requires_grad=req_grad[1],
) )
target = torch.randn( target = torch.randn(
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1] size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
) )
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx == 0).sum().item() < n * 0.01 assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) torch.testing.assert_allclose(
out_bnb, out_torch, atol=0.027, rtol=0.2
)
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None A.grad = None
B.grad = None B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if funcs[0] in [torch.matmul]: if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16) dim1 = dim1 - (dim1 % 16)
A = torch.randn( A = torch.randn(
size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0] size=(dim1, dim2, dim3),
device="cuda",
requires_grad=req_grad[0],
) )
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn( target = torch.randn(
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1] size=(dim1, dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
) )
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
A.grad = None A.grad = None
B.grad = None B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
@ -258,7 +290,16 @@ names = [
ids=names, ids=names,
) )
def test_matmullt( def test_matmullt(
dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights dim1,
dim2,
dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
): ):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
@ -278,7 +319,10 @@ def test_matmullt(
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
) )
target = torch.randn( target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype size=(dim2, dim4),
device="cuda",
requires_grad=req_grad[1],
dtype=dtype,
) )
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B2 = B.clone() B2 = B.clone()
@ -317,14 +361,18 @@ def test_matmullt(
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
torch.cuda.synchronize() torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() loss_bnb = torch.nn.functional.mse_loss(
out_bnb, target
).mean()
loss_bnb.backward() loss_bnb.backward()
gradA1 = A.grad gradA1 = A.grad
gradB1 = B.grad gradB1 = B.grad
A.grad = None A.grad = None
B.grad = None B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() loss_torch = torch.nn.functional.mse_loss(
out_torch, target
).mean()
loss_torch.backward() loss_torch.backward()
gradA2 = A.grad gradA2 = A.grad
gradB2 = B.grad gradB2 = B.grad
@ -332,7 +380,9 @@ def test_matmullt(
B.grad = None B.grad = None
if req_grad[0]: if req_grad[0]:
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) torch.testing.assert_allclose(
gradA1, gradA2, atol=0.015, rtol=0.1
)
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
assert torch.abs(gradB1).sum() > 0.0 assert torch.abs(gradB1).sum() > 0.0
@ -341,4 +391,6 @@ def test_matmullt(
assert (idx == 0).sum().item() < n * 0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() < n * 0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) torch.testing.assert_allclose(
gradB1, gradB2, atol=0.18, rtol=0.3
)

View File

@ -3,8 +3,12 @@ from typing import List, NamedTuple
import pytest import pytest
from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup, from bitsandbytes.cuda_setup import (
get_cuda_runtime_lib_path, tokenize_paths) CUDA_RUNTIME_LIB,
evaluate_cuda_setup,
get_cuda_runtime_lib_path,
tokenize_paths,
)
class InputAndExpectedOutput(NamedTuple): class InputAndExpectedOutput(NamedTuple):
@ -13,11 +17,26 @@ class InputAndExpectedOutput(NamedTuple):
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [ HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (
(f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), f"dir/with/{CUDA_RUNTIME_LIB}",
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), ),
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), (
f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
(
f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
( (
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}",

View File

@ -86,7 +86,9 @@ def teardown():
pass pass
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"]) @pytest.mark.parametrize(
"dtype", [torch.float32, torch.float16], ids=["float", "half"]
)
def test_estimate_quantiles(dtype): def test_estimate_quantiles(dtype):
A = torch.rand(1024, 1024, device="cuda") A = torch.rand(1024, 1024, device="cuda")
A = A.to(dtype) A = A.to(dtype)
@ -190,7 +192,9 @@ def test_dynamic_blockwise_stochastic_quantization():
) )
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"]) @pytest.mark.parametrize(
"gtype", [torch.float32, torch.float16], ids=["float", "half"]
)
def test_percentile_clipping(gtype): def test_percentile_clipping(gtype):
gnorm_vec1 = torch.zeros(100, device="cuda") gnorm_vec1 = torch.zeros(100, device="cuda")
gnorm_vec2 = torch.zeros(100, device="cuda") gnorm_vec2 = torch.zeros(100, device="cuda")
@ -270,7 +274,13 @@ def mean(xx):
dim1 = [1024 * 2] dim1 = [1024 * 2]
dim2 = [1024 * 16] dim2 = [1024 * 16]
methods = [ methods = [
(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant) (
lambda x, dim: quant(x),
lambda x, dim: quant(x),
dequant,
dequant,
mm_dequant,
)
] ]
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) # methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
@ -279,11 +289,14 @@ batched = [False, True]
values = list(product(dim1, dim2, methods, batched)) values = list(product(dim1, dim2, methods, batched))
values_names = list(product(dim1, dim2, method_names, batched)) values_names = list(product(dim1, dim2, method_names, batched))
names = [ names = [
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names "dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
for vals in values_names
] ]
@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names) @pytest.mark.parametrize(
"dim1, dim2, quant_methods, batched", values, ids=names
)
def test_approx_igemm(dim1, dim2, quant_methods, batched): def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32) dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32) dim2 = dim2 - (dim2 % 32)
@ -339,14 +352,18 @@ names = [
] ]
@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names) @pytest.mark.parametrize(
"hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
)
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
hidden_dim = hidden_dim - (hidden_dim % 32) hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16) batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16) seq_dim = seq_dim - (seq_dim % 16)
for i in range(k): for i in range(k):
shapeA = ( shapeA = (
(batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim) (batch_dim, hidden_dim)
if not transpose[0]
else (hidden_dim, batch_dim)
) )
shapeB = ( shapeB = (
(32 * random.randint(1, 4), hidden_dim) (32 * random.randint(1, 4), hidden_dim)
@ -394,7 +411,9 @@ seq_dim = torch.randint(32, 512, size=(n,)).tolist()
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
batch_dim = torch.randint(2, 16, size=(n,)).tolist() batch_dim = torch.randint(2, 16, size=(n,)).tolist()
values = list(product(seq_dim, hidden_dim, batch_dim)) values = list(product(seq_dim, hidden_dim, batch_dim))
names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values] names = [
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) @pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
@ -406,11 +425,13 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
A = torch.randint( A = torch.randint(
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
).to(torch.int8) ).to(torch.int8)
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to( B = torch.randint(
torch.int8 -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
) ).to(torch.int8)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device) iout = torch.empty(
A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
)
out = F.igemm(A, B, out=iout) out = F.igemm(A, B, out=iout)
torch.testing.assert_allclose(out.float(), out2) torch.testing.assert_allclose(out.float(), out2)
@ -428,7 +449,9 @@ names = [
] ]
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names) @pytest.mark.parametrize(
"seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
)
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x): def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True) maxA = torch.amax(x, dim=2, keepdim=True)
@ -444,7 +467,9 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
errs2 = [] errs2 = []
relerrs2 = [] relerrs2 = []
for i in range(k): for i in range(k):
A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda") A = torch.normal(
0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
)
if transpose: if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else: else:
@ -504,7 +529,8 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
transpose = [(False, False), (True, False), (False, True), (True, True)] transpose = [(False, False), (True, False), (False, True), (True, True)]
values = list(product(dim1, dim2, dim3, dim4, transpose)) values = list(product(dim1, dim2, dim3, dim4, transpose))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
for vals in values
] ]
@ -529,7 +555,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B) out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]: elif transpose[0] and transpose[1]:
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()) out2 = torch.bmm(
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
)
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_allclose(out.float(), out2.float()) torch.testing.assert_allclose(out.float(), out2.float())
@ -563,7 +591,9 @@ a_order = ["row"]
out_order = ["col", "row", "col32"] out_order = ["col", "row", "col32"]
transpose = [False] transpose = [False]
dims = [2, 3] dims = [2, 3]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format( "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
@ -574,9 +604,13 @@ names = [
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
values,
ids=names,
) )
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): def test_nvidia_transform(
dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose
):
if dims == 3 and out_order != "col32": if dims == 3 and out_order != "col32":
return return
if dtype == torch.int32 and out_order != "col32": if dtype == torch.int32 and out_order != "col32":
@ -586,7 +620,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if dims == 2: if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
elif dims == 3: elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype) A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
dtype
)
out, S = F.nvidia_transform(A, to_order=orderOut) out, S = F.nvidia_transform(A, to_order=orderOut)
@ -598,7 +634,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
if dims == 2: if dims == 2:
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
elif dims == 3: elif dims == 3:
n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32))) n = (
A.shape[0]
* A.shape[1]
* (A.shape[2] + (32 - (A.shape[2] % 32)))
)
assert out.numel() == n assert out.numel() == n
elif orderOut == "col_turing": elif orderOut == "col_turing":
# 32 col 8 row tiles # 32 col 8 row tiles
@ -613,7 +653,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
j = col j = col
coltile = (col // 32) + (1 if col % 32 != 0 else 0) coltile = (col // 32) + (1 if col % 32 != 0 else 0)
rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile rowtile = (
(row // 8) + (1 if row % 8 != 0 else 0)
) * total_coltile
offset = 32 * 8 * (rowtile + coltile) offset = 32 * 8 * (rowtile + coltile)
col2 = col % 32 col2 = col % 32
row2 = (row % 8) * 32 row2 = (row % 8) * 32
@ -624,7 +666,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
if orderOut == "col32": if orderOut == "col32":
out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S) out2, S = F.nvidia_transform(
out, from_order=orderOut, to_order="row", state=S
)
torch.testing.assert_allclose(A, out2) torch.testing.assert_allclose(A, out2)
@ -657,10 +701,12 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch.int8 torch.int8
) )
elif dims == 3: elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( A = torch.randint(
torch.int8 -128, 127, size=(dim1, dim2, dim3), device="cuda"
) ).to(torch.int8)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.t().float()) C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, "col32") A2, SA = F.transform(A, "col32")
@ -670,7 +716,9 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
torch.testing.assert_allclose(C1, C3.float()) torch.testing.assert_allclose(C1, C3.float())
# transpose # transpose
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8) B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
torch.int8
)
C1 = torch.matmul(A.float(), B.float()) C1 = torch.matmul(A.float(), B.float())
B2t, SBt = F.transform(B, "col_turing", transpose=True) B2t, SBt = F.transform(B, "col_turing", transpose=True)
@ -688,7 +736,8 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
values = list(product(dim1, dim2, dim3, dim4, dims)) values = list(product(dim1, dim2, dim3, dim4, dims))
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
for vals in values
] ]
@ -699,7 +748,9 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
if dims == 2: if dims == 2:
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
elif dims == 3: elif dims == 3:
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half() A = torch.normal(
0, 0.5, size=(dim1, dim2, dim3), device="cuda"
).half()
B = torch.randn((dim4, dim3), device="cuda").half() B = torch.randn((dim4, dim3), device="cuda").half()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
C1 = torch.matmul(A, B.t()) C1 = torch.matmul(A, B.t())
@ -742,7 +793,9 @@ values = [
# values = list(product(batch, seq, model, hidden)) # values = list(product(batch, seq, model, hidden))
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values] names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
@ -909,7 +962,9 @@ dims = (2,)
# ldb = list(range(256, 1*1024, 256)) # ldb = list(range(256, 1*1024, 256))
formatB = ["col_turing", "col_ampere"] formatB = ["col_turing", "col_ampere"]
values = list(product(dim1, dim4, dims, formatB)) values = list(product(dim1, dim4, dims, formatB))
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values] names = [
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names) @pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
@ -992,7 +1047,9 @@ def test_colrow_absmax(dim1, dim2, dims):
torch.testing.assert_allclose(row_stats1_trunc, row_stats2) torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2) torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0) row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
A, threshold=0.0
)
torch.testing.assert_allclose(col_stats1, col_stats2) torch.testing.assert_allclose(col_stats1, col_stats2)
torch.testing.assert_allclose(row_stats1, row_stats2) torch.testing.assert_allclose(row_stats1, row_stats2)
@ -1023,8 +1080,12 @@ def test_double_quant(dim1, dim2):
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
n = CAt.numel() n = CAt.numel()
num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() num_not_close_rows = (
num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
)
num_not_close_cols = (
(torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
)
# allow for 1:500 error due to rounding differences # allow for 1:500 error due to rounding differences
min_error = 1 / 500 min_error = 1 / 500
@ -1123,7 +1184,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
c = 10.0 * inner * scale c = 10.0 * inner * scale
row_scale = torch.ones_like(maxA) / c row_scale = torch.ones_like(maxA) / c
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) outC32, SC = F.igemmlt(
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
)
C3, S = F.nvidia_transform(outC32, "row", state=SC) C3, S = F.nvidia_transform(outC32, "row", state=SC)
maxval = torch.abs(C3).max() maxval = torch.abs(C3).max()
if maxval == 127: if maxval == 127:
@ -1204,7 +1267,9 @@ def test_row_scale_bench(dim1, dim4, inner):
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
for i in range(k): for i in range(k):
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale) outC32, SC = F.igemmlt(
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
)
torch.cuda.synchronize() torch.cuda.synchronize()
print("row-wise", time.time() - t0) print("row-wise", time.time() - t0)
@ -1230,7 +1295,9 @@ a_order = ["row"]
out_order = ["col32", "col_turing", "col_ampere"] out_order = ["col32", "col_turing", "col_ampere"]
transpose = [False, True] transpose = [False, True]
dims = [2] dims = [2]
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) values = list(
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
)
names = [ names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format( "dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
*vals *vals
@ -1240,14 +1307,20 @@ names = [
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
values,
ids=names,
) )
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype) A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
dtype
)
elif dims == 3: elif dims == 3:
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype) A = torch.randint(
10, 99, size=(dim1, dim2, dim3), device="cuda"
).to(dtype)
A.view(-1)[-1] = -1 A.view(-1)[-1] = -1
if transpose: if transpose:
@ -1282,7 +1355,9 @@ names = [
] ]
@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names) @pytest.mark.parametrize(
"dim1, dim2, dtype, orderA, orderOut", values, ids=names
)
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut): def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
for i in range(1): for i in range(1):
A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype) A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
@ -1332,17 +1407,23 @@ def test_coo_double_quant(dim1, dim2):
idx = torch.abs(A) >= threshold idx = torch.abs(A) >= threshold
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
A, threshold=threshold
)
if coo_tensor is not None: if coo_tensor is not None:
A1 = A * idx A1 = A * idx
A2 = torch.zeros_like(A) A2 = torch.zeros_like(A)
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values A2[
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
] = coo_tensor.values
torch.testing.assert_allclose(A1, A2) torch.testing.assert_allclose(A1, A2)
A1 = A * (idx == 0) A1 = A * (idx == 0)
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) torch.testing.assert_allclose(
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
)
n = 2 n = 2
@ -1454,7 +1535,9 @@ def test_integrated_sparse_decomp(dim1, dim2):
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
A, threshold=threshold
)
C32A, SA = F.transform(CA, "col32") C32A, SA = F.transform(CA, "col32")
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
@ -1494,7 +1577,9 @@ dim2 = [12288]
dtype = [torch.float16] dtype = [torch.float16]
out_function = ["zeros", "ones"] out_function = ["zeros", "ones"]
values = list(product(dim1, dim2, dtype, out_function)) values = list(product(dim1, dim2, dtype, out_function))
names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) @pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
@ -1536,7 +1621,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
std = out1.std() std = out1.std()
out1 /= std out1 /= std
out2 /= std out2 /= std
assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count) assert_all_approx_close(
out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
)
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
idx_col = torch.randint(0, A2.shape[-1], size=(15,)) idx_col = torch.randint(0, A2.shape[-1], size=(15,))
@ -1734,7 +1821,9 @@ values.append((batch_size, seqdim, 768, 4 * 768))
# values.append((batch_size, seqdim, 4096, 4*4096)) # values.append((batch_size, seqdim, 4096, 4*4096))
# values.append((batch_size, seqdim, 5140, 4*5140)) # values.append((batch_size, seqdim, 5140, 4*5140))
# values.append((batch_size, seqdim, 12288, 4*12288)) # values.append((batch_size, seqdim, 12288, 4*12288))
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values] names = [
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)

View File

@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
class LinearFunction(torch.autograd.Function): class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
norm = math.sqrt(math.pi) / math.sqrt(2.0) norm = math.sqrt(math.pi) / math.sqrt(2.0)
# std = torch.abs(x).mean()*norm # std = torch.abs(x).mean()*norm
std = torch.std(x) std = torch.std(x)
@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function):
return x.to(dtype) return x.to(dtype)
def get_8bit_linear(x, stochastic=False): def get_8bit_linear(x, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
max1 = torch.abs(x).max() max1 = torch.abs(x).max()
x = x / max1 * 127 x = x / max1 * 127
x = round_func(x) / 127 * max1 x = round_func(x) / 127 * max1
@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False): def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = (
LinearFunction.round_stoachastic if stochastic else torch.round
)
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0 max1[max1 == 0] = 1.0
x = (x * 127) / max1 x = (x * 127) / max1
@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function):
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t()) outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) output = LinearFunction.dequant(
outputq, S1, S2, x.dtype, args.quant_type
)
# if torch.rand(1) < 0.01: # if torch.rand(1) < 0.01:
# output32 = torch.matmul(x, weight.t()) # output32 = torch.matmul(x, weight.t())
# err = torch.abs(output-output32).float() # err = torch.abs(output-output32).float()
@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function):
grad_weight8, S1, S2, grad_output.dtype, args.quant_type grad_weight8, S1, S2, grad_output.dtype, args.quant_type
) )
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=2
)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8) grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant( grad_input = LinearFunction.dequant(
@ -338,8 +348,12 @@ def test_linear8bit():
loss2.backward() loss2.backward()
loss3.backward() loss3.backward()
assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) assert_all_approx_close(
assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
)
assert_all_approx_close(
l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
)
assert_all_approx_close( assert_all_approx_close(
l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2 l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
) )
@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential( l1 = torch.nn.Sequential(
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
) )
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) l2 = torch.nn.Sequential(
*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
)
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0: if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.cuda()
.half()
)
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
if threshold > 0: if threshold > 0:
assert mlp.fc2.state.idx is not None assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.half()
.cuda()
)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device="cuda").half()
@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda") mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.half()
.to("cuda")
)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device="cuda").half() b1 = torch.randn(16, 8, 32, device="cuda").half()

View File

@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"), ("exp_avg_sq", "state2", "qmap2", "absmax2"),
] ]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1")
]
str2statenames["momentum8bit_blockwise"] = [ str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1") ("momentum_buffer", "state1", "qmap1", "absmax1")
] ]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit_blockwise"] = [
("square_avg", "state1", "qmap1", "absmax1")
]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"] optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8 eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize() bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) bnb.optim.GlobalOptimManager.get_instance().override_config(
p3, "optim_bits", 8
)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) bnb.optim.GlobalOptimManager.get_instance().register_parameters(
[p1, p2, p3]
)
p1 = p1.cuda() p1 = p1.cuda()
p2 = p2.cuda() p2 = p2.cuda()
p3 = p3.cuda() p3 = p3.cuda()
@ -245,7 +255,9 @@ optimizer_names = [
"rmsprop8bit_blockwise", "rmsprop8bit_blockwise",
] ]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_allclose(
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) raws1cpy, bnb_optimizer.state[p2][name2]
)
torch.testing.assert_allclose(
qmap1, bnb_optimizer.state[p2][qmap]
)
if "blockwise" in optim_name: if "blockwise" in optim_name:
s1 = F.dequantize_blockwise( s1 = F.dequantize_blockwise(
@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
num_not_close = ( num_not_close = (
torch.isclose( torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol torch_optimizer.state[p1][name1],
s1,
atol=atol,
rtol=rtol,
) )
== 0 == 0
) )
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) torch.testing.assert_allclose(
p1, p2.float(), atol=patol, rtol=prtol
)
# the parameters diverge quickly. Here we keep them close # the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error # together so we can test against the Adam error
@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097]
gtype = [torch.float32] gtype = [torch.float32]
optim_bits = [32, 8] optim_bits = [32, 8]
values = list(product(dim1, dim2, gtype, optim_bits)) values = list(product(dim1, dim2, gtype, optim_bits))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p2 = p1.clone() p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
adam2 = bnb.optim.Adam( adam2 = bnb.optim.Adam(
[p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5 [p2],
lr,
(beta1, beta2),
eps,
optim_bits=optim_bits,
percentile_clipping=5,
) )
gnorm_vec = torch.zeros(100).cuda() gnorm_vec = torch.zeros(100).cuda()
@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
for i in range(50): for i in range(50):
step += 1 step += 1
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i) g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
0.01 * i
)
g2 = g1.clone() g2 = g1.clone()
p2.grad = g2 p2.grad = g2
@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
elif optim_bits == 8: elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose( torch.testing.assert_allclose(
adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3 adam1.state[p1]["state1"],
adam2.state[p2]["state1"],
atol=2,
rtol=1e-3,
) )
torch.testing.assert_allclose( torch.testing.assert_allclose(
adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3 adam1.state[p1]["state2"],
adam2.state[p2]["state2"],
atol=2,
rtol=1e-3,
) )
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"]) adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"]) adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['lars_apex', 'lars8bit'] # optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"] optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values] names = [
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)