forked from mrq/bitsandbytes-rocm
commit
843ad0631c
14
README.md
14
README.md
|
@ -23,12 +23,12 @@ Resources:
|
|||
1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)``
|
||||
2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same)
|
||||
3. There are two modes:
|
||||
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``use_fp16_weights=True`` (default)
|
||||
- Int8 inference. Pass the argument ``use_fp16_weights=False``
|
||||
- Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default)
|
||||
- Int8 inference. Pass the argument ``has_fp16_weights=False``
|
||||
4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``.
|
||||
```python
|
||||
# LLM.int8()
|
||||
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, use_fp16_weights=False, threshold=6.0)
|
||||
linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0)
|
||||
# inputs need to be fp16
|
||||
out = linear(x.to(torch.float16))
|
||||
```
|
||||
|
@ -115,7 +115,8 @@ We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fa
|
|||
|
||||
## How to cite us
|
||||
If you found this library and found LLM.int8() useful, please consider citing our work:
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{dettmers2022llmint8,
|
||||
title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale},
|
||||
author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke},
|
||||
|
@ -124,8 +125,9 @@ If you found this library and found LLM.int8() useful, please consider citing ou
|
|||
}
|
||||
```
|
||||
|
||||
For 8-bit optimizers or quantization routines please consider citing the following work.
|
||||
```
|
||||
For 8-bit optimizers or quantization routines, please consider citing the following work:
|
||||
|
||||
```bibtex
|
||||
@article{dettmers2022optimizers,
|
||||
title={8-bit Optimizers via Block-wise Quantization},
|
||||
author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke},
|
||||
|
|
|
@ -12,7 +12,7 @@ from .autograd._functions import (
|
|||
)
|
||||
from .cextension import COMPILED_WITH_CUDA
|
||||
from .nn import modules
|
||||
from . import cuda_setup
|
||||
from . import cuda_setup, utils
|
||||
|
||||
if COMPILED_WITH_CUDA:
|
||||
from .optim import adam
|
||||
|
|
|
@ -3,8 +3,9 @@
|
|||
# cli()
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
|
||||
HEADER_WIDTH = 60
|
||||
|
||||
|
@ -32,8 +33,6 @@ print()
|
|||
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
|
||||
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
|
||||
from .cuda_setup.env_vars import to_be_ignored
|
||||
from .utils import print_stderr
|
||||
|
||||
|
||||
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
|
||||
for k, v in os.environ.items():
|
||||
|
@ -84,7 +83,7 @@ try:
|
|||
|
||||
except ImportError:
|
||||
print()
|
||||
print_stderr(
|
||||
warn(
|
||||
f"WARNING: {__package__} is currently running as CPU-only!\n"
|
||||
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
|
||||
f"If you think that this is so erroneously,\nplease report an issue!"
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import operator
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
@ -368,9 +367,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
||||
matmul = MatMul8bitLt.apply
|
||||
|
||||
|
||||
def matmul(
|
||||
A: tensor,
|
||||
B: tensor,
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
import ctypes
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class CudaLibVals:
|
||||
# code bits taken from
|
||||
# https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
|
||||
|
||||
nGpus: ctypes.c_int = field(default=ctypes.c_int())
|
||||
cc_major: ctypes.c_int = field(default=ctypes.c_int())
|
||||
cc_minor: ctypes.c_int = field(default=ctypes.c_int())
|
||||
device: ctypes.c_int = field(default=ctypes.c_int())
|
||||
error_str: ctypes.c_char_p = field(default=ctypes.c_char_p())
|
||||
cuda: ctypes.CDLL = field(init=False, repr=False)
|
||||
ccs: List[str, ...] = field(init=False)
|
||||
|
||||
def _initialize_driver_API(self):
|
||||
self.check_cuda_result(self.cuda.cuInit(0))
|
||||
|
||||
def _load_cuda_lib(self):
|
||||
"""
|
||||
1. find libcuda.so library (GPU driver) (/usr/lib)
|
||||
init_device -> init variables -> call function by reference
|
||||
"""
|
||||
libnames = "libcuda.so"
|
||||
for libname in libnames:
|
||||
try:
|
||||
self.cuda = ctypes.CDLL(libname)
|
||||
except OSError:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise OSError("could not load any of: " + " ".join(libnames))
|
||||
|
||||
def call_cuda_func(self, function_obj, **kwargs):
|
||||
CUDA_SUCCESS = 0 # constant taken from cuda.h
|
||||
pass
|
||||
# if (CUDA_SUCCESS := function_obj(
|
||||
|
||||
def _error_handle(cuda_lib_call_return_value):
|
||||
"""
|
||||
2. call extern C function to determine CC
|
||||
(see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html)
|
||||
"""
|
||||
CUDA_SUCCESS = 0 # constant taken from cuda.h
|
||||
|
||||
if cuda_lib_call_return_value != CUDA_SUCCESS:
|
||||
self.cuda.cuGetErrorString(
|
||||
cuda_lib_call_return_value,
|
||||
ctypes.byref(self.error_str),
|
||||
)
|
||||
print("Count not initialize CUDA - failure!")
|
||||
raise Exception("CUDA exception!")
|
||||
return cuda_lib_call_return_value
|
||||
|
||||
def __post_init__(self):
|
||||
self._load_cuda_lib()
|
||||
self._initialize_driver_API()
|
||||
self.check_cuda_result(
|
||||
self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus))
|
||||
)
|
||||
tmp_ccs = []
|
||||
for gpu_index in range(self.nGpus.value):
|
||||
check_cuda_result(
|
||||
self.cuda,
|
||||
self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index),
|
||||
)
|
||||
check_cuda_result(
|
||||
self.cuda,
|
||||
self.cuda.cuDeviceComputeCapability(
|
||||
ctypes.byref(self.cc_major),
|
||||
ctypes.byref(self.cc_minor),
|
||||
self.device,
|
||||
),
|
||||
)
|
||||
tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}")
|
||||
self.ccs = sorted(tmp_ccs, reverse=True)
|
|
@ -17,9 +17,7 @@ evaluation:
|
|||
"""
|
||||
|
||||
import ctypes
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils import execute_and_return
|
||||
from .paths import determine_cuda_runtime_lib_path
|
||||
|
||||
|
||||
|
@ -28,7 +26,7 @@ def check_cuda_result(cuda, result_val):
|
|||
if result_val != 0:
|
||||
error_str = ctypes.c_char_p()
|
||||
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
|
||||
raise Exception(f"CUDA exception! Error code: {error_str.value.decode()}")
|
||||
print(f"CUDA exception! Error code: {error_str.value.decode()}")
|
||||
|
||||
def get_cuda_version(cuda, cudart_path):
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION
|
||||
|
@ -57,7 +55,7 @@ def get_cuda_lib_handle():
|
|||
cuda = ctypes.CDLL("libcuda.so")
|
||||
except OSError:
|
||||
# TODO: shouldn't we error or at least warn here?
|
||||
raise Exception('CUDA SETUP: ERROR! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
|
||||
print('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!')
|
||||
return None
|
||||
check_cuda_result(cuda, cuda.cuInit(0))
|
||||
|
||||
|
@ -80,7 +78,6 @@ def get_compute_capabilities(cuda):
|
|||
cc_major = ctypes.c_int()
|
||||
cc_minor = ctypes.c_int()
|
||||
|
||||
result = ctypes.c_int()
|
||||
device = ctypes.c_int()
|
||||
|
||||
check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
|
||||
|
@ -119,6 +116,10 @@ def evaluate_cuda_setup():
|
|||
print('For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link')
|
||||
print('='*80)
|
||||
binary_name = "libbitsandbytes_cpu.so"
|
||||
#if not torch.cuda.is_available():
|
||||
#print('No GPU detected. Loading CPU library...')
|
||||
#return binary_name
|
||||
|
||||
cudart_path = determine_cuda_runtime_lib_path()
|
||||
if cudart_path is None:
|
||||
print(
|
||||
|
|
|
@ -2,23 +2,11 @@ from pathlib import Path
|
|||
from typing import Set, Union
|
||||
from warnings import warn
|
||||
|
||||
from ..utils import print_stderr
|
||||
from .env_vars import get_potentially_lib_path_containing_env_vars
|
||||
|
||||
|
||||
CUDA_RUNTIME_LIB: str = "libcudart.so"
|
||||
|
||||
|
||||
def purge_unwanted_semicolon(tentative_path: Path) -> Path:
|
||||
"""
|
||||
Special function to handle the following exception:
|
||||
__LMOD_REF_COUNT_PATH=/sw/cuda/11.6.2/bin:2;/mmfs1/home/dettmers/git/sched/bin:1;/mmfs1/home/dettmers/data/anaconda3/bin:1;/mmfs1/home/dettmers/data/anaconda3/condabin:1;/mmfs1/home/dettmers/.local/bin:1;/mmfs1/home/dettmers/bin:1;/usr/local/bin:1;/usr/bin:1;/usr/local/sbin:1;/usr/sbin:1;/mmfs1/home/dettmers/.fzf/bin:1;/mmfs1/home/dettmers/data/local/cuda-11.4/bin:1
|
||||
"""
|
||||
# if ';' in str(tentative_path):
|
||||
# path_as_str, _ = str(tentative_path).split(';')
|
||||
pass
|
||||
|
||||
|
||||
def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]:
|
||||
return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path}
|
||||
|
||||
|
@ -29,7 +17,7 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
|
|||
}
|
||||
|
||||
if non_existent_directories:
|
||||
print_stderr(
|
||||
warn(
|
||||
"WARNING: The following directories listed in your path were found to "
|
||||
f"be non-existent: {non_existent_directories}"
|
||||
)
|
||||
|
@ -117,8 +105,6 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
|
|||
if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"}
|
||||
}
|
||||
|
||||
|
||||
|
||||
cuda_runtime_libs = set()
|
||||
for env_var, value in remaining_candidate_env_vars.items():
|
||||
cuda_runtime_libs.update(find_cuda_lib_in(value))
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
import ctypes as ct
|
||||
import operator
|
||||
import random
|
||||
import math
|
||||
import torch
|
||||
|
||||
from typing import Tuple
|
||||
|
@ -185,14 +184,9 @@ def create_dynamic_map(signed=True, n=7):
|
|||
|
||||
|
||||
def get_special_format_str():
|
||||
if not torch.cuda.is_available(): return 'col_turing'
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 7:
|
||||
print(
|
||||
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
|
||||
)
|
||||
assert major >= 7
|
||||
|
||||
if major == 7:
|
||||
if major <= 7:
|
||||
return "col_turing"
|
||||
elif major == 8:
|
||||
return "col_ampere"
|
||||
|
@ -248,23 +242,6 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
|
|||
return getattr(lib, name)
|
||||
|
||||
|
||||
class GlobalData(object):
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("Call get_instance() instead")
|
||||
|
||||
def initialize(self):
|
||||
self.data = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls.__new__(cls)
|
||||
cls._instance.initialize()
|
||||
return cls._instance
|
||||
|
||||
|
||||
def get_transform_buffer(
|
||||
shape, dtype, device, to_order, from_order="row", transpose=False
|
||||
):
|
||||
|
@ -1685,21 +1662,6 @@ def double_quant(
|
|||
return out_row, out_col, row_stats, col_stats, coo_tensor
|
||||
|
||||
|
||||
def get_special_format_str():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
if major < 7:
|
||||
print(
|
||||
f"Device with CUDA capability of {major} not supported for 8-bit matmul. Device has no tensor cores!"
|
||||
)
|
||||
assert major >= 7
|
||||
|
||||
if major == 7: return 'col_turing'
|
||||
elif major == 8: return 'col_ampere'
|
||||
else: return 'col_turing'
|
||||
|
||||
|
||||
|
||||
|
||||
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
|
||||
prev_device = pre_call(A.device)
|
||||
if state is None: state = (A.shape, from_order)
|
||||
|
|
|
@ -5,13 +5,12 @@
|
|||
|
||||
from bitsandbytes.cextension import COMPILED_WITH_CUDA
|
||||
|
||||
if COMPILED_WITH_CUDA:
|
||||
from .adam import Adam, Adam8bit, Adam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
from .adam import Adam, Adam8bit, Adam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
|
||||
from .optimizer import GlobalOptimManager
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
|
@ -22,11 +21,3 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
|
|||
|
||||
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
|
||||
return std_out, std_err
|
||||
|
||||
|
||||
def print_stderr(s: str) -> None:
|
||||
print(s, file=sys.stderr)
|
||||
|
||||
|
||||
def warn_of_missing_prerequisite(s: str) -> None:
|
||||
print_stderr("WARNING, missing pre-requisite: " + s)
|
||||
|
|
|
@ -371,7 +371,11 @@ template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandl
|
|||
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{
|
||||
#ifdef NO_CUBLASLT
|
||||
printf("ERROR: Your GPU does not support Int8 Matmul!");
|
||||
cout << "" << endl;
|
||||
cout << "=============================================" << endl;
|
||||
cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl;
|
||||
cout << "=============================================" << endl;
|
||||
cout << "" << endl;
|
||||
assert(false);
|
||||
|
||||
return 0;
|
||||
|
|
2
setup.py
2
setup.py
|
@ -18,7 +18,7 @@ def read(fname):
|
|||
|
||||
setup(
|
||||
name=f"bitsandbytes",
|
||||
version=f"0.32.1",
|
||||
version=f"0.32.3",
|
||||
author="Tim Dettmers",
|
||||
author_email="dettmers@cs.washington.edu",
|
||||
description="8-bit optimizers and matrix multiplication routines.",
|
||||
|
|
|
@ -40,6 +40,7 @@ names = [
|
|||
ids=names,
|
||||
)
|
||||
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
|
||||
if dim2 > 0:
|
||||
dim2 = dim2 - (dim2 % 16)
|
||||
dim3 = dim3 - (dim3 % 16)
|
||||
|
@ -306,6 +307,7 @@ def test_matmullt(
|
|||
has_fp16_weights,
|
||||
has_bias
|
||||
):
|
||||
if not torch.cuda.is_available(): pytest.skip('No GPU found.')
|
||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
|
||||
|
|
|
@ -1813,16 +1813,16 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
|
||||
|
||||
batch_size = 1
|
||||
seqdim = 2048
|
||||
seqdim = 1
|
||||
values = []
|
||||
values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
#values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
# values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
# values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
# values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
# values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
# values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
# values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
# values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = [
|
||||
"batch_{0}_seq_{1}_model_{2}_hidden_{3}".format(*vals) for vals in values
|
||||
]
|
||||
|
@ -1830,6 +1830,7 @@ names = [
|
|||
|
||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||
def test_bench_matmul(batch, seq, model, hidden):
|
||||
iters = 128
|
||||
formatB = F.get_special_format_str()
|
||||
|
||||
A = torch.randn(batch, seq, model, device="cuda").half()
|
||||
|
@ -1848,28 +1849,33 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
linearMixedBit.eval()
|
||||
|
||||
# warmup
|
||||
for i in range(100):
|
||||
for i in range(iters):
|
||||
torch.matmul(A, B.t())
|
||||
torch.cuda.synchronize()
|
||||
print("")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
for i in range(iters):
|
||||
torch.matmul(A, B.t())
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"pytorch: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
for i in range(iters):
|
||||
bnb.matmul(A, B)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"bnb lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul(A, B, threshold=6.0)
|
||||
torch.cuda.synchronize()
|
||||
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
|
@ -1877,18 +1883,16 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
CxB, SB = F.transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
for i in range(iters):
|
||||
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"igemmlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
BA, statsB = F.vectorwise_quant(B, dim=1)
|
||||
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
for i in range(iters):
|
||||
A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, statsA = F.vectorwise_quant(A2, dim=1)
|
||||
C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
|
@ -1896,15 +1900,13 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
||||
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
for i in range(iters):
|
||||
A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
|
||||
C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
|
@ -1912,14 +1914,12 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
for i in range(iters):
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
|
@ -1929,7 +1929,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
for i in range(iters):
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
|
|
Loading…
Reference in New Issue
Block a user