forked from mrq/bitsandbytes-rocm
reran black with linelength 80 for greater readability
This commit is contained in:
parent
3fd06fb620
commit
ea7c14f8ef
|
@ -3,8 +3,13 @@
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from .autograd._functions import (MatmulLtState, bmm_cublas, matmul,
|
from .autograd._functions import (
|
||||||
matmul_cublas, mm_cublas)
|
MatmulLtState,
|
||||||
|
bmm_cublas,
|
||||||
|
matmul,
|
||||||
|
matmul_cublas,
|
||||||
|
mm_cublas,
|
||||||
|
)
|
||||||
from .cextension import COMPILED_WITH_CUDA
|
from .cextension import COMPILED_WITH_CUDA
|
||||||
from .nn import modules
|
from .nn import modules
|
||||||
|
|
||||||
|
|
|
@ -111,7 +111,9 @@ class MatMul8bit(torch.autograd.Function):
|
||||||
qgrad_output, S1 = F.vectorwise_quant(
|
qgrad_output, S1 = F.vectorwise_quant(
|
||||||
grad_output, dim=dims, quant_type=quant_type
|
grad_output, dim=dims, quant_type=quant_type
|
||||||
)
|
)
|
||||||
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
|
qA, S2 = F.vectorwise_quant(
|
||||||
|
A, dim=dims, quant_type=quant_type
|
||||||
|
)
|
||||||
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
|
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
|
||||||
grad_B = F.vectorwise_mm_dequant(
|
grad_B = F.vectorwise_mm_dequant(
|
||||||
igrad_B,
|
igrad_B,
|
||||||
|
@ -146,7 +148,11 @@ class MatMul8bit(torch.autograd.Function):
|
||||||
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
|
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
|
||||||
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
|
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
|
||||||
grad_A = F.vectorwise_mm_dequant(
|
grad_A = F.vectorwise_mm_dequant(
|
||||||
igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
|
igrad_A,
|
||||||
|
S1,
|
||||||
|
S3.permute(permute_dim),
|
||||||
|
grad_output.dtype,
|
||||||
|
quant_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return grad_A, grad_B, None, None, None
|
return grad_A, grad_B, None, None, None
|
||||||
|
@ -211,7 +217,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
if len(A.shape) == 3:
|
if len(A.shape) == 3:
|
||||||
A = A.view(-1, A.shape[-1]).contiguous()
|
A = A.view(-1, A.shape[-1]).contiguous()
|
||||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
|
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
||||||
|
A, threshold=state.threshold
|
||||||
|
)
|
||||||
|
|
||||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||||
if state.has_fp16_weights:
|
if state.has_fp16_weights:
|
||||||
|
@ -225,7 +233,9 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
if state.CxB is None:
|
if state.CxB is None:
|
||||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||||
# we also need to convert it to the turing/ampere format
|
# we also need to convert it to the turing/ampere format
|
||||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
state.CxB, state.SB = F.transform(
|
||||||
|
state.CB, to_order=formatB
|
||||||
|
)
|
||||||
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
|
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
|
||||||
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
|
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
|
||||||
# # generate outlier index and subB
|
# # generate outlier index and subB
|
||||||
|
@ -259,7 +269,13 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
|
||||||
if (state.is_training and not has_grad) or state.CxB is None:
|
if (state.is_training and not has_grad) or state.CxB is None:
|
||||||
state.reset_grads()
|
state.reset_grads()
|
||||||
CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
|
(
|
||||||
|
CB,
|
||||||
|
state.CBt,
|
||||||
|
state.SCB,
|
||||||
|
state.SCBt,
|
||||||
|
coo_tensorB,
|
||||||
|
) = F.double_quant(B)
|
||||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||||
else:
|
else:
|
||||||
has_grad = False
|
has_grad = False
|
||||||
|
@ -277,7 +293,10 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
# state.idx = outlier_idx
|
# state.idx = outlier_idx
|
||||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||||
state.subB = (
|
state.subB = (
|
||||||
(outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
|
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||||
|
.t()
|
||||||
|
.contiguous()
|
||||||
|
.half()
|
||||||
)
|
)
|
||||||
CA[:, state.idx.long()] = 0
|
CA[:, state.idx.long()] = 0
|
||||||
CAt[:, state.idx.long()] = 0
|
CAt[:, state.idx.long()] = 0
|
||||||
|
@ -325,10 +344,14 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
SCAt, idx = ctx.tensor_states
|
SCAt, idx = ctx.tensor_states
|
||||||
formatB = ctx.formatB
|
formatB = ctx.formatB
|
||||||
state = ctx.state
|
state = ctx.state
|
||||||
assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
|
assert (
|
||||||
|
state.has_fp16_weights
|
||||||
|
), "Backprop only supported for fp16 weights."
|
||||||
|
|
||||||
if len(grad_output.shape) == 3:
|
if len(grad_output.shape) == 3:
|
||||||
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
|
grad_output = grad_output.view(
|
||||||
|
-1, grad_output.shape[-1]
|
||||||
|
).contiguous()
|
||||||
|
|
||||||
grad_A = grad_B = None
|
grad_A = grad_B = None
|
||||||
|
|
||||||
|
@ -359,7 +382,11 @@ matmul = MatMul8bitLt.apply
|
||||||
|
|
||||||
|
|
||||||
def matmul(
|
def matmul(
|
||||||
A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
|
A: tensor,
|
||||||
|
B: tensor,
|
||||||
|
out: tensor = None,
|
||||||
|
state: MatmulLtState = None,
|
||||||
|
threshold=0.0,
|
||||||
):
|
):
|
||||||
state = state or MatmulLtState()
|
state = state or MatmulLtState()
|
||||||
if threshold > 0.0:
|
if threshold > 0.0:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
build is dependent on
|
extract factors the build is dependent on:
|
||||||
- compute capability
|
[X] compute capability
|
||||||
- dependent on GPU family
|
[ ] TODO: Q - What if we have multiple GPUs of different makes?
|
||||||
- CUDA version
|
- CUDA version
|
||||||
- Software:
|
- Software:
|
||||||
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
|
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
|
||||||
|
@ -19,6 +19,8 @@ evaluation:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import ctypes
|
import ctypes
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
from os import environ as env
|
from os import environ as env
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Set, Union
|
from typing import Set, Union
|
||||||
|
@ -26,10 +28,31 @@ from typing import Set, Union
|
||||||
from .utils import print_err, warn_of_missing_prerequisite
|
from .utils import print_err, warn_of_missing_prerequisite
|
||||||
|
|
||||||
|
|
||||||
|
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
||||||
|
def _decode(subprocess_err_out_tuple):
|
||||||
|
return tuple(
|
||||||
|
to_decode.decode("UTF-8").strip()
|
||||||
|
for to_decode in subprocess_err_out_tuple
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_and_return_decoded_std_streams(command_string):
|
||||||
|
return _decode(
|
||||||
|
subprocess.Popen(
|
||||||
|
shlex.split(command_string),
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
).communicate()
|
||||||
|
)
|
||||||
|
|
||||||
|
std_out, std_err = execute_and_return_decoded_std_streams()
|
||||||
|
return std_out, std_err
|
||||||
|
|
||||||
|
|
||||||
def check_cuda_result(cuda, result_val):
|
def check_cuda_result(cuda, result_val):
|
||||||
if result_val != 0:
|
if result_val != 0:
|
||||||
|
# TODO: undefined name 'error_str'
|
||||||
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
||||||
print(f"Count not initialize CUDA - failure!")
|
print("Count not initialize CUDA - failure!")
|
||||||
raise Exception("CUDA exception!")
|
raise Exception("CUDA exception!")
|
||||||
return result_val
|
return result_val
|
||||||
|
|
||||||
|
@ -53,7 +76,9 @@ def get_compute_capability():
|
||||||
|
|
||||||
result = ctypes.c_int()
|
result = ctypes.c_int()
|
||||||
device = ctypes.c_int()
|
device = ctypes.c_int()
|
||||||
|
# TODO: local variable 'context' is assigned to but never used
|
||||||
context = ctypes.c_void_p()
|
context = ctypes.c_void_p()
|
||||||
|
# TODO: local variable 'error_str' is assigned to but never used
|
||||||
error_str = ctypes.c_char_p()
|
error_str = ctypes.c_char_p()
|
||||||
|
|
||||||
result = check_cuda_result(cuda, cuda.cuInit(0))
|
result = check_cuda_result(cuda, cuda.cuInit(0))
|
||||||
|
@ -61,7 +86,9 @@ def get_compute_capability():
|
||||||
result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
||||||
ccs = []
|
ccs = []
|
||||||
for i in range(nGpus.value):
|
for i in range(nGpus.value):
|
||||||
result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
|
result = check_cuda_result(
|
||||||
|
cuda, cuda.cuDeviceGet(ctypes.byref(device), i)
|
||||||
|
)
|
||||||
result = check_cuda_result(
|
result = check_cuda_result(
|
||||||
cuda,
|
cuda,
|
||||||
cuda.cuDeviceComputeCapability(
|
cuda.cuDeviceComputeCapability(
|
||||||
|
@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path(
|
||||||
} - non_existent_directories
|
} - non_existent_directories
|
||||||
|
|
||||||
if len(cuda_runtime_libs) > 1:
|
if len(cuda_runtime_libs) > 1:
|
||||||
err_msg = f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
|
err_msg = (
|
||||||
|
f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
|
||||||
|
)
|
||||||
raise FileNotFoundError(err_msg)
|
raise FileNotFoundError(err_msg)
|
||||||
|
|
||||||
elif len(cuda_runtime_libs) < 1:
|
elif len(cuda_runtime_libs) < 1:
|
||||||
err_msg = f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
|
err_msg = (
|
||||||
|
f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
|
||||||
|
)
|
||||||
raise FileNotFoundError(err_msg)
|
raise FileNotFoundError(err_msg)
|
||||||
|
|
||||||
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
|
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
|
||||||
|
|
|
@ -17,14 +17,29 @@ if COMPILED_WITH_CUDA:
|
||||||
"""C FUNCTIONS FOR OPTIMIZERS"""
|
"""C FUNCTIONS FOR OPTIMIZERS"""
|
||||||
str2optimizer32bit = {}
|
str2optimizer32bit = {}
|
||||||
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
||||||
str2optimizer32bit["momentum"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
|
str2optimizer32bit["momentum"] = (
|
||||||
str2optimizer32bit["rmsprop"] = (lib.crmsprop32bit_g32, lib.crmsprop32bit_g16)
|
lib.cmomentum32bit_g32,
|
||||||
str2optimizer32bit["adagrad"] = (lib.cadagrad32bit_g32, lib.cadagrad32bit_g16)
|
lib.cmomentum32bit_g16,
|
||||||
str2optimizer32bit["lars"] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
|
)
|
||||||
|
str2optimizer32bit["rmsprop"] = (
|
||||||
|
lib.crmsprop32bit_g32,
|
||||||
|
lib.crmsprop32bit_g16,
|
||||||
|
)
|
||||||
|
str2optimizer32bit["adagrad"] = (
|
||||||
|
lib.cadagrad32bit_g32,
|
||||||
|
lib.cadagrad32bit_g16,
|
||||||
|
)
|
||||||
|
str2optimizer32bit["lars"] = (
|
||||||
|
lib.cmomentum32bit_g32,
|
||||||
|
lib.cmomentum32bit_g16,
|
||||||
|
)
|
||||||
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
||||||
|
|
||||||
str2optimizer8bit = {}
|
str2optimizer8bit = {}
|
||||||
str2optimizer8bit["adam"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
|
str2optimizer8bit["adam"] = (
|
||||||
|
lib.cadam_static_8bit_g32,
|
||||||
|
lib.cadam_static_8bit_g16,
|
||||||
|
)
|
||||||
str2optimizer8bit["momentum"] = (
|
str2optimizer8bit["momentum"] = (
|
||||||
lib.cmomentum_static_8bit_g32,
|
lib.cmomentum_static_8bit_g32,
|
||||||
lib.cmomentum_static_8bit_g16,
|
lib.cmomentum_static_8bit_g16,
|
||||||
|
@ -33,7 +48,10 @@ if COMPILED_WITH_CUDA:
|
||||||
lib.crmsprop_static_8bit_g32,
|
lib.crmsprop_static_8bit_g32,
|
||||||
lib.crmsprop_static_8bit_g16,
|
lib.crmsprop_static_8bit_g16,
|
||||||
)
|
)
|
||||||
str2optimizer8bit["lamb"] = (lib.cadam_static_8bit_g32, lib.cadam_static_8bit_g16)
|
str2optimizer8bit["lamb"] = (
|
||||||
|
lib.cadam_static_8bit_g32,
|
||||||
|
lib.cadam_static_8bit_g16,
|
||||||
|
)
|
||||||
str2optimizer8bit["lars"] = (
|
str2optimizer8bit["lars"] = (
|
||||||
lib.cmomentum_static_8bit_g32,
|
lib.cmomentum_static_8bit_g32,
|
||||||
lib.cmomentum_static_8bit_g16,
|
lib.cmomentum_static_8bit_g16,
|
||||||
|
@ -137,7 +155,9 @@ def create_dynamic_map(signed=True, n=7):
|
||||||
if not signed:
|
if not signed:
|
||||||
additional_items = 2 * additional_items
|
additional_items = 2 * additional_items
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
fraction_items = 2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
|
fraction_items = (
|
||||||
|
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
|
||||||
|
)
|
||||||
boundaries = torch.linspace(0.1, 1, fraction_items)
|
boundaries = torch.linspace(0.1, 1, fraction_items)
|
||||||
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
||||||
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
||||||
|
@ -272,7 +292,13 @@ def get_transform_buffer(
|
||||||
|
|
||||||
|
|
||||||
def nvidia_transform(
|
def nvidia_transform(
|
||||||
A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
|
A,
|
||||||
|
to_order,
|
||||||
|
from_order="row",
|
||||||
|
out=None,
|
||||||
|
transpose=False,
|
||||||
|
state=None,
|
||||||
|
ld=None,
|
||||||
):
|
):
|
||||||
if state is None:
|
if state is None:
|
||||||
state = (A.shape, from_order)
|
state = (A.shape, from_order)
|
||||||
|
@ -352,7 +378,11 @@ def estimate_quantiles(
|
||||||
|
|
||||||
|
|
||||||
def quantize_blockwise(
|
def quantize_blockwise(
|
||||||
A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None
|
A: Tensor,
|
||||||
|
code: Tensor = None,
|
||||||
|
absmax: Tensor = None,
|
||||||
|
rand=None,
|
||||||
|
out: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Quantize tensor A in blocks of size 4096 values.
|
Quantize tensor A in blocks of size 4096 values.
|
||||||
|
@ -629,7 +659,9 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.zeros_like(A, dtype=torch.float32)
|
out = torch.zeros_like(A, dtype=torch.float32)
|
||||||
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
lib.cdequantize(
|
||||||
|
get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -1005,7 +1037,9 @@ def histogram_scatter_add_2d(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
|
def check_matmul(
|
||||||
|
A, B, out, transposed_A, transposed_B, expected_type=torch.int8
|
||||||
|
):
|
||||||
if not torch.cuda.is_initialized():
|
if not torch.cuda.is_initialized():
|
||||||
torch.cuda.init()
|
torch.cuda.init()
|
||||||
if A.dtype != expected_type or B.dtype != expected_type:
|
if A.dtype != expected_type or B.dtype != expected_type:
|
||||||
|
@ -1097,7 +1131,11 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
|
||||||
|
|
||||||
|
|
||||||
def igemm(
|
def igemm(
|
||||||
A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
|
A: Tensor,
|
||||||
|
B: Tensor,
|
||||||
|
out: Tensor = None,
|
||||||
|
transposed_A=False,
|
||||||
|
transposed_B=False,
|
||||||
):
|
):
|
||||||
sout = check_matmul(A, B, out, transposed_A, transposed_B)
|
sout = check_matmul(A, B, out, transposed_A, transposed_B)
|
||||||
if out is None:
|
if out is None:
|
||||||
|
@ -1193,7 +1231,11 @@ def igemm(
|
||||||
|
|
||||||
|
|
||||||
def batched_igemm(
|
def batched_igemm(
|
||||||
A: Tensor, B: Tensor, out: Tensor = None, transposed_A=False, transposed_B=False
|
A: Tensor,
|
||||||
|
B: Tensor,
|
||||||
|
out: Tensor = None,
|
||||||
|
transposed_A=False,
|
||||||
|
transposed_B=False,
|
||||||
):
|
):
|
||||||
if not len(A.shape) == 3 or not len(B.shape) == 3:
|
if not len(A.shape) == 3 or not len(B.shape) == 3:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1392,9 +1434,13 @@ def mm_dequant(
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
|
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
|
||||||
if new_row_stats is None:
|
if new_row_stats is None:
|
||||||
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
|
new_row_stats = torch.empty(
|
||||||
|
out_shape[0], dtype=torch.float32, device=A.device
|
||||||
|
)
|
||||||
if new_col_stats is None:
|
if new_col_stats is None:
|
||||||
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
|
new_col_stats = torch.empty(
|
||||||
|
out_shape[1], dtype=torch.float32, device=A.device
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
new_row_stats.shape[0] == row_stats.shape[0]
|
new_row_stats.shape[0] == row_stats.shape[0]
|
||||||
), f"{new_row_stats.shape} vs {row_stats.shape}"
|
), f"{new_row_stats.shape} vs {row_stats.shape}"
|
||||||
|
@ -1440,13 +1486,13 @@ def get_colrow_absmax(
|
||||||
col_tiles = (cols + 255) // 256
|
col_tiles = (cols + 255) // 256
|
||||||
tiled_rows = ((rows + 15) // 16) * 16
|
tiled_rows = ((rows + 15) // 16) * 16
|
||||||
if row_stats is None:
|
if row_stats is None:
|
||||||
row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(
|
row_stats = torch.empty(
|
||||||
-50000.0
|
(rows,), dtype=torch.float32, device=device
|
||||||
)
|
).fill_(-50000.0)
|
||||||
if col_stats is None:
|
if col_stats is None:
|
||||||
col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(
|
col_stats = torch.empty(
|
||||||
-50000.0
|
(cols,), dtype=torch.float32, device=device
|
||||||
)
|
).fill_(-50000.0)
|
||||||
|
|
||||||
if nnz_block_ptr is None and threshold > 0.0:
|
if nnz_block_ptr is None and threshold > 0.0:
|
||||||
nnz_block_ptr = torch.zeros(
|
nnz_block_ptr = torch.zeros(
|
||||||
|
@ -1462,7 +1508,13 @@ def get_colrow_absmax(
|
||||||
|
|
||||||
prev_device = pre_call(A.device)
|
prev_device = pre_call(A.device)
|
||||||
lib.cget_col_row_stats(
|
lib.cget_col_row_stats(
|
||||||
ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols
|
ptrA,
|
||||||
|
ptrRowStats,
|
||||||
|
ptrColStats,
|
||||||
|
ptrNnzrows,
|
||||||
|
ct.c_float(threshold),
|
||||||
|
rows,
|
||||||
|
cols,
|
||||||
)
|
)
|
||||||
post_call(prev_device)
|
post_call(prev_device)
|
||||||
|
|
||||||
|
@ -1526,7 +1578,9 @@ class CSCSparseTensor(object):
|
||||||
def coo2csr(cooA):
|
def coo2csr(cooA):
|
||||||
values, counts = torch.unique(cooA.rowidx, return_counts=True)
|
values, counts = torch.unique(cooA.rowidx, return_counts=True)
|
||||||
values.add_(1)
|
values.add_(1)
|
||||||
rowptr = torch.zeros((cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device)
|
rowptr = torch.zeros(
|
||||||
|
(cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
|
||||||
|
)
|
||||||
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
|
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
|
||||||
rowptr.cumsum_(0)
|
rowptr.cumsum_(0)
|
||||||
return CSRSparseTensor(
|
return CSRSparseTensor(
|
||||||
|
@ -1540,10 +1594,14 @@ def coo2csc(cooA):
|
||||||
values = cooA.values[col2rowidx]
|
values = cooA.values[col2rowidx]
|
||||||
colvalues, counts = torch.unique(val, return_counts=True)
|
colvalues, counts = torch.unique(val, return_counts=True)
|
||||||
colvalues.add_(1)
|
colvalues.add_(1)
|
||||||
colptr = torch.zeros((cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device)
|
colptr = torch.zeros(
|
||||||
|
(cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
|
||||||
|
)
|
||||||
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
|
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
|
||||||
colptr.cumsum_(0)
|
colptr.cumsum_(0)
|
||||||
return CSCSparseTensor(cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values)
|
return CSCSparseTensor(
|
||||||
|
cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
|
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
|
||||||
|
@ -1568,7 +1626,9 @@ def double_quant(
|
||||||
rows = A.shape[0]
|
rows = A.shape[0]
|
||||||
|
|
||||||
if row_stats is None or col_stats is None:
|
if row_stats is None or col_stats is None:
|
||||||
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold)
|
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
|
||||||
|
A, threshold=threshold
|
||||||
|
)
|
||||||
|
|
||||||
if out_col is None:
|
if out_col is None:
|
||||||
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
|
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
|
||||||
|
@ -1663,7 +1723,13 @@ def get_special_format_str():
|
||||||
|
|
||||||
|
|
||||||
def transform(
|
def transform(
|
||||||
A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None
|
A,
|
||||||
|
to_order,
|
||||||
|
from_order="row",
|
||||||
|
out=None,
|
||||||
|
transpose=False,
|
||||||
|
state=None,
|
||||||
|
ld=None,
|
||||||
):
|
):
|
||||||
if state is None:
|
if state is None:
|
||||||
state = (A.shape, from_order)
|
state = (A.shape, from_order)
|
||||||
|
@ -1716,7 +1782,9 @@ def transform(
|
||||||
|
|
||||||
def spmm_coo(cooA, B, out=None):
|
def spmm_coo(cooA, B, out=None):
|
||||||
if out is None:
|
if out is None:
|
||||||
out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
|
out = torch.empty(
|
||||||
|
(cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
|
||||||
|
)
|
||||||
nnz = cooA.nnz
|
nnz = cooA.nnz
|
||||||
assert cooA.rowidx.numel() == nnz
|
assert cooA.rowidx.numel() == nnz
|
||||||
assert cooA.colidx.numel() == nnz
|
assert cooA.colidx.numel() == nnz
|
||||||
|
@ -1982,7 +2050,9 @@ def extract_outliers(A, SA, idx):
|
||||||
assert formatA in ["col_turing", "col_ampere"]
|
assert formatA in ["col_turing", "col_ampere"]
|
||||||
assert A.device.type == "cuda"
|
assert A.device.type == "cuda"
|
||||||
|
|
||||||
out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device)
|
out = torch.zeros(
|
||||||
|
(shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
|
||||||
|
)
|
||||||
|
|
||||||
idx_size = ct.c_int32(idx.numel())
|
idx_size = ct.c_int32(idx.numel())
|
||||||
rows = ct.c_int32(shapeA[0])
|
rows = ct.c_int32(shapeA[0])
|
||||||
|
|
|
@ -2,8 +2,19 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set,
|
from typing import (
|
||||||
Tuple, TypeVar, Union, overload)
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -131,7 +142,12 @@ class Embedding(torch.nn.Embedding):
|
||||||
|
|
||||||
class Int8Params(torch.nn.Parameter):
|
class Int8Params(torch.nn.Parameter):
|
||||||
def __new__(
|
def __new__(
|
||||||
cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None
|
cls,
|
||||||
|
data=None,
|
||||||
|
requires_grad=True,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
CB=None,
|
||||||
|
SCB=None,
|
||||||
):
|
):
|
||||||
cls.has_fp16_weights = has_fp16_weights
|
cls.has_fp16_weights = has_fp16_weights
|
||||||
cls.CB = None
|
cls.CB = None
|
||||||
|
@ -186,7 +202,9 @@ class Int8Params(torch.nn.Parameter):
|
||||||
return self.cuda(device)
|
return self.cuda(device)
|
||||||
else:
|
else:
|
||||||
new_param = Int8Params(
|
new_param = Int8Params(
|
||||||
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
|
super().to(
|
||||||
|
device=device, dtype=dtype, non_blocking=non_blocking
|
||||||
|
),
|
||||||
requires_grad=self.requires_grad,
|
requires_grad=self.requires_grad,
|
||||||
has_fp16_weights=self.has_fp16_weights,
|
has_fp16_weights=self.has_fp16_weights,
|
||||||
)
|
)
|
||||||
|
@ -206,7 +224,9 @@ class Linear8bitLt(nn.Linear):
|
||||||
threshold=0.0,
|
threshold=0.0,
|
||||||
index=None,
|
index=None,
|
||||||
):
|
):
|
||||||
super(Linear8bitLt, self).__init__(input_features, output_features, bias)
|
super(Linear8bitLt, self).__init__(
|
||||||
|
input_features, output_features, bias
|
||||||
|
)
|
||||||
self.state = bnb.MatmulLtState()
|
self.state = bnb.MatmulLtState()
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
||||||
|
@ -215,7 +235,9 @@ class Linear8bitLt(nn.Linear):
|
||||||
if threshold > 0.0 and not has_fp16_weights:
|
if threshold > 0.0 and not has_fp16_weights:
|
||||||
self.state.use_pool = True
|
self.state.use_pool = True
|
||||||
|
|
||||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
|
self.weight = Int8Params(
|
||||||
|
self.weight.data, has_fp16_weights=has_fp16_weights
|
||||||
|
)
|
||||||
|
|
||||||
def init_8bit_state(self):
|
def init_8bit_state(self):
|
||||||
self.state.CB = self.weight.CB
|
self.state.CB = self.weight.CB
|
||||||
|
|
|
@ -23,7 +23,9 @@ class Adagrad(Optimizer1State):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
raise ValueError(
|
||||||
|
"Invalid weight_decay value: {}".format(weight_decay)
|
||||||
|
)
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
if initial_accumulator_value != 0.0:
|
if initial_accumulator_value != 0.0:
|
||||||
|
@ -63,7 +65,9 @@ class Adagrad8bit(Optimizer1State):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
raise ValueError(
|
||||||
|
"Invalid weight_decay value: {}".format(weight_decay)
|
||||||
|
)
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
if initial_accumulator_value != 0.0:
|
if initial_accumulator_value != 0.0:
|
||||||
|
@ -104,7 +108,9 @@ class Adagrad32bit(Optimizer1State):
|
||||||
if not 0.0 <= lr:
|
if not 0.0 <= lr:
|
||||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
raise ValueError(
|
||||||
|
"Invalid weight_decay value: {}".format(weight_decay)
|
||||||
|
)
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
if initial_accumulator_value != 0.0:
|
if initial_accumulator_value != 0.0:
|
||||||
|
|
|
@ -140,7 +140,11 @@ class AnalysisAdam(torch.optim.Optimizer):
|
||||||
savedir=None,
|
savedir=None,
|
||||||
):
|
):
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
|
lr=lr,
|
||||||
|
betas=betas,
|
||||||
|
eps=eps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
amsgrad=amsgrad,
|
||||||
)
|
)
|
||||||
super(AnalysisAdam, self).__init__(params, defaults)
|
super(AnalysisAdam, self).__init__(params, defaults)
|
||||||
self.analysis = bnb_analysis
|
self.analysis = bnb_analysis
|
||||||
|
@ -198,7 +202,9 @@ class AnalysisAdam(torch.optim.Optimizer):
|
||||||
state["relerrors"] = torch.zeros(
|
state["relerrors"] = torch.zeros(
|
||||||
(256, 256), device=p_data_fp32.device
|
(256, 256), device=p_data_fp32.device
|
||||||
)
|
)
|
||||||
state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
|
state["counts"] = torch.zeros(
|
||||||
|
(256, 256), device=p_data_fp32.device
|
||||||
|
)
|
||||||
if amsgrad:
|
if amsgrad:
|
||||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||||
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||||
|
@ -214,7 +220,9 @@ class AnalysisAdam(torch.optim.Optimizer):
|
||||||
beta1, beta2 = group["betas"]
|
beta1, beta2 = group["betas"]
|
||||||
bias_correction1 = 1 - beta1 ** state["step"]
|
bias_correction1 = 1 - beta1 ** state["step"]
|
||||||
bias_correction2 = 1 - beta2 ** state["step"]
|
bias_correction2 = 1 - beta2 ** state["step"]
|
||||||
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
|
step_size = (
|
||||||
|
group["lr"] * math.sqrt(bias_correction2) / bias_correction1
|
||||||
|
)
|
||||||
e = state["abserrors"]
|
e = state["abserrors"]
|
||||||
rele = state["relerrors"]
|
rele = state["relerrors"]
|
||||||
counts = state["counts"]
|
counts = state["counts"]
|
||||||
|
@ -235,7 +243,10 @@ class AnalysisAdam(torch.optim.Optimizer):
|
||||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||||
update_fp32 = exp_avg / denom
|
update_fp32 = exp_avg / denom
|
||||||
|
|
||||||
if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
|
if (
|
||||||
|
p_data_fp32.numel() <= 8192
|
||||||
|
or p_data_fp32.numel() > 50000 * 1000
|
||||||
|
):
|
||||||
# embedding layer or too small
|
# embedding layer or too small
|
||||||
p_data_fp32 += -step_size * update_fp32
|
p_data_fp32 += -step_size * update_fp32
|
||||||
else:
|
else:
|
||||||
|
@ -274,7 +285,9 @@ class AnalysisAdam(torch.optim.Optimizer):
|
||||||
# 3. dequantize
|
# 3. dequantize
|
||||||
# Error will be calculated automatically!
|
# Error will be calculated automatically!
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid analysis value: {self.analysis}!")
|
raise ValueError(
|
||||||
|
f"Invalid analysis value: {self.analysis}!"
|
||||||
|
)
|
||||||
|
|
||||||
denom = state2.sqrt().add_(group["eps"])
|
denom = state2.sqrt().add_(group["eps"])
|
||||||
update_8bit = state1 / denom
|
update_8bit = state1 / denom
|
||||||
|
@ -296,7 +309,9 @@ class AnalysisAdam(torch.optim.Optimizer):
|
||||||
if self.savedir != "" and state["step"] % 100 == 0:
|
if self.savedir != "" and state["step"] % 100 == 0:
|
||||||
if not os.path.exists(self.savedir):
|
if not os.path.exists(self.savedir):
|
||||||
os.makedirs(self.savedir)
|
os.makedirs(self.savedir)
|
||||||
shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
|
shapestr = "_".join(
|
||||||
|
[str(dim) for dim in p_data_fp32.shape]
|
||||||
|
)
|
||||||
pathe = os.path.join(
|
pathe = os.path.join(
|
||||||
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
|
self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
|
||||||
)
|
)
|
||||||
|
|
|
@ -24,7 +24,9 @@ class LARS(Optimizer1State):
|
||||||
max_unorm=0.02,
|
max_unorm=0.02,
|
||||||
):
|
):
|
||||||
if momentum == 0:
|
if momentum == 0:
|
||||||
raise NotImplementedError(f"LARS without momentum is not supported!")
|
raise NotImplementedError(
|
||||||
|
f"LARS without momentum is not supported!"
|
||||||
|
)
|
||||||
super(LARS, self).__init__(
|
super(LARS, self).__init__(
|
||||||
"lars",
|
"lars",
|
||||||
params,
|
params,
|
||||||
|
@ -56,7 +58,9 @@ class LARS8bit(Optimizer1State):
|
||||||
max_unorm=0.02,
|
max_unorm=0.02,
|
||||||
):
|
):
|
||||||
if momentum == 0:
|
if momentum == 0:
|
||||||
raise NotImplementedError(f"LARS without momentum is not supported!")
|
raise NotImplementedError(
|
||||||
|
f"LARS without momentum is not supported!"
|
||||||
|
)
|
||||||
super(LARS8bit, self).__init__(
|
super(LARS8bit, self).__init__(
|
||||||
"lars",
|
"lars",
|
||||||
params,
|
params,
|
||||||
|
@ -88,7 +92,9 @@ class LARS32bit(Optimizer1State):
|
||||||
max_unorm=0.02,
|
max_unorm=0.02,
|
||||||
):
|
):
|
||||||
if momentum == 0:
|
if momentum == 0:
|
||||||
raise NotImplementedError(f"LARS without momentum is not supported!")
|
raise NotImplementedError(
|
||||||
|
f"LARS without momentum is not supported!"
|
||||||
|
)
|
||||||
super(LARS32bit, self).__init__(
|
super(LARS32bit, self).__init__(
|
||||||
"lars",
|
"lars",
|
||||||
params,
|
params,
|
||||||
|
@ -121,7 +127,9 @@ class PytorchLARS(Optimizer):
|
||||||
if momentum < 0.0:
|
if momentum < 0.0:
|
||||||
raise ValueError("Invalid momentum value: {}".format(momentum))
|
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||||
if weight_decay < 0.0:
|
if weight_decay < 0.0:
|
||||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
raise ValueError(
|
||||||
|
"Invalid weight_decay value: {}".format(weight_decay)
|
||||||
|
)
|
||||||
|
|
||||||
defaults = dict(
|
defaults = dict(
|
||||||
lr=lr,
|
lr=lr,
|
||||||
|
@ -132,7 +140,9 @@ class PytorchLARS(Optimizer):
|
||||||
max_unorm=max_unorm,
|
max_unorm=max_unorm,
|
||||||
)
|
)
|
||||||
if nesterov and (momentum <= 0 or dampening != 0):
|
if nesterov and (momentum <= 0 or dampening != 0):
|
||||||
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
raise ValueError(
|
||||||
|
"Nesterov momentum requires a momentum and zero dampening"
|
||||||
|
)
|
||||||
super(PytorchLARS, self).__init__(params, defaults)
|
super(PytorchLARS, self).__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
@ -46,9 +46,13 @@ class GlobalOptimManager(object):
|
||||||
for group_index, group in enumerate(param_groups):
|
for group_index, group in enumerate(param_groups):
|
||||||
for p_index, p in enumerate(group["params"]):
|
for p_index, p in enumerate(group["params"]):
|
||||||
if id(p) in self.pid2config:
|
if id(p) in self.pid2config:
|
||||||
self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
|
self.index2config[(group_index, p_index)] = self.pid2config[
|
||||||
|
id(p)
|
||||||
|
]
|
||||||
|
|
||||||
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
|
def override_config(
|
||||||
|
self, parameters, key=None, value=None, key_value_dict=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Overrides initial optimizer config for specific parameters.
|
Overrides initial optimizer config for specific parameters.
|
||||||
|
|
||||||
|
@ -136,7 +140,8 @@ class Optimizer8bit(torch.optim.Optimizer):
|
||||||
|
|
||||||
if len(groups) != len(saved_groups):
|
if len(groups) != len(saved_groups):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"loaded state dict has a different number of " "parameter groups"
|
"loaded state dict has a different number of "
|
||||||
|
"parameter groups"
|
||||||
)
|
)
|
||||||
param_lens = (len(g["params"]) for g in groups)
|
param_lens = (len(g["params"]) for g in groups)
|
||||||
saved_lens = (len(g["params"]) for g in saved_groups)
|
saved_lens = (len(g["params"]) for g in saved_groups)
|
||||||
|
@ -192,7 +197,9 @@ class Optimizer8bit(torch.optim.Optimizer):
|
||||||
new_group["params"] = group["params"]
|
new_group["params"] = group["params"]
|
||||||
return new_group
|
return new_group
|
||||||
|
|
||||||
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
param_groups = [
|
||||||
|
update_group(g, ng) for g, ng in zip(groups, saved_groups)
|
||||||
|
]
|
||||||
self.__setstate__({"state": state, "param_groups": param_groups})
|
self.__setstate__({"state": state, "param_groups": param_groups})
|
||||||
|
|
||||||
def to_gpu(self):
|
def to_gpu(self):
|
||||||
|
@ -222,9 +229,9 @@ class Optimizer8bit(torch.optim.Optimizer):
|
||||||
# found the matching parameter
|
# found the matching parameter
|
||||||
# init override
|
# init override
|
||||||
self.mng.pid2config[id(p)] = config
|
self.mng.pid2config[id(p)] = config
|
||||||
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
|
self.mng.index2config[
|
||||||
id(p)
|
(gindex, pindex)
|
||||||
]
|
] = self.mng.pid2config[id(p)]
|
||||||
found = True
|
found = True
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
@ -280,7 +287,9 @@ class Optimizer8bit(torch.optim.Optimizer):
|
||||||
raise NotImplementedError(f"init_state method needs to be overidden")
|
raise NotImplementedError(f"init_state method needs to be overidden")
|
||||||
|
|
||||||
def update_step(self, group, p, gindex, pindex):
|
def update_step(self, group, p, gindex, pindex):
|
||||||
raise NotImplementedError(f"The update_step method needs to be overidden")
|
raise NotImplementedError(
|
||||||
|
f"The update_step method needs to be overidden"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Optimizer2State(Optimizer8bit):
|
class Optimizer2State(Optimizer8bit):
|
||||||
|
@ -310,9 +319,13 @@ class Optimizer2State(Optimizer8bit):
|
||||||
betas = [float(b) for b in betas]
|
betas = [float(b) for b in betas]
|
||||||
for i in range(len(betas)):
|
for i in range(len(betas)):
|
||||||
if not 0.0 <= betas[i] < 1.0:
|
if not 0.0 <= betas[i] < 1.0:
|
||||||
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
|
raise ValueError(
|
||||||
|
f"Invalid beta parameter at index {i}: {betas[i]}"
|
||||||
|
)
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
raise ValueError(
|
||||||
|
"Invalid weight_decay value: {}".format(weight_decay)
|
||||||
|
)
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||||
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
|
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
|
||||||
|
|
||||||
|
@ -351,7 +364,9 @@ class Optimizer2State(Optimizer8bit):
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
state["step"] = 0
|
state["step"] = 0
|
||||||
|
|
||||||
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
|
if dtype == torch.float32 or (
|
||||||
|
dtype == torch.uint8 and p.numel() < 4096
|
||||||
|
):
|
||||||
state["state1"] = torch.zeros_like(
|
state["state1"] = torch.zeros_like(
|
||||||
p,
|
p,
|
||||||
memory_format=torch.preserve_format,
|
memory_format=torch.preserve_format,
|
||||||
|
@ -368,8 +383,12 @@ class Optimizer2State(Optimizer8bit):
|
||||||
if state["step"] == 0:
|
if state["step"] == 0:
|
||||||
if "dynamic" not in self.name2qmap:
|
if "dynamic" not in self.name2qmap:
|
||||||
self.fill_qmap()
|
self.fill_qmap()
|
||||||
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
|
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
|
||||||
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
|
p.device
|
||||||
|
)
|
||||||
|
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
|
||||||
|
p.device
|
||||||
|
)
|
||||||
|
|
||||||
state["state1"] = torch.zeros_like(
|
state["state1"] = torch.zeros_like(
|
||||||
p,
|
p,
|
||||||
|
@ -399,11 +418,15 @@ class Optimizer2State(Optimizer8bit):
|
||||||
(blocks,), dtype=torch.float32, device=p.device
|
(blocks,), dtype=torch.float32, device=p.device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
state["max1"] = torch.zeros(
|
||||||
|
(1,), dtype=torch.float32, device=p.device
|
||||||
|
)
|
||||||
state["new_max1"] = torch.zeros(
|
state["new_max1"] = torch.zeros(
|
||||||
(1,), dtype=torch.float32, device=p.device
|
(1,), dtype=torch.float32, device=p.device
|
||||||
)
|
)
|
||||||
state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
state["max2"] = torch.zeros(
|
||||||
|
(1,), dtype=torch.float32, device=p.device
|
||||||
|
)
|
||||||
state["new_max2"] = torch.zeros(
|
state["new_max2"] = torch.zeros(
|
||||||
(1,), dtype=torch.float32, device=p.device
|
(1,), dtype=torch.float32, device=p.device
|
||||||
)
|
)
|
||||||
|
@ -470,7 +493,9 @@ class Optimizer2State(Optimizer8bit):
|
||||||
state["new_max2"],
|
state["new_max2"],
|
||||||
config["weight_decay"],
|
config["weight_decay"],
|
||||||
gnorm_scale=gnorm_scale,
|
gnorm_scale=gnorm_scale,
|
||||||
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
|
unorm_vec=state["unorm_vec"]
|
||||||
|
if config["max_unorm"] > 0.0
|
||||||
|
else None,
|
||||||
max_unorm=config["max_unorm"],
|
max_unorm=config["max_unorm"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -522,9 +547,13 @@ class Optimizer1State(Optimizer8bit):
|
||||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
for i in range(len(betas)):
|
for i in range(len(betas)):
|
||||||
if not 0.0 <= betas[i] < 1.0:
|
if not 0.0 <= betas[i] < 1.0:
|
||||||
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
|
raise ValueError(
|
||||||
|
f"Invalid beta parameter at index {i}: {betas[i]}"
|
||||||
|
)
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
raise ValueError(
|
||||||
|
"Invalid weight_decay value: {}".format(weight_decay)
|
||||||
|
)
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||||
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
|
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
|
||||||
|
|
||||||
|
@ -563,7 +592,9 @@ class Optimizer1State(Optimizer8bit):
|
||||||
state = self.state[p]
|
state = self.state[p]
|
||||||
state["step"] = 0
|
state["step"] = 0
|
||||||
|
|
||||||
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
|
if dtype == torch.float32 or (
|
||||||
|
dtype == torch.uint8 and p.numel() < 4096
|
||||||
|
):
|
||||||
state["state1"] = torch.zeros_like(
|
state["state1"] = torch.zeros_like(
|
||||||
p,
|
p,
|
||||||
memory_format=torch.preserve_format,
|
memory_format=torch.preserve_format,
|
||||||
|
@ -574,7 +605,9 @@ class Optimizer1State(Optimizer8bit):
|
||||||
if state["step"] == 0:
|
if state["step"] == 0:
|
||||||
if "dynamic" not in self.name2qmap:
|
if "dynamic" not in self.name2qmap:
|
||||||
self.fill_qmap()
|
self.fill_qmap()
|
||||||
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
|
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
|
||||||
|
p.device
|
||||||
|
)
|
||||||
|
|
||||||
state["state1"] = torch.zeros_like(
|
state["state1"] = torch.zeros_like(
|
||||||
p,
|
p,
|
||||||
|
@ -593,7 +626,9 @@ class Optimizer1State(Optimizer8bit):
|
||||||
(blocks,), dtype=torch.float32, device=p.device
|
(blocks,), dtype=torch.float32, device=p.device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
state["max1"] = torch.zeros(
|
||||||
|
(1,), dtype=torch.float32, device=p.device
|
||||||
|
)
|
||||||
state["new_max1"] = torch.zeros(
|
state["new_max1"] = torch.zeros(
|
||||||
(1,), dtype=torch.float32, device=p.device
|
(1,), dtype=torch.float32, device=p.device
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,7 +22,9 @@ class RMSprop(Optimizer1State):
|
||||||
block_wise=True,
|
block_wise=True,
|
||||||
):
|
):
|
||||||
if alpha == 0:
|
if alpha == 0:
|
||||||
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
|
raise NotImplementedError(
|
||||||
|
f"RMSprop with alpha==0.0 is not supported!"
|
||||||
|
)
|
||||||
if centered:
|
if centered:
|
||||||
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
||||||
super(RMSprop, self).__init__(
|
super(RMSprop, self).__init__(
|
||||||
|
@ -56,7 +58,9 @@ class RMSprop8bit(Optimizer1State):
|
||||||
block_wise=True,
|
block_wise=True,
|
||||||
):
|
):
|
||||||
if alpha == 0:
|
if alpha == 0:
|
||||||
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
|
raise NotImplementedError(
|
||||||
|
f"RMSprop with alpha==0.0 is not supported!"
|
||||||
|
)
|
||||||
if centered:
|
if centered:
|
||||||
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
||||||
super(RMSprop8bit, self).__init__(
|
super(RMSprop8bit, self).__init__(
|
||||||
|
@ -91,7 +95,9 @@ class RMSprop32bit(Optimizer1State):
|
||||||
):
|
):
|
||||||
|
|
||||||
if alpha == 0:
|
if alpha == 0:
|
||||||
raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
|
raise NotImplementedError(
|
||||||
|
f"RMSprop with alpha==0.0 is not supported!"
|
||||||
|
)
|
||||||
if centered:
|
if centered:
|
||||||
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
raise NotImplementedError(f"Centered RMSprop is not supported!")
|
||||||
super(RMSprop32bit, self).__init__(
|
super(RMSprop32bit, self).__init__(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
|
import sys
|
||||||
import shlex
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
||||||
|
|
20
quicktest.py
20
quicktest.py
|
@ -14,23 +14,31 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
|
||||||
torch.int8
|
torch.int8
|
||||||
)
|
)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
|
A = torch.randint(
|
||||||
torch.int8
|
-128, 127, size=(dim1, dim2, dim3), device="cuda"
|
||||||
)
|
).to(torch.int8)
|
||||||
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
|
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
|
||||||
|
torch.int8
|
||||||
|
)
|
||||||
C1 = torch.matmul(A.float(), B.t().float())
|
C1 = torch.matmul(A.float(), B.t().float())
|
||||||
|
|
||||||
A2, SA = F.transform(A, "col32")
|
A2, SA = F.transform(A, "col32")
|
||||||
B2, SB = F.transform(B, "colx")
|
B2, SB = F.transform(B, "colx")
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
C2, SC = F.transform(
|
C2, SC = F.transform(
|
||||||
torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"),
|
torch.zeros(
|
||||||
|
A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"
|
||||||
|
),
|
||||||
"col32",
|
"col32",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
C2, SC = F.transform(
|
C2, SC = F.transform(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda"
|
A.shape[0],
|
||||||
|
A.shape[1],
|
||||||
|
B.shape[0],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
),
|
),
|
||||||
"col32",
|
"col32",
|
||||||
)
|
)
|
||||||
|
|
|
@ -18,9 +18,13 @@ req_grad_str = ["FF", "TF", "TT", "FT"]
|
||||||
transpose = [(False, False), (False, True), (True, True), (True, False)]
|
transpose = [(False, False), (False, True), (True, True), (True, False)]
|
||||||
str_transpose = ["FF", "FT", "TT", "TF"]
|
str_transpose = ["FF", "FT", "TT", "TF"]
|
||||||
dtype = [torch.float32, torch.float16]
|
dtype = [torch.float32, torch.float16]
|
||||||
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
|
values = list(
|
||||||
|
product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)
|
||||||
|
)
|
||||||
str_values = list(
|
str_values = list(
|
||||||
product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)
|
product(
|
||||||
|
dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose
|
||||||
|
)
|
||||||
)
|
)
|
||||||
names = [
|
names = [
|
||||||
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
|
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
|
||||||
|
@ -31,7 +35,9 @@ names = [
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names
|
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose",
|
||||||
|
values,
|
||||||
|
ids=names,
|
||||||
)
|
)
|
||||||
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
dim2 = dim2 - (dim2 % 16)
|
dim2 = dim2 - (dim2 % 16)
|
||||||
|
@ -79,7 +85,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
A.grad = None
|
A.grad = None
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
|
||||||
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
|
loss_torch = torch.nn.functional.mse_loss(
|
||||||
|
out_torch, target
|
||||||
|
).mean()
|
||||||
loss_torch.backward()
|
loss_torch.backward()
|
||||||
gradA2 = A.grad
|
gradA2 = A.grad
|
||||||
gradB2 = B.grad
|
gradB2 = B.grad
|
||||||
|
@ -87,25 +95,35 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
|
||||||
if req_grad[0]:
|
if req_grad[0]:
|
||||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
torch.testing.assert_allclose(
|
||||||
|
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||||
|
)
|
||||||
if req_grad[1]:
|
if req_grad[1]:
|
||||||
n = gradB1.numel()
|
n = gradB1.numel()
|
||||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||||
assert (idx == 0).sum().item() < n * 0.1
|
assert (idx == 0).sum().item() < n * 0.1
|
||||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||||
assert (idx == 0).sum().item() < n * 0.02
|
assert (idx == 0).sum().item() < n * 0.02
|
||||||
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
|
torch.testing.assert_allclose(
|
||||||
|
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||||
|
)
|
||||||
|
|
||||||
# batched matrix multiply
|
# batched matrix multiply
|
||||||
if funcs[0] in [torch.bmm, torch.matmul]:
|
if funcs[0] in [torch.bmm, torch.matmul]:
|
||||||
A = torch.randn(
|
A = torch.randn(
|
||||||
size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
|
size=(dim1, dim2, dim3),
|
||||||
|
device="cuda",
|
||||||
|
requires_grad=req_grad[0],
|
||||||
)
|
)
|
||||||
B = torch.randn(
|
B = torch.randn(
|
||||||
size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1]
|
size=(dim1, dim3, dim4),
|
||||||
|
device="cuda",
|
||||||
|
requires_grad=req_grad[1],
|
||||||
)
|
)
|
||||||
target = torch.randn(
|
target = torch.randn(
|
||||||
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
|
size=(dim1, dim2, dim4),
|
||||||
|
device="cuda",
|
||||||
|
requires_grad=req_grad[1],
|
||||||
)
|
)
|
||||||
torch.nn.init.xavier_uniform_(B)
|
torch.nn.init.xavier_uniform_(B)
|
||||||
|
|
||||||
|
@ -115,7 +133,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
n = out_bnb.numel()
|
n = out_bnb.numel()
|
||||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||||
assert (idx == 0).sum().item() < n * 0.01
|
assert (idx == 0).sum().item() < n * 0.01
|
||||||
torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
|
torch.testing.assert_allclose(
|
||||||
|
out_bnb, out_torch, atol=0.027, rtol=0.2
|
||||||
|
)
|
||||||
|
|
||||||
if any(req_grad):
|
if any(req_grad):
|
||||||
out_bnb.data.copy_(out_torch)
|
out_bnb.data.copy_(out_torch)
|
||||||
|
@ -127,7 +147,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
A.grad = None
|
A.grad = None
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
|
||||||
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
|
loss_torch = torch.nn.functional.mse_loss(
|
||||||
|
out_torch, target
|
||||||
|
).mean()
|
||||||
loss_torch.backward()
|
loss_torch.backward()
|
||||||
gradA2 = A.grad
|
gradA2 = A.grad
|
||||||
gradB2 = B.grad
|
gradB2 = B.grad
|
||||||
|
@ -135,7 +157,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
|
||||||
if req_grad[0]:
|
if req_grad[0]:
|
||||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
torch.testing.assert_allclose(
|
||||||
|
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||||
|
)
|
||||||
if req_grad[1]:
|
if req_grad[1]:
|
||||||
n = gradB1.numel()
|
n = gradB1.numel()
|
||||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||||
|
@ -146,12 +170,16 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
if funcs[0] in [torch.matmul]:
|
if funcs[0] in [torch.matmul]:
|
||||||
dim1 = dim1 - (dim1 % 16)
|
dim1 = dim1 - (dim1 % 16)
|
||||||
A = torch.randn(
|
A = torch.randn(
|
||||||
size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
|
size=(dim1, dim2, dim3),
|
||||||
|
device="cuda",
|
||||||
|
requires_grad=req_grad[0],
|
||||||
)
|
)
|
||||||
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
|
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
|
||||||
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
|
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
|
||||||
target = torch.randn(
|
target = torch.randn(
|
||||||
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
|
size=(dim1, dim2, dim4),
|
||||||
|
device="cuda",
|
||||||
|
requires_grad=req_grad[1],
|
||||||
)
|
)
|
||||||
torch.nn.init.xavier_uniform_(B)
|
torch.nn.init.xavier_uniform_(B)
|
||||||
|
|
||||||
|
@ -178,7 +206,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
A.grad = None
|
A.grad = None
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
|
||||||
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
|
loss_torch = torch.nn.functional.mse_loss(
|
||||||
|
out_torch, target
|
||||||
|
).mean()
|
||||||
loss_torch.backward()
|
loss_torch.backward()
|
||||||
gradA2 = A.grad
|
gradA2 = A.grad
|
||||||
gradB2 = B.grad
|
gradB2 = B.grad
|
||||||
|
@ -186,7 +216,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
|
||||||
if req_grad[0]:
|
if req_grad[0]:
|
||||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
torch.testing.assert_allclose(
|
||||||
|
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||||
|
)
|
||||||
if req_grad[1]:
|
if req_grad[1]:
|
||||||
n = gradB1.numel()
|
n = gradB1.numel()
|
||||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||||
|
@ -258,7 +290,16 @@ names = [
|
||||||
ids=names,
|
ids=names,
|
||||||
)
|
)
|
||||||
def test_matmullt(
|
def test_matmullt(
|
||||||
dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights
|
dim1,
|
||||||
|
dim2,
|
||||||
|
dim3,
|
||||||
|
dim4,
|
||||||
|
funcs,
|
||||||
|
dtype,
|
||||||
|
req_grad,
|
||||||
|
transpose,
|
||||||
|
decomp,
|
||||||
|
has_fp16_weights,
|
||||||
):
|
):
|
||||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||||
|
@ -278,7 +319,10 @@ def test_matmullt(
|
||||||
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
|
size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
|
||||||
)
|
)
|
||||||
target = torch.randn(
|
target = torch.randn(
|
||||||
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype
|
size=(dim2, dim4),
|
||||||
|
device="cuda",
|
||||||
|
requires_grad=req_grad[1],
|
||||||
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
torch.nn.init.xavier_uniform_(B)
|
torch.nn.init.xavier_uniform_(B)
|
||||||
B2 = B.clone()
|
B2 = B.clone()
|
||||||
|
@ -317,14 +361,18 @@ def test_matmullt(
|
||||||
if any(req_grad):
|
if any(req_grad):
|
||||||
out_bnb.data.copy_(out_torch)
|
out_bnb.data.copy_(out_torch)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
loss_bnb = torch.nn.functional.mse_loss(
|
||||||
|
out_bnb, target
|
||||||
|
).mean()
|
||||||
loss_bnb.backward()
|
loss_bnb.backward()
|
||||||
gradA1 = A.grad
|
gradA1 = A.grad
|
||||||
gradB1 = B.grad
|
gradB1 = B.grad
|
||||||
A.grad = None
|
A.grad = None
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
|
||||||
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
|
loss_torch = torch.nn.functional.mse_loss(
|
||||||
|
out_torch, target
|
||||||
|
).mean()
|
||||||
loss_torch.backward()
|
loss_torch.backward()
|
||||||
gradA2 = A.grad
|
gradA2 = A.grad
|
||||||
gradB2 = B.grad
|
gradB2 = B.grad
|
||||||
|
@ -332,7 +380,9 @@ def test_matmullt(
|
||||||
B.grad = None
|
B.grad = None
|
||||||
|
|
||||||
if req_grad[0]:
|
if req_grad[0]:
|
||||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
torch.testing.assert_allclose(
|
||||||
|
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||||
|
)
|
||||||
if req_grad[1]:
|
if req_grad[1]:
|
||||||
n = gradB1.numel()
|
n = gradB1.numel()
|
||||||
assert torch.abs(gradB1).sum() > 0.0
|
assert torch.abs(gradB1).sum() > 0.0
|
||||||
|
@ -341,4 +391,6 @@ def test_matmullt(
|
||||||
assert (idx == 0).sum().item() < n * 0.1
|
assert (idx == 0).sum().item() < n * 0.1
|
||||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||||
assert (idx == 0).sum().item() < n * 0.02
|
assert (idx == 0).sum().item() < n * 0.02
|
||||||
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
|
torch.testing.assert_allclose(
|
||||||
|
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||||
|
)
|
||||||
|
|
|
@ -3,8 +3,12 @@ from typing import List, NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup,
|
from bitsandbytes.cuda_setup import (
|
||||||
get_cuda_runtime_lib_path, tokenize_paths)
|
CUDA_RUNTIME_LIB,
|
||||||
|
evaluate_cuda_setup,
|
||||||
|
get_cuda_runtime_lib_path,
|
||||||
|
tokenize_paths,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class InputAndExpectedOutput(NamedTuple):
|
class InputAndExpectedOutput(NamedTuple):
|
||||||
|
@ -13,11 +17,26 @@ class InputAndExpectedOutput(NamedTuple):
|
||||||
|
|
||||||
|
|
||||||
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
|
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
|
||||||
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
(
|
||||||
(f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
),
|
||||||
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"),
|
(
|
||||||
|
f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
|
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
|
||||||
|
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
|
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
|
||||||
|
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
|
),
|
||||||
(
|
(
|
||||||
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
|
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
|
||||||
f"dir/with/{CUDA_RUNTIME_LIB}",
|
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||||
|
|
|
@ -86,7 +86,9 @@ def teardown():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=["float", "half"])
|
@pytest.mark.parametrize(
|
||||||
|
"dtype", [torch.float32, torch.float16], ids=["float", "half"]
|
||||||
|
)
|
||||||
def test_estimate_quantiles(dtype):
|
def test_estimate_quantiles(dtype):
|
||||||
A = torch.rand(1024, 1024, device="cuda")
|
A = torch.rand(1024, 1024, device="cuda")
|
||||||
A = A.to(dtype)
|
A = A.to(dtype)
|
||||||
|
@ -190,7 +192,9 @@ def test_dynamic_blockwise_stochastic_quantization():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=["float", "half"])
|
@pytest.mark.parametrize(
|
||||||
|
"gtype", [torch.float32, torch.float16], ids=["float", "half"]
|
||||||
|
)
|
||||||
def test_percentile_clipping(gtype):
|
def test_percentile_clipping(gtype):
|
||||||
gnorm_vec1 = torch.zeros(100, device="cuda")
|
gnorm_vec1 = torch.zeros(100, device="cuda")
|
||||||
gnorm_vec2 = torch.zeros(100, device="cuda")
|
gnorm_vec2 = torch.zeros(100, device="cuda")
|
||||||
|
@ -270,7 +274,13 @@ def mean(xx):
|
||||||
dim1 = [1024 * 2]
|
dim1 = [1024 * 2]
|
||||||
dim2 = [1024 * 16]
|
dim2 = [1024 * 16]
|
||||||
methods = [
|
methods = [
|
||||||
(lambda x, dim: quant(x), lambda x, dim: quant(x), dequant, dequant, mm_dequant)
|
(
|
||||||
|
lambda x, dim: quant(x),
|
||||||
|
lambda x, dim: quant(x),
|
||||||
|
dequant,
|
||||||
|
dequant,
|
||||||
|
mm_dequant,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
|
methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant))
|
||||||
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
|
# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant))
|
||||||
|
@ -279,11 +289,14 @@ batched = [False, True]
|
||||||
values = list(product(dim1, dim2, methods, batched))
|
values = list(product(dim1, dim2, methods, batched))
|
||||||
values_names = list(product(dim1, dim2, method_names, batched))
|
values_names = list(product(dim1, dim2, method_names, batched))
|
||||||
names = [
|
names = [
|
||||||
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals) for vals in values_names
|
"dim1_{0}_dim2_{1}_quant_{2}_batched_{3}".format(*vals)
|
||||||
|
for vals in values_names
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim1, dim2, quant_methods, batched", values, ids=names)
|
@pytest.mark.parametrize(
|
||||||
|
"dim1, dim2, quant_methods, batched", values, ids=names
|
||||||
|
)
|
||||||
def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
||||||
dim1 = dim1 - (dim1 % 32)
|
dim1 = dim1 - (dim1 % 32)
|
||||||
dim2 = dim2 - (dim2 % 32)
|
dim2 = dim2 - (dim2 % 32)
|
||||||
|
@ -339,14 +352,18 @@ names = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("hidden_dim, batch_dim, transpose, seq_dim", values, ids=names)
|
@pytest.mark.parametrize(
|
||||||
|
"hidden_dim, batch_dim, transpose, seq_dim", values, ids=names
|
||||||
|
)
|
||||||
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
|
def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
|
||||||
hidden_dim = hidden_dim - (hidden_dim % 32)
|
hidden_dim = hidden_dim - (hidden_dim % 32)
|
||||||
batch_dim = batch_dim - (batch_dim % 16)
|
batch_dim = batch_dim - (batch_dim % 16)
|
||||||
seq_dim = seq_dim - (seq_dim % 16)
|
seq_dim = seq_dim - (seq_dim % 16)
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
shapeA = (
|
shapeA = (
|
||||||
(batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
|
(batch_dim, hidden_dim)
|
||||||
|
if not transpose[0]
|
||||||
|
else (hidden_dim, batch_dim)
|
||||||
)
|
)
|
||||||
shapeB = (
|
shapeB = (
|
||||||
(32 * random.randint(1, 4), hidden_dim)
|
(32 * random.randint(1, 4), hidden_dim)
|
||||||
|
@ -394,7 +411,9 @@ seq_dim = torch.randint(32, 512, size=(n,)).tolist()
|
||||||
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
|
hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist()
|
||||||
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
|
batch_dim = torch.randint(2, 16, size=(n,)).tolist()
|
||||||
values = list(product(seq_dim, hidden_dim, batch_dim))
|
values = list(product(seq_dim, hidden_dim, batch_dim))
|
||||||
names = ["seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"seq_dim{0}_hidden_dim{1}_batch_dim{2}".format(*vals) for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
|
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names)
|
||||||
|
@ -406,11 +425,13 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
|
||||||
A = torch.randint(
|
A = torch.randint(
|
||||||
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
|
-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
|
||||||
).to(torch.int8)
|
).to(torch.int8)
|
||||||
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(
|
B = torch.randint(
|
||||||
torch.int8
|
-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda"
|
||||||
)
|
).to(torch.int8)
|
||||||
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
|
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
|
||||||
iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
|
iout = torch.empty(
|
||||||
|
A.shape[2], B.shape[2], dtype=torch.int32, device=A.device
|
||||||
|
)
|
||||||
out = F.igemm(A, B, out=iout)
|
out = F.igemm(A, B, out=iout)
|
||||||
|
|
||||||
torch.testing.assert_allclose(out.float(), out2)
|
torch.testing.assert_allclose(out.float(), out2)
|
||||||
|
@ -428,7 +449,9 @@ names = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim, transpose", values, ids=names)
|
@pytest.mark.parametrize(
|
||||||
|
"seq_dim, hidden_dim, batch_dim, transpose", values, ids=names
|
||||||
|
)
|
||||||
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
|
def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
|
||||||
def min_max(x):
|
def min_max(x):
|
||||||
maxA = torch.amax(x, dim=2, keepdim=True)
|
maxA = torch.amax(x, dim=2, keepdim=True)
|
||||||
|
@ -444,7 +467,9 @@ def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose):
|
||||||
errs2 = []
|
errs2 = []
|
||||||
relerrs2 = []
|
relerrs2 = []
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
|
A = torch.normal(
|
||||||
|
0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda"
|
||||||
|
)
|
||||||
if transpose:
|
if transpose:
|
||||||
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
|
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
|
||||||
else:
|
else:
|
||||||
|
@ -504,7 +529,8 @@ dim4 = torch.randint(32, 256, size=(n,)).tolist()
|
||||||
transpose = [(False, False), (True, False), (False, True), (True, True)]
|
transpose = [(False, False), (True, False), (False, True), (True, True)]
|
||||||
values = list(product(dim1, dim2, dim3, dim4, transpose))
|
values = list(product(dim1, dim2, dim3, dim4, transpose))
|
||||||
names = [
|
names = [
|
||||||
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals) for vals in values
|
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_transpose_{4}".format(*vals)
|
||||||
|
for vals in values
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -529,7 +555,9 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
|
||||||
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
|
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
|
||||||
out = F.igemm(A.permute([0, 2, 1]), B)
|
out = F.igemm(A.permute([0, 2, 1]), B)
|
||||||
elif transpose[0] and transpose[1]:
|
elif transpose[0] and transpose[1]:
|
||||||
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
|
out2 = torch.bmm(
|
||||||
|
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
|
||||||
|
)
|
||||||
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
|
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
|
||||||
torch.testing.assert_allclose(out.float(), out2.float())
|
torch.testing.assert_allclose(out.float(), out2.float())
|
||||||
|
|
||||||
|
@ -563,7 +591,9 @@ a_order = ["row"]
|
||||||
out_order = ["col", "row", "col32"]
|
out_order = ["col", "row", "col32"]
|
||||||
transpose = [False]
|
transpose = [False]
|
||||||
dims = [2, 3]
|
dims = [2, 3]
|
||||||
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
|
values = list(
|
||||||
|
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
|
||||||
|
)
|
||||||
|
|
||||||
names = [
|
names = [
|
||||||
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
|
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_transpose_{7}".format(
|
||||||
|
@ -574,9 +604,13 @@ names = [
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
|
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
|
||||||
|
values,
|
||||||
|
ids=names,
|
||||||
)
|
)
|
||||||
def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
|
def test_nvidia_transform(
|
||||||
|
dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose
|
||||||
|
):
|
||||||
if dims == 3 and out_order != "col32":
|
if dims == 3 and out_order != "col32":
|
||||||
return
|
return
|
||||||
if dtype == torch.int32 and out_order != "col32":
|
if dtype == torch.int32 and out_order != "col32":
|
||||||
|
@ -586,7 +620,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
|
A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(dtype)
|
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
|
||||||
out, S = F.nvidia_transform(A, to_order=orderOut)
|
out, S = F.nvidia_transform(A, to_order=orderOut)
|
||||||
|
|
||||||
|
@ -598,7 +634,11 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
|
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
n = A.shape[0] * A.shape[1] * (A.shape[2] + (32 - (A.shape[2] % 32)))
|
n = (
|
||||||
|
A.shape[0]
|
||||||
|
* A.shape[1]
|
||||||
|
* (A.shape[2] + (32 - (A.shape[2] % 32)))
|
||||||
|
)
|
||||||
assert out.numel() == n
|
assert out.numel() == n
|
||||||
elif orderOut == "col_turing":
|
elif orderOut == "col_turing":
|
||||||
# 32 col 8 row tiles
|
# 32 col 8 row tiles
|
||||||
|
@ -613,7 +653,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
|
||||||
j = col
|
j = col
|
||||||
|
|
||||||
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
|
coltile = (col // 32) + (1 if col % 32 != 0 else 0)
|
||||||
rowtile = ((row // 8) + (1 if row % 8 != 0 else 0)) * total_coltile
|
rowtile = (
|
||||||
|
(row // 8) + (1 if row % 8 != 0 else 0)
|
||||||
|
) * total_coltile
|
||||||
offset = 32 * 8 * (rowtile + coltile)
|
offset = 32 * 8 * (rowtile + coltile)
|
||||||
col2 = col % 32
|
col2 = col % 32
|
||||||
row2 = (row % 8) * 32
|
row2 = (row % 8) * 32
|
||||||
|
@ -624,7 +666,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
|
||||||
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
|
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
|
||||||
|
|
||||||
if orderOut == "col32":
|
if orderOut == "col32":
|
||||||
out2, S = F.nvidia_transform(out, from_order=orderOut, to_order="row", state=S)
|
out2, S = F.nvidia_transform(
|
||||||
|
out, from_order=orderOut, to_order="row", state=S
|
||||||
|
)
|
||||||
torch.testing.assert_allclose(A, out2)
|
torch.testing.assert_allclose(A, out2)
|
||||||
|
|
||||||
|
|
||||||
|
@ -657,10 +701,12 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
|
||||||
torch.int8
|
torch.int8
|
||||||
)
|
)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
|
A = torch.randint(
|
||||||
torch.int8
|
-128, 127, size=(dim1, dim2, dim3), device="cuda"
|
||||||
)
|
).to(torch.int8)
|
||||||
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
|
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(
|
||||||
|
torch.int8
|
||||||
|
)
|
||||||
C1 = torch.matmul(A.float(), B.t().float())
|
C1 = torch.matmul(A.float(), B.t().float())
|
||||||
|
|
||||||
A2, SA = F.transform(A, "col32")
|
A2, SA = F.transform(A, "col32")
|
||||||
|
@ -670,7 +716,9 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
|
||||||
torch.testing.assert_allclose(C1, C3.float())
|
torch.testing.assert_allclose(C1, C3.float())
|
||||||
|
|
||||||
# transpose
|
# transpose
|
||||||
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(torch.int8)
|
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
|
||||||
|
torch.int8
|
||||||
|
)
|
||||||
C1 = torch.matmul(A.float(), B.float())
|
C1 = torch.matmul(A.float(), B.float())
|
||||||
|
|
||||||
B2t, SBt = F.transform(B, "col_turing", transpose=True)
|
B2t, SBt = F.transform(B, "col_turing", transpose=True)
|
||||||
|
@ -688,7 +736,8 @@ dims = (2,)
|
||||||
# ldb = list(range(256, 1*1024, 256))
|
# ldb = list(range(256, 1*1024, 256))
|
||||||
values = list(product(dim1, dim2, dim3, dim4, dims))
|
values = list(product(dim1, dim2, dim3, dim4, dims))
|
||||||
names = [
|
names = [
|
||||||
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals) for vals in values
|
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_dims_{4}".format(*vals)
|
||||||
|
for vals in values
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -699,7 +748,9 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
|
A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half()
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
A = torch.normal(0, 0.5, size=(dim1, dim2, dim3), device="cuda").half()
|
A = torch.normal(
|
||||||
|
0, 0.5, size=(dim1, dim2, dim3), device="cuda"
|
||||||
|
).half()
|
||||||
B = torch.randn((dim4, dim3), device="cuda").half()
|
B = torch.randn((dim4, dim3), device="cuda").half()
|
||||||
torch.nn.init.xavier_uniform_(B)
|
torch.nn.init.xavier_uniform_(B)
|
||||||
C1 = torch.matmul(A, B.t())
|
C1 = torch.matmul(A, B.t())
|
||||||
|
@ -742,7 +793,9 @@ values = [
|
||||||
|
|
||||||
|
|
||||||
# values = list(product(batch, seq, model, hidden))
|
# values = list(product(batch, seq, model, hidden))
|
||||||
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||||
|
@ -909,7 +962,9 @@ dims = (2,)
|
||||||
# ldb = list(range(256, 1*1024, 256))
|
# ldb = list(range(256, 1*1024, 256))
|
||||||
formatB = ["col_turing", "col_ampere"]
|
formatB = ["col_turing", "col_ampere"]
|
||||||
values = list(product(dim1, dim4, dims, formatB))
|
values = list(product(dim1, dim4, dims, formatB))
|
||||||
names = ["dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"dim1_{0}_dim4_{1}_dims_{2}_formatB_{3}".format(*vals) for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim4, dims, formatB", values, ids=names)
|
||||||
|
@ -992,7 +1047,9 @@ def test_colrow_absmax(dim1, dim2, dims):
|
||||||
torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
|
torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
|
||||||
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
|
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
|
||||||
|
|
||||||
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(A, threshold=0.0)
|
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
|
||||||
|
A, threshold=0.0
|
||||||
|
)
|
||||||
|
|
||||||
torch.testing.assert_allclose(col_stats1, col_stats2)
|
torch.testing.assert_allclose(col_stats1, col_stats2)
|
||||||
torch.testing.assert_allclose(row_stats1, row_stats2)
|
torch.testing.assert_allclose(row_stats1, row_stats2)
|
||||||
|
@ -1023,8 +1080,12 @@ def test_double_quant(dim1, dim2):
|
||||||
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
|
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
|
||||||
|
|
||||||
n = CAt.numel()
|
n = CAt.numel()
|
||||||
num_not_close_rows = (torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
|
num_not_close_rows = (
|
||||||
num_not_close_cols = (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
|
(torch.isclose(CA, out_row1, atol=1) == 0).sum().item()
|
||||||
|
)
|
||||||
|
num_not_close_cols = (
|
||||||
|
(torch.isclose(CAt, out_col1, atol=1) == 0).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# allow for 1:500 error due to rounding differences
|
# allow for 1:500 error due to rounding differences
|
||||||
min_error = 1 / 500
|
min_error = 1 / 500
|
||||||
|
@ -1123,7 +1184,9 @@ def test_igemmlt_row_scale(dim1, dim4, inner):
|
||||||
|
|
||||||
c = 10.0 * inner * scale
|
c = 10.0 * inner * scale
|
||||||
row_scale = torch.ones_like(maxA) / c
|
row_scale = torch.ones_like(maxA) / c
|
||||||
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
|
outC32, SC = F.igemmlt(
|
||||||
|
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
|
||||||
|
)
|
||||||
C3, S = F.nvidia_transform(outC32, "row", state=SC)
|
C3, S = F.nvidia_transform(outC32, "row", state=SC)
|
||||||
maxval = torch.abs(C3).max()
|
maxval = torch.abs(C3).max()
|
||||||
if maxval == 127:
|
if maxval == 127:
|
||||||
|
@ -1204,7 +1267,9 @@ def test_row_scale_bench(dim1, dim4, inner):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
outC32, SC = F.igemmlt(A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale)
|
outC32, SC = F.igemmlt(
|
||||||
|
A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale
|
||||||
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
print("row-wise", time.time() - t0)
|
print("row-wise", time.time() - t0)
|
||||||
|
|
||||||
|
@ -1230,7 +1295,9 @@ a_order = ["row"]
|
||||||
out_order = ["col32", "col_turing", "col_ampere"]
|
out_order = ["col32", "col_turing", "col_ampere"]
|
||||||
transpose = [False, True]
|
transpose = [False, True]
|
||||||
dims = [2]
|
dims = [2]
|
||||||
values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose))
|
values = list(
|
||||||
|
product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)
|
||||||
|
)
|
||||||
names = [
|
names = [
|
||||||
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
|
"dim1_{0}_dim2_{1}_dim3_{2}_dims_{3}_dtype_{4}_orderA_{5}_orderOut_{6}_{7}".format(
|
||||||
*vals
|
*vals
|
||||||
|
@ -1240,14 +1307,20 @@ names = [
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", values, ids=names
|
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
|
||||||
|
values,
|
||||||
|
ids=names,
|
||||||
)
|
)
|
||||||
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
|
def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(dtype)
|
A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
A = torch.randint(10, 99, size=(dim1, dim2, dim3), device="cuda").to(dtype)
|
A = torch.randint(
|
||||||
|
10, 99, size=(dim1, dim2, dim3), device="cuda"
|
||||||
|
).to(dtype)
|
||||||
|
|
||||||
A.view(-1)[-1] = -1
|
A.view(-1)[-1] = -1
|
||||||
if transpose:
|
if transpose:
|
||||||
|
@ -1282,7 +1355,9 @@ names = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim1, dim2, dtype, orderA, orderOut", values, ids=names)
|
@pytest.mark.parametrize(
|
||||||
|
"dim1, dim2, dtype, orderA, orderOut", values, ids=names
|
||||||
|
)
|
||||||
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
|
def test_transform_to_row(dim1, dim2, dtype, orderA, orderOut):
|
||||||
for i in range(1):
|
for i in range(1):
|
||||||
A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
|
A = torch.randint(-127, 127, size=(dim1, dim2), device="cuda").to(dtype)
|
||||||
|
@ -1332,17 +1407,23 @@ def test_coo_double_quant(dim1, dim2):
|
||||||
|
|
||||||
idx = torch.abs(A) >= threshold
|
idx = torch.abs(A) >= threshold
|
||||||
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
||||||
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
|
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
|
||||||
|
A, threshold=threshold
|
||||||
|
)
|
||||||
|
|
||||||
if coo_tensor is not None:
|
if coo_tensor is not None:
|
||||||
A1 = A * idx
|
A1 = A * idx
|
||||||
A2 = torch.zeros_like(A)
|
A2 = torch.zeros_like(A)
|
||||||
A2[coo_tensor.rowidx.long(), coo_tensor.colidx.long()] = coo_tensor.values
|
A2[
|
||||||
|
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
|
||||||
|
] = coo_tensor.values
|
||||||
torch.testing.assert_allclose(A1, A2)
|
torch.testing.assert_allclose(A1, A2)
|
||||||
|
|
||||||
A1 = A * (idx == 0)
|
A1 = A * (idx == 0)
|
||||||
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
|
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
|
||||||
torch.testing.assert_allclose(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
|
torch.testing.assert_allclose(
|
||||||
|
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
n = 2
|
n = 2
|
||||||
|
@ -1454,7 +1535,9 @@ def test_integrated_sparse_decomp(dim1, dim2):
|
||||||
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
|
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
|
||||||
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
|
out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
|
||||||
|
|
||||||
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold)
|
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(
|
||||||
|
A, threshold=threshold
|
||||||
|
)
|
||||||
C32A, SA = F.transform(CA, "col32")
|
C32A, SA = F.transform(CA, "col32")
|
||||||
|
|
||||||
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
|
out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1)
|
||||||
|
@ -1494,7 +1577,9 @@ dim2 = [12288]
|
||||||
dtype = [torch.float16]
|
dtype = [torch.float16]
|
||||||
out_function = ["zeros", "ones"]
|
out_function = ["zeros", "ones"]
|
||||||
values = list(product(dim1, dim2, dtype, out_function))
|
values = list(product(dim1, dim2, dtype, out_function))
|
||||||
names = ["dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"dim1_{0}_dim2_{1}_dtype_{2}_out_func_{3}".format(*vals) for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names)
|
||||||
|
@ -1536,7 +1621,9 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
|
||||||
std = out1.std()
|
std = out1.std()
|
||||||
out1 /= std
|
out1 /= std
|
||||||
out2 /= std
|
out2 /= std
|
||||||
assert_all_approx_close(out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count)
|
assert_all_approx_close(
|
||||||
|
out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count
|
||||||
|
)
|
||||||
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
|
# assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count)
|
||||||
|
|
||||||
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
|
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
|
||||||
|
@ -1734,7 +1821,9 @@ values.append((batch_size, seqdim, 768, 4 * 768))
|
||||||
# values.append((batch_size, seqdim, 4096, 4*4096))
|
# values.append((batch_size, seqdim, 4096, 4*4096))
|
||||||
# values.append((batch_size, seqdim, 5140, 4*5140))
|
# values.append((batch_size, seqdim, 5140, 4*5140))
|
||||||
# values.append((batch_size, seqdim, 12288, 4*12288))
|
# values.append((batch_size, seqdim, 12288, 4*12288))
|
||||||
names = ["batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||||
|
|
|
@ -48,7 +48,9 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
|
||||||
class LinearFunction(torch.autograd.Function):
|
class LinearFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
|
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
|
||||||
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
|
round_func = (
|
||||||
|
LinearFunction.round_stoachastic if stochastic else torch.round
|
||||||
|
)
|
||||||
norm = math.sqrt(math.pi) / math.sqrt(2.0)
|
norm = math.sqrt(math.pi) / math.sqrt(2.0)
|
||||||
# std = torch.abs(x).mean()*norm
|
# std = torch.abs(x).mean()*norm
|
||||||
std = torch.std(x)
|
std = torch.std(x)
|
||||||
|
@ -116,7 +118,9 @@ class LinearFunction(torch.autograd.Function):
|
||||||
return x.to(dtype)
|
return x.to(dtype)
|
||||||
|
|
||||||
def get_8bit_linear(x, stochastic=False):
|
def get_8bit_linear(x, stochastic=False):
|
||||||
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
|
round_func = (
|
||||||
|
LinearFunction.round_stoachastic if stochastic else torch.round
|
||||||
|
)
|
||||||
max1 = torch.abs(x).max()
|
max1 = torch.abs(x).max()
|
||||||
x = x / max1 * 127
|
x = x / max1 * 127
|
||||||
x = round_func(x) / 127 * max1
|
x = round_func(x) / 127 * max1
|
||||||
|
@ -125,7 +129,9 @@ class LinearFunction(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_8bit_vector_wise(x, dim, stochastic=False):
|
def get_8bit_vector_wise(x, dim, stochastic=False):
|
||||||
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
|
round_func = (
|
||||||
|
LinearFunction.round_stoachastic if stochastic else torch.round
|
||||||
|
)
|
||||||
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
|
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
|
||||||
max1[max1 == 0] = 1.0
|
max1[max1 == 0] = 1.0
|
||||||
x = (x * 127) / max1
|
x = (x * 127) / max1
|
||||||
|
@ -209,7 +215,9 @@ class LinearFunction(torch.autograd.Function):
|
||||||
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
|
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
|
||||||
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
|
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
|
||||||
outputq = bnb.functional.igemm(x8, weight8.t())
|
outputq = bnb.functional.igemm(x8, weight8.t())
|
||||||
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
|
output = LinearFunction.dequant(
|
||||||
|
outputq, S1, S2, x.dtype, args.quant_type
|
||||||
|
)
|
||||||
# if torch.rand(1) < 0.01:
|
# if torch.rand(1) < 0.01:
|
||||||
# output32 = torch.matmul(x, weight.t())
|
# output32 = torch.matmul(x, weight.t())
|
||||||
# err = torch.abs(output-output32).float()
|
# err = torch.abs(output-output32).float()
|
||||||
|
@ -261,7 +269,9 @@ class LinearFunction(torch.autograd.Function):
|
||||||
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
|
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
|
||||||
)
|
)
|
||||||
|
|
||||||
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
|
grad_output8, S1 = LinearFunction.quant(
|
||||||
|
grad_output, args.quant_type, dim=2
|
||||||
|
)
|
||||||
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
|
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
|
||||||
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
|
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
|
||||||
grad_input = LinearFunction.dequant(
|
grad_input = LinearFunction.dequant(
|
||||||
|
@ -338,8 +348,12 @@ def test_linear8bit():
|
||||||
loss2.backward()
|
loss2.backward()
|
||||||
loss3.backward()
|
loss3.backward()
|
||||||
|
|
||||||
assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
|
assert_all_approx_close(
|
||||||
assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
|
l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
|
||||||
|
)
|
||||||
|
assert_all_approx_close(
|
||||||
|
l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2
|
||||||
|
)
|
||||||
assert_all_approx_close(
|
assert_all_approx_close(
|
||||||
l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
|
l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
|
||||||
)
|
)
|
||||||
|
@ -388,7 +402,9 @@ def test_linear8bitlt_accumulated_gradient():
|
||||||
l1 = torch.nn.Sequential(
|
l1 = torch.nn.Sequential(
|
||||||
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
|
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
|
||||||
)
|
)
|
||||||
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
|
l2 = torch.nn.Sequential(
|
||||||
|
*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
|
||||||
|
)
|
||||||
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
|
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
|
||||||
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
|
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
|
||||||
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
|
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
|
||||||
|
@ -462,7 +478,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
if threshold > 0:
|
if threshold > 0:
|
||||||
assert mlp.fc2.state.idx is not None
|
assert mlp.fc2.state.idx is not None
|
||||||
|
|
||||||
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
|
mlp = (
|
||||||
|
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
||||||
|
.cuda()
|
||||||
|
.half()
|
||||||
|
)
|
||||||
assert mlp.fc1.weight.dtype == torch.int8
|
assert mlp.fc1.weight.dtype == torch.int8
|
||||||
assert mlp.fc2.weight.dtype == torch.int8
|
assert mlp.fc2.weight.dtype == torch.int8
|
||||||
|
|
||||||
|
@ -475,7 +495,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
if threshold > 0:
|
if threshold > 0:
|
||||||
assert mlp.fc2.state.idx is not None
|
assert mlp.fc2.state.idx is not None
|
||||||
|
|
||||||
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
|
mlp = (
|
||||||
|
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
||||||
|
.half()
|
||||||
|
.cuda()
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||||
|
@ -488,7 +512,11 @@ def test_linear8bitlt_no_fp16_weights(threshold):
|
||||||
assert mlp.fc1.weight.dtype == torch.int8
|
assert mlp.fc1.weight.dtype == torch.int8
|
||||||
assert mlp.fc2.weight.dtype == torch.int8
|
assert mlp.fc2.weight.dtype == torch.int8
|
||||||
|
|
||||||
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda")
|
mlp = (
|
||||||
|
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
|
||||||
|
.half()
|
||||||
|
.to("cuda")
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(100):
|
for i in range(100):
|
||||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||||
|
|
|
@ -103,20 +103,26 @@ str2statenames["adam8bit_blockwise"] = [
|
||||||
("exp_avg", "state1", "qmap1", "absmax1"),
|
("exp_avg", "state1", "qmap1", "absmax1"),
|
||||||
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
||||||
]
|
]
|
||||||
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
str2statenames["momentum8bit"] = [
|
||||||
|
("momentum_buffer", "state1", "qmap1", "max1")
|
||||||
|
]
|
||||||
str2statenames["momentum8bit_blockwise"] = [
|
str2statenames["momentum8bit_blockwise"] = [
|
||||||
("momentum_buffer", "state1", "qmap1", "absmax1")
|
("momentum_buffer", "state1", "qmap1", "absmax1")
|
||||||
]
|
]
|
||||||
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
||||||
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
|
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
|
||||||
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
|
str2statenames["rmsprop8bit_blockwise"] = [
|
||||||
|
("square_avg", "state1", "qmap1", "absmax1")
|
||||||
|
]
|
||||||
|
|
||||||
dim1 = [1024]
|
dim1 = [1024]
|
||||||
dim2 = [32, 1024, 4097, 1]
|
dim2 = [32, 1024, 4097, 1]
|
||||||
gtype = [torch.float32, torch.float16]
|
gtype = [torch.float32, torch.float16]
|
||||||
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
|
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
|
||||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||||
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||||
|
@ -203,9 +209,13 @@ def test_global_config(dim1, dim2, gtype):
|
||||||
eps = 1e-8
|
eps = 1e-8
|
||||||
|
|
||||||
bnb.optim.GlobalOptimManager.get_instance().initialize()
|
bnb.optim.GlobalOptimManager.get_instance().initialize()
|
||||||
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
|
bnb.optim.GlobalOptimManager.get_instance().override_config(
|
||||||
|
p3, "optim_bits", 8
|
||||||
|
)
|
||||||
|
|
||||||
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
|
bnb.optim.GlobalOptimManager.get_instance().register_parameters(
|
||||||
|
[p1, p2, p3]
|
||||||
|
)
|
||||||
p1 = p1.cuda()
|
p1 = p1.cuda()
|
||||||
p2 = p2.cuda()
|
p2 = p2.cuda()
|
||||||
p3 = p3.cuda()
|
p3 = p3.cuda()
|
||||||
|
@ -245,7 +255,9 @@ optimizer_names = [
|
||||||
"rmsprop8bit_blockwise",
|
"rmsprop8bit_blockwise",
|
||||||
]
|
]
|
||||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||||
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||||
|
@ -329,8 +341,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
||||||
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||||
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
||||||
rm_path(path)
|
rm_path(path)
|
||||||
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
|
torch.testing.assert_allclose(
|
||||||
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
|
raws1cpy, bnb_optimizer.state[p2][name2]
|
||||||
|
)
|
||||||
|
torch.testing.assert_allclose(
|
||||||
|
qmap1, bnb_optimizer.state[p2][qmap]
|
||||||
|
)
|
||||||
|
|
||||||
if "blockwise" in optim_name:
|
if "blockwise" in optim_name:
|
||||||
s1 = F.dequantize_blockwise(
|
s1 = F.dequantize_blockwise(
|
||||||
|
@ -349,12 +365,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
||||||
|
|
||||||
num_not_close = (
|
num_not_close = (
|
||||||
torch.isclose(
|
torch.isclose(
|
||||||
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
|
torch_optimizer.state[p1][name1],
|
||||||
|
s1,
|
||||||
|
atol=atol,
|
||||||
|
rtol=rtol,
|
||||||
)
|
)
|
||||||
== 0
|
== 0
|
||||||
)
|
)
|
||||||
assert num_not_close.sum().item() < 20
|
assert num_not_close.sum().item() < 20
|
||||||
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
|
torch.testing.assert_allclose(
|
||||||
|
p1, p2.float(), atol=patol, rtol=prtol
|
||||||
|
)
|
||||||
|
|
||||||
# the parameters diverge quickly. Here we keep them close
|
# the parameters diverge quickly. Here we keep them close
|
||||||
# together so we can test against the Adam error
|
# together so we can test against the Adam error
|
||||||
|
@ -375,7 +396,10 @@ dim2 = [32, 1024, 4097]
|
||||||
gtype = [torch.float32]
|
gtype = [torch.float32]
|
||||||
optim_bits = [32, 8]
|
optim_bits = [32, 8]
|
||||||
values = list(product(dim1, dim2, gtype, optim_bits))
|
values = list(product(dim1, dim2, gtype, optim_bits))
|
||||||
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
|
||||||
|
for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
|
||||||
|
@ -391,7 +415,12 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
|
||||||
p2 = p1.clone()
|
p2 = p1.clone()
|
||||||
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
|
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
|
||||||
adam2 = bnb.optim.Adam(
|
adam2 = bnb.optim.Adam(
|
||||||
[p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5
|
[p2],
|
||||||
|
lr,
|
||||||
|
(beta1, beta2),
|
||||||
|
eps,
|
||||||
|
optim_bits=optim_bits,
|
||||||
|
percentile_clipping=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
gnorm_vec = torch.zeros(100).cuda()
|
gnorm_vec = torch.zeros(100).cuda()
|
||||||
|
@ -399,7 +428,9 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
|
||||||
|
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
step += 1
|
step += 1
|
||||||
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
|
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
|
||||||
|
0.01 * i
|
||||||
|
)
|
||||||
g2 = g1.clone()
|
g2 = g1.clone()
|
||||||
p2.grad = g2
|
p2.grad = g2
|
||||||
|
|
||||||
|
@ -430,10 +461,16 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
|
||||||
elif optim_bits == 8:
|
elif optim_bits == 8:
|
||||||
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
|
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
|
||||||
torch.testing.assert_allclose(
|
torch.testing.assert_allclose(
|
||||||
adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3
|
adam1.state[p1]["state1"],
|
||||||
|
adam2.state[p2]["state1"],
|
||||||
|
atol=2,
|
||||||
|
rtol=1e-3,
|
||||||
)
|
)
|
||||||
torch.testing.assert_allclose(
|
torch.testing.assert_allclose(
|
||||||
adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3
|
adam1.state[p1]["state2"],
|
||||||
|
adam2.state[p2]["state2"],
|
||||||
|
atol=2,
|
||||||
|
rtol=1e-3,
|
||||||
)
|
)
|
||||||
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
|
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
|
||||||
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
|
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
|
||||||
|
@ -463,7 +500,9 @@ gtype = [torch.float32, torch.float16]
|
||||||
# optimizer_names = ['lars_apex', 'lars8bit']
|
# optimizer_names = ['lars_apex', 'lars8bit']
|
||||||
optimizer_names = ["adam8bit_blockwise"]
|
optimizer_names = ["adam8bit_blockwise"]
|
||||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||||
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
|
names = [
|
||||||
|
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user