Merge remote-tracking branch 'origin/main' into merge
This commit is contained in:
commit
675baa79d2
27
CHANGELOG.md
27
CHANGELOG.md
|
@ -201,3 +201,30 @@ Features:
|
|||
|
||||
Improvements:
|
||||
- Improved logging for the CUDA detection mechanism.
|
||||
|
||||
### 0.38.0
|
||||
|
||||
#### 8-bit Lion, Load/Store 8-bit Models directly from/to HF Hub
|
||||
|
||||
Features:
|
||||
- Support for 32 and 8-bit Lion has been added. Thank you @lucidrains
|
||||
- Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab
|
||||
- New bug report features `python -m bitsandbytes` now gives extensive debugging details to debug CUDA setup failures.
|
||||
|
||||
Bug fixes:
|
||||
- Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins
|
||||
- Fixed a bug where cudart.so libraries could not be found in newer PyTorch releases.
|
||||
|
||||
Improvements:
|
||||
- Improved the CUDA Setup procedure by doing a more extensive search for CUDA libraries
|
||||
|
||||
Deprecated:
|
||||
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
|
||||
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
|
||||
|
||||
|
||||
### 0.38.1
|
||||
|
||||
Features:
|
||||
- Added Int8 SwitchBack layers
|
||||
- Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`)
|
||||
|
|
51
README.md
51
README.md
|
@ -11,11 +11,41 @@ Resources:
|
|||
|
||||
## TL;DR
|
||||
**Requirements**
|
||||
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. LLM.int8() requires Turing or Ampere GPUs.
|
||||
Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0.
|
||||
|
||||
(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0)
|
||||
|
||||
**Installation**:
|
||||
|
||||
``pip install bitsandbytes``
|
||||
|
||||
In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below.
|
||||
|
||||
Compilation quickstart:
|
||||
```bash
|
||||
git clone https://github.com/timdettmers/bitsandbytes.git
|
||||
cd bitsandbytes
|
||||
|
||||
# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120}
|
||||
# make argument in {cuda110, cuda11x, cuda12x}
|
||||
# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes
|
||||
CUDA_VERSION=117 make cuda11x
|
||||
python setup.py install
|
||||
```
|
||||
|
||||
**Using Int8 inference with HuggingFace Transformers**
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
'decapoda-research/llama-7b-hf,
|
||||
device_map='auto',
|
||||
load_in_8bit=True,
|
||||
max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB')
|
||||
```
|
||||
|
||||
A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py).
|
||||
|
||||
**Using 8-bit optimizer**:
|
||||
1. Comment out optimizer: ``#torch.optim.Adam(....)``
|
||||
2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same)
|
||||
|
@ -40,7 +70,7 @@ out = linear(x.to(torch.float16))
|
|||
## Features
|
||||
- 8-bit Matrix multiplication with mixed precision decomposition
|
||||
- LLM.int8() inference
|
||||
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB (saves 75% memory)
|
||||
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory)
|
||||
- Stable Embedding Layer: Improved stability through better initialization, and normalization
|
||||
- 8-bit quantization: Quantile, Linear, and Dynamic quantization
|
||||
- Fast quantile estimation: Up to 100x faster than other algorithms
|
||||
|
@ -113,8 +143,23 @@ For upcoming features and changes and full history see [Patch Notes](CHANGELOG.m
|
|||
2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_)
|
||||
|
||||
## Compile from source
|
||||
To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands.
|
||||
|
||||
To compile from source, please follow the [compile_from_source.md](compile_from_source.md) instructions.
|
||||
```bash
|
||||
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh
|
||||
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
|
||||
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121}
|
||||
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
|
||||
|
||||
# For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc
|
||||
bash cuda install 118 ~/local 1
|
||||
```
|
||||
|
||||
To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`:
|
||||
|
||||
``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x``
|
||||
|
||||
For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions.
|
||||
|
||||
## License
|
||||
|
||||
|
|
4
benchmarking/switchback/README.md
Normal file
4
benchmarking/switchback/README.md
Normal file
|
@ -0,0 +1,4 @@
|
|||
Steps:
|
||||
|
||||
1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling).
|
||||
2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed.
|
60
benchmarking/switchback/info_a100_py2.jsonl
Normal file
60
benchmarking/switchback/info_a100_py2.jsonl
Normal file
|
@ -0,0 +1,60 @@
|
|||
{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28139352798461914, "standard_gw": 0.2811811864376068, "standard_gx": 0.30258670449256897, "rowwise_fwd": 0.1994594931602478, "rowwise_bwd": 0.16159191727638245, "global_fwd": 0.19502267241477966, "global_bwd": 0.16080215573310852, "x_quantize_rowwise": 0.03306940197944641, "g_quantize_rowwise": 0.08210167288780212, "w_quantize_rowwise": 0.03385916352272034, "w_quantize_colwise_transpose": 0.08635595440864563, "w_quantize_global": 0.09237229824066162, "w_quantize_global_transpose": 0.10007619857788086, "time_standard": 0.8651614189147949, "time_rowwise": 0.8776187896728516, "time_global": 0.944625586271286}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.262625515460968, "standard_gw": 0.2806223928928375, "standard_gx": 0.31118839979171753, "rowwise_fwd": 0.1828707754611969, "rowwise_bwd": 0.21236762404441833, "global_fwd": 0.16665831208229065, "global_bwd": 0.19929558038711548, "x_quantize_rowwise": 0.08227676153182983, "g_quantize_rowwise": 0.03310292959213257, "w_quantize_rowwise": 0.032648444175720215, "w_quantize_colwise_transpose": 0.09015202522277832, "w_quantize_global": 0.0988692045211792, "w_quantize_global_transpose": 0.10057538747787476, "time_standard": 0.8544363081455231, "time_rowwise": 0.9140409529209137, "time_global": 0.96140056848526}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5731917917728424, "standard_gw": 0.5709454417228699, "standard_gx": 0.5963630974292755, "rowwise_fwd": 0.37662312388420105, "rowwise_bwd": 0.281747430562973, "global_fwd": 0.36768242716789246, "global_bwd": 0.28043612837791443, "x_quantize_rowwise": 0.046547502279281616, "g_quantize_rowwise": 0.15532970428466797, "w_quantize_rowwise": 0.032436102628707886, "w_quantize_colwise_transpose": 0.08635222911834717, "w_quantize_global": 0.0947415828704834, "w_quantize_global_transpose": 0.10129809379577637, "time_standard": 1.7405003309249878, "time_rowwise": 1.5499815344810486, "time_global": 1.616980880498886}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5341619253158569, "standard_gw": 0.5690865218639374, "standard_gx": 0.599835067987442, "rowwise_fwd": 0.3233291208744049, "rowwise_bwd": 0.41359663009643555, "global_fwd": 0.2831108868122101, "global_bwd": 0.37280842661857605, "x_quantize_rowwise": 0.15563145279884338, "g_quantize_rowwise": 0.046741217374801636, "w_quantize_rowwise": 0.03306940197944641, "w_quantize_colwise_transpose": 0.09020790457725525, "w_quantize_global": 0.0925213098526001, "w_quantize_global_transpose": 0.09945780038833618, "time_standard": 1.7030835151672363, "time_rowwise": 1.6316622495651245, "time_global": 1.6193576157093048}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2199915945529938, "standard_gw": 1.1069811880588531, "standard_gx": 1.09761580824852, "rowwise_fwd": 0.738043338060379, "rowwise_bwd": 0.5549229681491852, "global_fwd": 0.7219798862934113, "global_bwd": 0.5512163043022156, "x_quantize_rowwise": 0.08748471736907959, "g_quantize_rowwise": 0.3023110330104828, "w_quantize_rowwise": 0.03182142972946167, "w_quantize_colwise_transpose": 0.08632615208625793, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.10032951831817627, "time_standard": 3.424588590860367, "time_rowwise": 2.9078908264636993, "time_global": 2.9647573828697205}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1040829122066498, "standard_gw": 1.1221766471862793, "standard_gx": 1.1548101902008057, "rowwise_fwd": 0.581938773393631, "rowwise_bwd": 0.7480122148990631, "global_fwd": 0.5537159740924835, "global_bwd": 0.7232688367366791, "x_quantize_rowwise": 0.30193477869033813, "g_quantize_rowwise": 0.08745118975639343, "w_quantize_rowwise": 0.03374740481376648, "w_quantize_colwise_transpose": 0.09068101644515991, "w_quantize_global": 0.09645149111747742, "w_quantize_global_transpose": 0.10189786553382874, "time_standard": 3.3810697495937347, "time_rowwise": 2.9659420251846313, "time_global": 2.9868967831134796}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.4533793330192566, "standard_gw": 2.1938569843769073, "standard_gx": 2.179361879825592, "rowwise_fwd": 1.4615543186664581, "rowwise_bwd": 1.0522231459617615, "global_fwd": 1.4288239181041718, "global_bwd": 1.0450035333633423, "x_quantize_rowwise": 0.1691766083240509, "g_quantize_rowwise": 0.5951300263404846, "w_quantize_rowwise": 0.03337860107421875, "w_quantize_colwise_transpose": 0.08653849363327026, "w_quantize_global": 0.0940859317779541, "w_quantize_global_transpose": 0.09976327419281006, "time_standard": 6.826598197221756, "time_rowwise": 5.5918581783771515, "time_global": 5.625840276479721}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.1698065102100372, "standard_gw": 2.1875128149986267, "standard_gx": 2.2887587547302246, "rowwise_fwd": 1.0762326419353485, "rowwise_bwd": 1.4638006687164307, "global_fwd": 1.0450668632984161, "global_bwd": 1.4308765530586243, "x_quantize_rowwise": 0.5953535437583923, "g_quantize_rowwise": 0.16899779438972473, "w_quantize_rowwise": 0.03240257501602173, "w_quantize_colwise_transpose": 0.09106099605560303, "w_quantize_global": 0.09546056389808655, "w_quantize_global_transpose": 0.09852275252342224, "time_standard": 6.6460780799388885, "time_rowwise": 5.615361034870148, "time_global": 5.621790885925293}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.858218133449554, "standard_gw": 4.3631307780742645, "standard_gx": 4.404045641422272, "rowwise_fwd": 2.9063820838928223, "rowwise_bwd": 2.094462513923645, "global_fwd": 2.8426870703697205, "global_bwd": 2.0792782306671143, "x_quantize_rowwise": 0.33241137862205505, "g_quantize_rowwise": 1.1817105114459991, "w_quantize_rowwise": 0.03374367952346802, "w_quantize_colwise_transpose": 0.08633732795715332, "w_quantize_global": 0.09231641888618469, "w_quantize_global_transpose": 0.100012868642807, "time_standard": 13.62539455294609, "time_rowwise": 10.998178273439407, "time_global": 10.991547256708145}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.246581345796585, "standard_gw": 4.42587211728096, "standard_gx": 4.581417888402939, "rowwise_fwd": 2.1114833652973175, "rowwise_bwd": 2.9050447046756744, "global_fwd": 2.0806826651096344, "global_bwd": 2.85966694355011, "x_quantize_rowwise": 1.1816024780273438, "g_quantize_rowwise": 0.33330172300338745, "w_quantize_rowwise": 0.033445656299591064, "w_quantize_colwise_transpose": 0.09065866470336914, "w_quantize_global": 0.09239837527275085, "w_quantize_global_transpose": 0.09984523057937622, "time_standard": 13.253871351480484, "time_rowwise": 11.081408709287643, "time_global": 11.073369532823563}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.4859529435634613, "standard_gw": 0.46338513493537903, "standard_gx": 0.42321905493736267, "rowwise_fwd": 0.2761557698249817, "rowwise_bwd": 0.20775198936462402, "global_fwd": 0.2713911235332489, "global_bwd": 0.20639970898628235, "x_quantize_rowwise": 0.033095479011535645, "g_quantize_rowwise": 0.11894106864929199, "w_quantize_rowwise": 0.03125518560409546, "w_quantize_colwise_transpose": 0.1424551010131836, "w_quantize_global": 0.07288157939910889, "w_quantize_global_transpose": 0.08071959018707275, "time_standard": 1.372557133436203, "time_rowwise": 1.2730397284030914, "time_global": 1.2468136847019196}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.3920421004295349, "standard_gw": 0.44424086809158325, "standard_gx": 0.4759356379508972, "rowwise_fwd": 0.23231282830238342, "rowwise_bwd": 0.28430670499801636, "global_fwd": 0.20883232355117798, "global_bwd": 0.2741999924182892, "x_quantize_rowwise": 0.12018159031867981, "g_quantize_rowwise": 0.03195926547050476, "w_quantize_rowwise": 0.026017427444458008, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.07734447717666626, "w_quantize_global_transpose": 0.0788569450378418, "time_standard": 1.3122186064720154, "time_rowwise": 1.2863576412200928, "time_global": 1.235615462064743}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0111741721630096, "standard_gw": 0.9267590939998627, "standard_gx": 0.8254274725914001, "rowwise_fwd": 0.5434826016426086, "rowwise_bwd": 0.4077926278114319, "global_fwd": 0.5318708717823029, "global_bwd": 0.40537863969802856, "x_quantize_rowwise": 0.059738755226135254, "g_quantize_rowwise": 0.2299174666404724, "w_quantize_rowwise": 0.02545863389968872, "w_quantize_colwise_transpose": 0.14269724488258362, "w_quantize_global": 0.07300823926925659, "w_quantize_global_transpose": 0.07878988981246948, "time_standard": 2.7633607387542725, "time_rowwise": 2.335846424102783, "time_global": 2.305462956428528}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8095316588878632, "standard_gw": 0.8607134222984314, "standard_gx": 0.9204968810081482, "rowwise_fwd": 0.4275888204574585, "rowwise_bwd": 0.5485899746417999, "global_fwd": 0.41000545024871826, "global_bwd": 0.5317628383636475, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.059254467487335205, "w_quantize_rowwise": 0.02466142177581787, "w_quantize_colwise_transpose": 0.14865398406982422, "w_quantize_global": 0.07582828402519226, "w_quantize_global_transpose": 0.08231401443481445, "time_standard": 2.5907419621944427, "time_rowwise": 2.2996440529823303, "time_global": 2.2500604391098022}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0658522844314575, "standard_gw": 1.718364655971527, "standard_gx": 1.6660578548908234, "rowwise_fwd": 1.066897064447403, "rowwise_bwd": 0.8070804178714752, "global_fwd": 1.0473169386386871, "global_bwd": 0.8021742105484009, "x_quantize_rowwise": 0.11274218559265137, "g_quantize_rowwise": 0.4518181085586548, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.14259666204452515, "w_quantize_global": 0.07484853267669678, "w_quantize_global_transpose": 0.07976219058036804, "time_standard": 5.450274795293808, "time_rowwise": 4.326000809669495, "time_global": 4.287026822566986}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.7549192309379578, "standard_gw": 1.6954988241195679, "standard_gx": 1.8179528415203094, "rowwise_fwd": 0.8649080991744995, "rowwise_bwd": 1.0746456682682037, "global_fwd": 0.8023083209991455, "global_bwd": 1.0471977293491364, "x_quantize_rowwise": 0.45225024223327637, "g_quantize_rowwise": 0.11286512017250061, "w_quantize_rowwise": 0.0252649188041687, "w_quantize_colwise_transpose": 0.14732033014297485, "w_quantize_global": 0.07537379860877991, "w_quantize_global_transpose": 0.0807642936706543, "time_standard": 6.268370896577835, "time_rowwise": 4.372753202915192, "time_global": 4.266258329153061}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.098430275917053, "standard_gw": 3.3501461148262024, "standard_gx": 5.560480058193207, "rowwise_fwd": 2.112947404384613, "rowwise_bwd": 1.605246216058731, "global_fwd": 2.0697638392448425, "global_bwd": 1.5953518450260162, "x_quantize_rowwise": 0.21921470761299133, "g_quantize_rowwise": 0.8956789970397949, "w_quantize_rowwise": 0.02710893750190735, "w_quantize_colwise_transpose": 0.14268234372138977, "w_quantize_global": 0.07259473204612732, "w_quantize_global_transpose": 0.07899105548858643, "time_standard": 13.009056448936462, "time_rowwise": 8.35302472114563, "time_global": 8.281741291284561}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.586959421634674, "standard_gw": 3.358360379934311, "standard_gx": 3.6434978246688843, "rowwise_fwd": 1.6269534826278687, "rowwise_bwd": 2.128206193447113, "global_fwd": 1.5950687229633331, "global_bwd": 2.0831897854804993, "x_quantize_rowwise": 0.8954145014286041, "g_quantize_rowwise": 0.21914392709732056, "w_quantize_rowwise": 0.026203691959381104, "w_quantize_colwise_transpose": 0.14658644795417786, "w_quantize_global": 0.07478520274162292, "w_quantize_global_transpose": 0.07964670658111572, "time_standard": 12.58881762623787, "time_rowwise": 8.400868624448776, "time_global": 8.305609226226807}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.229725062847137, "standard_gw": 6.791356950998306, "standard_gx": 6.806455552577972, "rowwise_fwd": 4.252471029758453, "rowwise_bwd": 3.2062679529190063, "global_fwd": 4.175614565610886, "global_bwd": 3.1837262213230133, "x_quantize_rowwise": 0.4321373999118805, "g_quantize_rowwise": 1.787092536687851, "w_quantize_rowwise": 0.0270158052444458, "w_quantize_colwise_transpose": 0.1424252986907959, "w_quantize_global": 0.07348507642745972, "w_quantize_global_transpose": 0.07829815149307251, "time_standard": 21.827537566423416, "time_rowwise": 16.63876697421074, "time_global": 16.52171090245247}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.279478669166565, "standard_gw": 6.7345499992370605, "standard_gx": 7.206875830888748, "rowwise_fwd": 3.209315240383148, "rowwise_bwd": 4.256397485733032, "global_fwd": 3.180190920829773, "global_bwd": 4.177983850240707, "x_quantize_rowwise": 1.7836056649684906, "g_quantize_rowwise": 0.4321075975894928, "w_quantize_rowwise": 0.03205239772796631, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09316205978393555, "w_quantize_global_transpose": 0.10086596012115479, "time_standard": 25.220904499292374, "time_rowwise": 16.5947787463665, "time_global": 16.502466052770615}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5776733160018921, "standard_gw": 0.5300231277942657, "standard_gx": 0.6005913019180298, "rowwise_fwd": 0.33330172300338745, "rowwise_bwd": 0.2957060933113098, "global_fwd": 0.32876431941986084, "global_bwd": 0.29108673334121704, "x_quantize_rowwise": 0.03466755151748657, "g_quantize_rowwise": 0.12264400720596313, "w_quantize_rowwise": 0.033874064683914185, "w_quantize_colwise_transpose": 0.1775398850440979, "w_quantize_global": 0.09503215551376343, "w_quantize_global_transpose": 0.10617449879646301, "time_standard": 1.7082877457141876, "time_rowwise": 1.5277564525604248, "time_global": 1.5083923935890198}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5164109170436859, "standard_gw": 0.5367249250411987, "standard_gx": 0.5876161158084869, "rowwise_fwd": 0.3132447600364685, "rowwise_bwd": 0.3396235406398773, "global_fwd": 0.2943649888038635, "global_bwd": 0.33209100365638733, "x_quantize_rowwise": 0.12357160449028015, "g_quantize_rowwise": 0.035997480154037476, "w_quantize_rowwise": 0.03213062882423401, "w_quantize_colwise_transpose": 0.17676874995231628, "w_quantize_global": 0.09861215949058533, "w_quantize_global_transpose": 0.0998862087726593, "time_standard": 1.6407519578933716, "time_rowwise": 1.5580616891384125, "time_global": 1.5212483704090118}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.2096501886844635, "standard_gw": 1.0663382709026337, "standard_gx": 1.0961703956127167, "rowwise_fwd": 0.6396733224391937, "rowwise_bwd": 0.5173943936824799, "global_fwd": 0.6296299397945404, "global_bwd": 0.5130060017108917, "x_quantize_rowwise": 0.06211921572685242, "g_quantize_rowwise": 0.2361498773097992, "w_quantize_rowwise": 0.03260001540184021, "w_quantize_colwise_transpose": 0.17679482698440552, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.09913742542266846, "time_standard": 3.372158855199814, "time_rowwise": 2.7310699224472046, "time_global": 2.6999935507774353}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.1065565049648285, "standard_gw": 1.0664314031600952, "standard_gx": 1.1266544461250305, "rowwise_fwd": 0.5352050065994263, "rowwise_bwd": 0.6464086472988129, "global_fwd": 0.513765960931778, "global_bwd": 0.6284862756729126, "x_quantize_rowwise": 0.23620948195457458, "g_quantize_rowwise": 0.062271952629089355, "w_quantize_rowwise": 0.031460076570510864, "w_quantize_colwise_transpose": 0.17675384879112244, "w_quantize_global": 0.09486451745033264, "w_quantize_global_transpose": 0.09898096323013306, "time_standard": 3.2996423542499542, "time_rowwise": 2.7547404170036316, "time_global": 2.7010105550289154}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.4367496371269226, "standard_gw": 2.0806193351745605, "standard_gx": 2.19624862074852, "rowwise_fwd": 1.2554042041301727, "rowwise_bwd": 1.0227933526039124, "global_fwd": 1.2322552502155304, "global_bwd": 1.0152235627174377, "x_quantize_rowwise": 0.11792033910751343, "g_quantize_rowwise": 0.4639364778995514, "w_quantize_rowwise": 0.03241002559661865, "w_quantize_colwise_transpose": 0.17657503485679626, "w_quantize_global": 0.09655207395553589, "w_quantize_global_transpose": 0.09958073496818542, "time_standard": 6.713617593050003, "time_rowwise": 5.149658769369125, "time_global": 5.106087774038315}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.1935217082500458, "standard_gw": 2.0055584609508514, "standard_gx": 2.1882541477680206, "rowwise_fwd": 1.0396353900432587, "rowwise_bwd": 1.2542344629764557, "global_fwd": 1.0161921381950378, "global_bwd": 1.233428716659546, "x_quantize_rowwise": 0.4642195999622345, "g_quantize_rowwise": 0.11782720685005188, "w_quantize_rowwise": 0.033117830753326416, "w_quantize_colwise_transpose": 0.17696991562843323, "w_quantize_global": 0.09416043758392334, "w_quantize_global_transpose": 0.10101497173309326, "time_standard": 6.387334316968918, "time_rowwise": 5.091562867164612, "time_global": 5.032401531934738}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.804681986570358, "standard_gw": 4.763372242450714, "standard_gx": 4.064023494720459, "rowwise_fwd": 2.484843134880066, "rowwise_bwd": 1.9691288471221924, "global_fwd": 2.441786229610443, "global_bwd": 1.9574686884880066, "x_quantize_rowwise": 0.2294592559337616, "g_quantize_rowwise": 0.9196549654006958, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.1768544316291809, "w_quantize_global": 0.09644776582717896, "w_quantize_global_transpose": 0.09847059845924377, "time_standard": 13.632077723741531, "time_rowwise": 10.574690997600555, "time_global": 10.506659746170044}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.0907710790634155, "standard_gw": 3.9793066680431366, "standard_gx": 4.302978515625, "rowwise_fwd": 1.992940902709961, "rowwise_bwd": 2.4996213614940643, "global_fwd": 1.9551962614059448, "global_bwd": 2.457551658153534, "x_quantize_rowwise": 0.9200014173984528, "g_quantize_rowwise": 0.2293996512889862, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.17882883548736572, "w_quantize_global": 0.09540095925331116, "w_quantize_global_transpose": 0.09880587458610535, "time_standard": 12.373056262731552, "time_rowwise": 9.831476956605911, "time_global": 9.73566249012947}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.655728936195374, "standard_gw": 8.261296898126602, "standard_gx": 8.064884692430496, "rowwise_fwd": 5.007706582546234, "rowwise_bwd": 3.8615092635154724, "global_fwd": 4.920527338981628, "global_bwd": 3.8330331444740295, "x_quantize_rowwise": 0.45276060700416565, "g_quantize_rowwise": 1.8306002020835876, "w_quantize_rowwise": 0.031366944313049316, "w_quantize_colwise_transpose": 0.1766495406627655, "w_quantize_global": 0.09412690997123718, "w_quantize_global_transpose": 0.09780004620552063, "time_standard": 25.981910526752472, "time_rowwise": 19.621890038251877, "time_global": 19.49014514684677}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 8.033104240894318, "standard_gw": 8.2889124751091, "standard_gx": 8.622754365205765, "rowwise_fwd": 3.8747042417526245, "rowwise_bwd": 5.003921687602997, "global_fwd": 3.8315393030643463, "global_bwd": 4.9162134528160095, "x_quantize_rowwise": 1.8304847180843353, "g_quantize_rowwise": 0.4522763192653656, "w_quantize_rowwise": 0.03413110971450806, "w_quantize_colwise_transpose": 0.1771189272403717, "w_quantize_global": 0.09519979357719421, "w_quantize_global_transpose": 0.09930506348609924, "time_standard": 24.944771081209183, "time_rowwise": 19.661549478769302, "time_global": 19.51393112540245}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7954612374305725, "standard_gw": 0.7456131279468536, "standard_gx": 0.8799619972705841, "rowwise_fwd": 0.43267011642456055, "rowwise_bwd": 0.34622475504875183, "global_fwd": 0.42615458369255066, "global_bwd": 0.344250351190567, "x_quantize_rowwise": 0.03748014569282532, "g_quantize_rowwise": 0.13304129242897034, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.2407953143119812, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10305643081665039, "time_standard": 2.4210363626480103, "time_rowwise": 1.96877121925354, "time_global": 1.8842294812202454}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7120333611965179, "standard_gw": 0.7622130215167999, "standard_gx": 0.8262209594249725, "rowwise_fwd": 0.3702230751514435, "rowwise_bwd": 0.4419572651386261, "global_fwd": 0.3479123115539551, "global_bwd": 0.4306286573410034, "x_quantize_rowwise": 0.13308599591255188, "g_quantize_rowwise": 0.037495046854019165, "w_quantize_rowwise": 0.03398209810256958, "w_quantize_colwise_transpose": 0.23782625794410706, "w_quantize_global": 0.09853765368461609, "w_quantize_global_transpose": 0.10247156023979187, "time_standard": 2.3004673421382904, "time_rowwise": 2.016782760620117, "time_global": 1.9123442471027374}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6292817890644073, "standard_gw": 1.5109702944755554, "standard_gx": 1.482747495174408, "rowwise_fwd": 0.8386112749576569, "rowwise_bwd": 0.6844550371170044, "global_fwd": 0.8220970630645752, "global_bwd": 0.6802082061767578, "x_quantize_rowwise": 0.06883963942527771, "g_quantize_rowwise": 0.25641173124313354, "w_quantize_rowwise": 0.033054500818252563, "w_quantize_colwise_transpose": 0.24027004837989807, "w_quantize_global": 0.0967271625995636, "w_quantize_global_transpose": 0.102948397397995, "time_standard": 4.622999578714371, "time_rowwise": 3.6326125264167786, "time_global": 3.5382024943828583}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4877021312713623, "standard_gw": 1.5015341341495514, "standard_gx": 1.529306173324585, "rowwise_fwd": 0.715944916009903, "rowwise_bwd": 0.8529908955097198, "global_fwd": 0.680088996887207, "global_bwd": 0.8224695920944214, "x_quantize_rowwise": 0.2568177878856659, "g_quantize_rowwise": 0.06864592432975769, "w_quantize_rowwise": 0.03343448042869568, "w_quantize_colwise_transpose": 0.23645907640457153, "w_quantize_global": 0.09399279952049255, "w_quantize_global_transpose": 0.10286271572113037, "time_standard": 4.518542438745499, "time_rowwise": 3.665827214717865, "time_global": 3.5264119505882263}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.261040896177292, "standard_gw": 2.8816498816013336, "standard_gx": 2.8357282280921936, "rowwise_fwd": 1.6594752669334412, "rowwise_bwd": 1.359265297651291, "global_fwd": 1.6287527978420258, "global_bwd": 1.3503879308700562, "x_quantize_rowwise": 0.13146549463272095, "g_quantize_rowwise": 0.5035959184169769, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.24086236953735352, "w_quantize_global": 0.0945068895816803, "w_quantize_global_transpose": 0.10332837700843811, "time_standard": 8.978419005870819, "time_rowwise": 6.8106986582279205, "time_global": 6.693687289953232}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.848360687494278, "standard_gw": 2.8955675661563873, "standard_gx": 3.0499882996082306, "rowwise_fwd": 1.3900883495807648, "rowwise_bwd": 1.6595833003520966, "global_fwd": 1.3514049351215363, "global_bwd": 1.629263162612915, "x_quantize_rowwise": 0.5036592483520508, "g_quantize_rowwise": 0.13118237257003784, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.23709610104560852, "w_quantize_global": 0.0951625406742096, "w_quantize_global_transpose": 0.10216236114501953, "time_standard": 8.793916553258896, "time_rowwise": 6.851561367511749, "time_global": 6.708402186632156}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.4978525042533875, "standard_gw": 6.462603807449341, "standard_gx": 5.5987648665905, "rowwise_fwd": 3.2996535301208496, "rowwise_bwd": 2.6320070028305054, "global_fwd": 3.2426007091999054, "global_bwd": 2.612769603729248, "x_quantize_rowwise": 0.2561397850513458, "g_quantize_rowwise": 0.9984448552131653, "w_quantize_rowwise": 0.033076852560043335, "w_quantize_colwise_transpose": 0.24232640862464905, "w_quantize_global": 0.09618699550628662, "w_quantize_global_transpose": 0.10257214307785034, "time_standard": 18.559221178293228, "time_rowwise": 13.9242522418499, "time_global": 13.771317899227142}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.5702440440654755, "standard_gw": 5.717620253562927, "standard_gx": 6.08203187584877, "rowwise_fwd": 2.649586647748947, "rowwise_bwd": 3.315173089504242, "global_fwd": 2.6132799685001373, "global_bwd": 3.257807344198227, "x_quantize_rowwise": 0.9980201721191406, "g_quantize_rowwise": 0.256560742855072, "w_quantize_rowwise": 0.03356859087944031, "w_quantize_colwise_transpose": 0.23729726672172546, "w_quantize_global": 0.09495764970779419, "w_quantize_global_transpose": 0.103779137134552, "time_standard": 17.369896173477173, "time_rowwise": 13.207826763391495, "time_global": 13.04202526807785}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 13.058379292488098, "standard_gw": 11.480242013931274, "standard_gx": 11.092845350503922, "rowwise_fwd": 6.637874990701675, "rowwise_bwd": 5.24790957570076, "global_fwd": 6.521012634038925, "global_bwd": 5.214303731918335, "x_quantize_rowwise": 0.5057565867900848, "g_quantize_rowwise": 1.989319920539856, "w_quantize_rowwise": 0.03439188003540039, "w_quantize_colwise_transpose": 0.24280324578285217, "w_quantize_global": 0.09520724415779114, "w_quantize_global_transpose": 0.10240450501441956, "time_standard": 35.631466656923294, "time_rowwise": 26.138298213481903, "time_global": 25.908246636390686}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.13397628068924, "standard_gw": 11.371888220310211, "standard_gx": 12.12756335735321, "rowwise_fwd": 5.2495077252388, "rowwise_bwd": 6.638709455728531, "global_fwd": 5.215313285589218, "global_bwd": 6.5222084522247314, "x_quantize_rowwise": 1.9870512187480927, "g_quantize_rowwise": 0.5058236420154572, "w_quantize_rowwise": 0.034634023904800415, "w_quantize_colwise_transpose": 0.23674964904785156, "w_quantize_global": 0.09457767009735107, "w_quantize_global_transpose": 0.10183081030845642, "time_standard": 34.63342785835266, "time_rowwise": 26.024363934993744, "time_global": 25.798693299293518}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2125298380851746, "standard_gw": 1.1111274361610413, "standard_gx": 1.0840706527233124, "rowwise_fwd": 0.6057210266590118, "rowwise_bwd": 0.51865354180336, "global_fwd": 0.5952082574367523, "global_bwd": 0.5167685449123383, "x_quantize_rowwise": 0.045686960220336914, "g_quantize_rowwise": 0.15827640891075134, "w_quantize_rowwise": 0.04361197352409363, "w_quantize_colwise_transpose": 0.34067779779434204, "w_quantize_global": 0.13644620776176453, "w_quantize_global_transpose": 0.14925003051757812, "time_standard": 3.407727926969528, "time_rowwise": 2.823755145072937, "time_global": 2.7127638459205627}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.0731369256973267, "standard_gw": 1.1365897953510284, "standard_gx": 1.1498592793941498, "rowwise_fwd": 0.5573518574237823, "rowwise_bwd": 0.615488737821579, "global_fwd": 0.5220361053943634, "global_bwd": 0.5939789116382599, "x_quantize_rowwise": 0.15765801072120667, "g_quantize_rowwise": 0.04369020462036133, "w_quantize_rowwise": 0.047359615564346313, "w_quantize_colwise_transpose": 0.5526281893253326, "w_quantize_global": 0.13606995344161987, "w_quantize_global_transpose": 0.15017390251159668, "time_standard": 3.359586000442505, "time_rowwise": 3.1107664108276367, "time_global": 2.7401968836784363}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4274885654449463, "standard_gw": 2.1799951791763306, "standard_gx": 2.1426528692245483, "rowwise_fwd": 1.195710152387619, "rowwise_bwd": 1.027170568704605, "global_fwd": 1.1747106909751892, "global_bwd": 1.0251589119434357, "x_quantize_rowwise": 0.08098781108856201, "g_quantize_rowwise": 0.3052949905395508, "w_quantize_rowwise": 0.043764710426330566, "w_quantize_colwise_transpose": 0.33987686038017273, "w_quantize_global": 0.13646483421325684, "w_quantize_global_transpose": 0.14739856123924255, "time_standard": 6.750136613845825, "time_rowwise": 5.172800272703171, "time_global": 5.050010979175568}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1661892533302307, "standard_gw": 2.0948275923728943, "standard_gx": 2.306375652551651, "rowwise_fwd": 1.0587647557258606, "rowwise_bwd": 1.1999905109405518, "global_fwd": 1.0296404361724854, "global_bwd": 1.1749230325222015, "x_quantize_rowwise": 0.3054030239582062, "g_quantize_rowwise": 0.08077546954154968, "w_quantize_rowwise": 0.047225505113601685, "w_quantize_colwise_transpose": 0.600133091211319, "w_quantize_global": 0.13613328337669373, "w_quantize_global_transpose": 0.1484006643295288, "time_standard": 6.567392498254776, "time_rowwise": 5.387119948863983, "time_global": 4.97010350227356}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.807606339454651, "standard_gw": 4.170913249254227, "standard_gx": 4.117622971534729, "rowwise_fwd": 2.370934933423996, "rowwise_bwd": 1.9481778144836426, "global_fwd": 2.3383721709251404, "global_bwd": 1.9443817436695099, "x_quantize_rowwise": 0.1547597348690033, "g_quantize_rowwise": 0.6000511348247528, "w_quantize_rowwise": 0.04361942410469055, "w_quantize_colwise_transpose": 0.3403201699256897, "w_quantize_global": 0.13600289821624756, "w_quantize_global_transpose": 0.1474134624004364, "time_standard": 13.096142560243607, "time_rowwise": 9.628776460886002, "time_global": 9.491894394159317}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.1619837284088135, "standard_gw": 4.181284457445145, "standard_gx": 4.635505378246307, "rowwise_fwd": 1.9684135913848877, "rowwise_bwd": 2.3750364780426025, "global_fwd": 1.9445866346359253, "global_bwd": 2.3551955819129944, "x_quantize_rowwise": 0.6004162132740021, "g_quantize_rowwise": 0.15468522906303406, "w_quantize_rowwise": 0.04730746150016785, "w_quantize_colwise_transpose": 0.5999617278575897, "w_quantize_global": 0.1364201307296753, "w_quantize_global_transpose": 0.14847144484519958, "time_standard": 12.978773564100266, "time_rowwise": 9.927105158567429, "time_global": 9.521059691905975}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.52371209859848, "standard_gw": 8.354485034942627, "standard_gx": 8.69860127568245, "rowwise_fwd": 4.717472940683365, "rowwise_bwd": 3.8843750953674316, "global_fwd": 4.645414650440216, "global_bwd": 3.8761012256145477, "x_quantize_rowwise": 0.3024861216545105, "g_quantize_rowwise": 1.1897757649421692, "w_quantize_rowwise": 0.04366785287857056, "w_quantize_colwise_transpose": 0.33988431096076965, "w_quantize_global": 0.1359507441520691, "w_quantize_global_transpose": 0.14724582433700562, "time_standard": 26.576798409223557, "time_rowwise": 18.832147121429443, "time_global": 18.651459366083145}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.307881653308868, "standard_gw": 8.214320987462997, "standard_gx": 9.21182706952095, "rowwise_fwd": 3.8919784128665924, "rowwise_bwd": 4.72346693277359, "global_fwd": 3.8761794567108154, "global_bwd": 4.673641175031662, "x_quantize_rowwise": 1.1893920600414276, "g_quantize_rowwise": 0.3024972975254059, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.6039328873157501, "w_quantize_global": 0.13624504208564758, "w_quantize_global_transpose": 0.14867261052131653, "time_standard": 25.734029710292816, "time_rowwise": 18.972668796777725, "time_global": 18.540948629379272}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.30372044444084, "standard_gw": 16.480475664138794, "standard_gx": 17.61433482170105, "rowwise_fwd": 9.49602946639061, "rowwise_bwd": 7.768530398607254, "global_fwd": 9.3533955514431, "global_bwd": 7.749464362859726, "x_quantize_rowwise": 0.5977451801300049, "g_quantize_rowwise": 2.3684948682785034, "w_quantize_rowwise": 0.04375725984573364, "w_quantize_colwise_transpose": 0.34042075276374817, "w_quantize_global": 0.13628974556922913, "w_quantize_global_transpose": 0.14671683311462402, "time_standard": 53.398530930280685, "time_rowwise": 37.09545359015465, "time_global": 36.83258220553398}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.041003495454788, "standard_gw": 17.770148813724518, "standard_gx": 17.70009845495224, "rowwise_fwd": 7.756810635328293, "rowwise_bwd": 9.502101689577103, "global_fwd": 7.7384114265441895, "global_bwd": 9.36170294880867, "x_quantize_rowwise": 2.3686252534389496, "g_quantize_rowwise": 0.5980581045150757, "w_quantize_rowwise": 0.04723668098449707, "w_quantize_colwise_transpose": 0.6035342812538147, "w_quantize_global": 0.13603642582893372, "w_quantize_global_transpose": 0.1485198736190796, "time_standard": 53.511250764131546, "time_rowwise": 38.64651545882225, "time_global": 38.121502846479416}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.598241299390793, "standard_gw": 4.294309765100479, "standard_gx": 4.261095076799393, "rowwise_fwd": 2.0976848900318146, "rowwise_bwd": 1.9718967378139496, "global_fwd": 2.0763762295246124, "global_bwd": 1.9703581929206848, "x_quantize_rowwise": 0.08216872811317444, "g_quantize_rowwise": 0.4405900835990906, "w_quantize_rowwise": 0.1553371548652649, "w_quantize_colwise_transpose": 1.6110725700855255, "w_quantize_global": 0.481240451335907, "w_quantize_global_transpose": 0.5061514675617218, "time_standard": 13.153646141290665, "time_rowwise": 10.653059929609299, "time_global": 9.85119491815567}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.35885414481163, "standard_gw": 4.29583340883255, "standard_gx": 4.5370906591415405, "rowwise_fwd": 2.0015686750411987, "rowwise_bwd": 2.097565680742264, "global_fwd": 1.969795674085617, "global_bwd": 2.075403928756714, "x_quantize_rowwise": 0.43984130024909973, "g_quantize_rowwise": 0.08216127753257751, "w_quantize_rowwise": 0.22544339299201965, "w_quantize_colwise_transpose": 2.4342015385627747, "w_quantize_global": 0.48087164759635925, "w_quantize_global_transpose": 0.5099289119243622, "time_standard": 13.19177821278572, "time_rowwise": 11.576615273952484, "time_global": 9.85383614897728}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.09888744354248, "standard_gw": 8.230950683355331, "standard_gx": 8.465446531772614, "rowwise_fwd": 4.182614386081696, "rowwise_bwd": 3.747660666704178, "global_fwd": 4.138719290494919, "global_bwd": 3.74777615070343, "x_quantize_rowwise": 0.15515834093093872, "g_quantize_rowwise": 0.8699297904968262, "w_quantize_rowwise": 0.15544891357421875, "w_quantize_colwise_transpose": 1.6132444143295288, "w_quantize_global": 0.48100948333740234, "w_quantize_global_transpose": 0.5051903426647186, "time_standard": 25.795284658670425, "time_rowwise": 18.955007195472717, "time_global": 18.128734081983566}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.378107100725174, "standard_gw": 8.923027664422989, "standard_gx": 9.049762040376663, "rowwise_fwd": 3.765825182199478, "rowwise_bwd": 4.183519631624222, "global_fwd": 3.744799643754959, "global_bwd": 4.1590481996536255, "x_quantize_rowwise": 0.8693933486938477, "g_quantize_rowwise": 0.1553073525428772, "w_quantize_rowwise": 0.2258792519569397, "w_quantize_colwise_transpose": 2.4386271834373474, "w_quantize_global": 0.4811100661754608, "w_quantize_global_transpose": 0.5102269351482391, "time_standard": 26.350896805524826, "time_rowwise": 20.5615796148777, "time_global": 18.842913210392}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.266115337610245, "standard_gw": 17.671160399913788, "standard_gx": 17.10302010178566, "rowwise_fwd": 8.347474038600922, "rowwise_bwd": 7.514089345932007, "global_fwd": 8.263226598501205, "global_bwd": 7.487393915653229, "x_quantize_rowwise": 0.3021806478500366, "g_quantize_rowwise": 1.7319358885288239, "w_quantize_rowwise": 0.15519559383392334, "w_quantize_colwise_transpose": 1.6133114695549011, "w_quantize_global": 0.48247724771499634, "w_quantize_global_transpose": 0.506427139043808, "time_standard": 53.04029583930969, "time_rowwise": 37.3353473842144, "time_global": 36.44480183720589}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 17.73649826645851, "standard_gw": 16.359902918338776, "standard_gx": 18.0993489921093, "rowwise_fwd": 7.493957877159119, "rowwise_bwd": 8.352488279342651, "global_fwd": 7.486194372177124, "global_bwd": 8.28903540968895, "x_quantize_rowwise": 1.7313472926616669, "g_quantize_rowwise": 0.30205026268959045, "w_quantize_rowwise": 0.2255477011203766, "w_quantize_colwise_transpose": 2.4363920092582703, "w_quantize_global": 0.4815347492694855, "w_quantize_global_transpose": 0.5103759467601776, "time_standard": 52.195750176906586, "time_rowwise": 36.90168634057045, "time_global": 35.16044095158577}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.309611052274704, "standard_gw": 32.85098075866699, "standard_gx": 34.34552624821663, "rowwise_fwd": 16.74525812268257, "rowwise_bwd": 15.026237815618515, "global_fwd": 16.574162989854813, "global_bwd": 14.977734535932541, "x_quantize_rowwise": 0.5954466760158539, "g_quantize_rowwise": 3.4569576382637024, "w_quantize_rowwise": 0.15521422028541565, "w_quantize_colwise_transpose": 1.6133897006511688, "w_quantize_global": 0.4822872579097748, "w_quantize_global_transpose": 0.5065612494945526, "time_standard": 103.50611805915833, "time_rowwise": 70.44348493218422, "time_global": 69.44413110613823}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.40017828345299, "standard_gw": 33.037226647138596, "standard_gx": 36.30436211824417, "rowwise_fwd": 15.043705701828003, "rowwise_bwd": 16.756191849708557, "global_fwd": 15.011314302682877, "global_bwd": 16.580048948526382, "x_quantize_rowwise": 3.4548528492450714, "g_quantize_rowwise": 0.5951337516307831, "w_quantize_rowwise": 0.22584572434425354, "w_quantize_colwise_transpose": 2.4329908192157745, "w_quantize_global": 0.4813261330127716, "w_quantize_global_transpose": 0.5101598799228668, "time_standard": 104.74176704883575, "time_rowwise": 71.54594734311104, "time_global": 69.67006251215935}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 73.40333238244057, "standard_gw": 73.76311346888542, "standard_gx": 70.41774317622185, "rowwise_fwd": 33.37597846984863, "rowwise_bwd": 30.345775187015533, "global_fwd": 33.00366923213005, "global_bwd": 30.218638479709625, "x_quantize_rowwise": 1.1825822293758392, "g_quantize_rowwise": 6.902601569890976, "w_quantize_rowwise": 0.15529245138168335, "w_quantize_colwise_transpose": 1.6109198331832886, "w_quantize_global": 0.48149004578590393, "w_quantize_global_transpose": 0.5066059529781342, "time_standard": 217.58418902754784, "time_rowwise": 147.33626320958138, "time_global": 146.05870097875595}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 71.5160183608532, "standard_gw": 73.76786693930626, "standard_gx": 72.98104092478752, "rowwise_fwd": 30.291248112916946, "rowwise_bwd": 33.36654230952263, "global_fwd": 30.181586742401123, "global_bwd": 33.082425594329834, "x_quantize_rowwise": 6.902430206537247, "g_quantize_rowwise": 1.1815279722213745, "w_quantize_rowwise": 0.2262219786643982, "w_quantize_colwise_transpose": 2.4421699345111847, "w_quantize_global": 0.4816502332687378, "w_quantize_global_transpose": 0.5105249583721161, "time_standard": 218.26492622494698, "time_rowwise": 148.17800745368004, "time_global": 146.1080126464367}
|
138
benchmarking/switchback/make_plot_with_jsonl.py
Normal file
138
benchmarking/switchback/make_plot_with_jsonl.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import matplotlib.gridspec as gridspec
|
||||
|
||||
cmap=plt.get_cmap('cool')
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
|
||||
gs = gridspec.GridSpec(1, 2)
|
||||
|
||||
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
|
||||
batch_size_for_plot1 = 32768
|
||||
batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17]
|
||||
dims_to_xtick = [1024, 2048, 4096]
|
||||
logscale_plot1 = True
|
||||
|
||||
ax = fig.add_subplot(gs[0, 0])
|
||||
|
||||
# TODO: change this to what you want.
|
||||
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
|
||||
df = rdf[rdf.batch_size == batch_size_for_plot1]
|
||||
|
||||
# first plot the time occupied by different operations
|
||||
for k, marker, ls, color, name in [
|
||||
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
|
||||
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
|
||||
|
||||
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
|
||||
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
|
||||
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
|
||||
|
||||
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
|
||||
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
|
||||
|
||||
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
|
||||
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
|
||||
('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'),
|
||||
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'),
|
||||
]:
|
||||
xs = []
|
||||
ys = []
|
||||
for embed_dim in dims_to_consider:
|
||||
# average over dim -> 4*dim and 4*dim -> dim
|
||||
df_ = df[df.dim_in == embed_dim]
|
||||
df_ = df_[df_.dim_out == embed_dim * 4]
|
||||
xs.append(embed_dim)
|
||||
y_ = 0
|
||||
for k_ in k.split('+'):
|
||||
y_ += df_[k_].values[0]
|
||||
df_ = df[df.dim_in == embed_dim * 4]
|
||||
df_ = df_[df_.dim_out == embed_dim]
|
||||
for k_ in k.split('+'):
|
||||
y_ += df_[k_].values[0]
|
||||
ys.append(y_ * 0.5)
|
||||
|
||||
|
||||
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
|
||||
|
||||
|
||||
ax.set_xlabel('dim', fontsize=13)
|
||||
ax.set_ylabel('time (ms)', fontsize=13)
|
||||
|
||||
ax.grid()
|
||||
|
||||
ax.set_xscale('log')
|
||||
if logscale_plot1:
|
||||
ax.set_yscale('log')
|
||||
|
||||
ax.tick_params(axis='x', labelsize=11)
|
||||
ax.tick_params(axis='y', labelsize=11)
|
||||
|
||||
ax.set_xticks(dims_to_xtick)
|
||||
ax.set_xticklabels(dims_to_xtick)
|
||||
ax.set_xticks([], minor=True)
|
||||
|
||||
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
|
||||
leg.get_texts()[0].set_fontweight('bold')
|
||||
leg.get_texts()[1].set_fontweight('bold')
|
||||
plt.subplots_adjust(left=0.1)
|
||||
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
|
||||
|
||||
|
||||
ax = fig.add_subplot(gs[0, 1])
|
||||
|
||||
# now plot the % speedup for different batch sizes
|
||||
for j, batch_size in enumerate(batch_sizes_for_plot2):
|
||||
all_xs, all_ys = [], []
|
||||
for k, marker, ls, color, name in [
|
||||
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
|
||||
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
|
||||
]:
|
||||
|
||||
xs, ys = [], []
|
||||
df = rdf[rdf.batch_size == batch_size]
|
||||
for embed_dim in dims_to_consider:
|
||||
df_ = df[df.dim_in == embed_dim]
|
||||
df_ = df_[df_.dim_out == embed_dim * 4]
|
||||
xs.append(embed_dim)
|
||||
y_ = 0
|
||||
for k_ in k.split('+'):
|
||||
y_ += df_[k_].values[0]
|
||||
df_ = df[df.dim_in == embed_dim * 4]
|
||||
df_ = df_[df_.dim_out == embed_dim]
|
||||
for k_ in k.split('+'):
|
||||
y_ += df_[k_].values[0]
|
||||
ys.append(y_ * 0.5)
|
||||
all_xs.append(xs)
|
||||
all_ys.append(ys)
|
||||
|
||||
color = cmap(j * 0.25)
|
||||
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
|
||||
markers = ['^', 'v', 'P', 'o']
|
||||
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
|
||||
|
||||
ax.legend()
|
||||
ax.set_xlabel('dim', fontsize=13)
|
||||
ax.set_xscale('log')
|
||||
ax.grid()
|
||||
ax.set_ylabel(r'% speedup', fontsize=13)
|
||||
|
||||
|
||||
ax.tick_params(axis='x', labelsize=11)
|
||||
ax.tick_params(axis='y', labelsize=11)
|
||||
|
||||
ax.set_xticks(dims_to_xtick)
|
||||
ax.set_xticklabels(dims_to_xtick)
|
||||
ax.set_xticks([], minor=True)
|
||||
|
||||
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
|
||||
|
||||
|
||||
|
||||
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
|
||||
|
BIN
benchmarking/switchback/plot_with_info.pdf
Normal file
BIN
benchmarking/switchback/plot_with_info.pdf
Normal file
Binary file not shown.
102
benchmarking/switchback/speed_benchmark.py
Normal file
102
benchmarking/switchback/speed_benchmark.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
import json
|
||||
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
||||
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
|
||||
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
||||
|
||||
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
|
||||
|
||||
def get_time(k, fn, info_dict):
|
||||
|
||||
for _ in range(repeat // 2):
|
||||
fn()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(repeat):
|
||||
fn()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
ms = (end - start) / repeat * 1000
|
||||
print(f"time {k}: {ms:.3f} ms")
|
||||
info_dict[k] = ms
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(0)
|
||||
wm = 4
|
||||
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
|
||||
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
|
||||
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
|
||||
|
||||
# switch switches dim_in and dim_out
|
||||
for switch in [False, True]:
|
||||
|
||||
# hparams
|
||||
repeat = 64
|
||||
batch_size = batch_size
|
||||
dim_out = dim * wm
|
||||
dim_in = dim
|
||||
if switch:
|
||||
dim_out = dim
|
||||
dim_in = wm * dim
|
||||
|
||||
dim_in = round(dim_in)
|
||||
dim_out = round(dim_out)
|
||||
|
||||
# simulate forward pass
|
||||
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
|
||||
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
|
||||
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
|
||||
|
||||
x_int8 = x.clone().to(torch.int8)
|
||||
g_int8 = g.clone().to(torch.int8)
|
||||
w_int8 = w.clone().to(torch.int8)
|
||||
wt_int8 = w.t().contiguous().clone().to(torch.int8)
|
||||
state_x_rowwise = x.max(dim=1)[0]
|
||||
state_g_rowwise = g.max(dim=1)[0]
|
||||
state_w_columnwise = w.max(dim=0)[0]
|
||||
state_w_rowwise = w.max(dim=1)[0]
|
||||
state_w_global = w.max()
|
||||
|
||||
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
|
||||
|
||||
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
|
||||
get_time('standard_gw', lambda : g.t().matmul(x), info)
|
||||
get_time('standard_gx', lambda : g.matmul(w), info)
|
||||
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
|
||||
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
|
||||
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
|
||||
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
|
||||
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
|
||||
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
|
||||
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
|
||||
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
|
||||
get_time('w_quantize_global', lambda : quantize_global(w), info)
|
||||
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
|
||||
|
||||
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
|
||||
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
|
||||
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
|
||||
|
||||
print('TOTAL STANDARD', time_standard)
|
||||
print('TOTAL ROWWISE', time_rowwise)
|
||||
print('TOTAL GLOBAL', time_global)
|
||||
|
||||
print('speedup', -100*(time_global - time_standard)/time_standard)
|
||||
|
||||
info['time_standard'] = time_standard
|
||||
info['time_rowwise'] = time_rowwise
|
||||
info['time_global'] = time_global
|
||||
|
||||
info_json = json.dumps(info)
|
||||
|
||||
# TODO: change this to what you want.
|
||||
with open("speed_benchmark/info.jsonl", "a") as file:
|
||||
file.write(info_json + "\n")
|
|
@ -3,7 +3,7 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import cuda_setup, utils
|
||||
from . import cuda_setup, utils, research
|
||||
from .autograd._functions import (
|
||||
MatmulLtState,
|
||||
bmm_cublas,
|
||||
|
|
|
@ -1,11 +1,82 @@
|
|||
import os
|
||||
import sys
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
from warnings import warn
|
||||
from typing import Tuple
|
||||
from os.path import isdir
|
||||
|
||||
import torch
|
||||
|
||||
HEADER_WIDTH = 60
|
||||
|
||||
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
||||
def _decode(subprocess_err_out_tuple):
|
||||
return tuple(
|
||||
to_decode.decode("UTF-8").strip()
|
||||
for to_decode in subprocess_err_out_tuple
|
||||
)
|
||||
|
||||
def execute_and_return_decoded_std_streams(command_string):
|
||||
return _decode(
|
||||
subprocess.Popen(
|
||||
shlex.split(command_string),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
).communicate()
|
||||
)
|
||||
|
||||
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
|
||||
return std_out, std_err
|
||||
|
||||
def find_file_recursive(folder, filename):
|
||||
cmd = f'find {folder} -name {filename}'
|
||||
out, err = execute_and_return(cmd)
|
||||
if len(err) > 0:
|
||||
raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?')
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def generate_bug_report_information():
|
||||
print_header("")
|
||||
print_header("BUG REPORT INFORMATION")
|
||||
print_header("")
|
||||
print('')
|
||||
|
||||
if 'CONDA_PREFIX' in os.environ:
|
||||
paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so')
|
||||
print_header("ANACONDA CUDA PATHS")
|
||||
print(paths)
|
||||
print('')
|
||||
if isdir('/usr/local/'):
|
||||
paths = find_file_recursive('/usr/local', '*cuda*so')
|
||||
print_header("/usr/local CUDA PATHS")
|
||||
print(paths)
|
||||
print('')
|
||||
|
||||
if isdir(os.getcwd()):
|
||||
paths = find_file_recursive(os.getcwd(), '*cuda*so')
|
||||
print_header("WORKING DIRECTORY CUDA PATHS")
|
||||
print(paths)
|
||||
print('')
|
||||
|
||||
print_header("LD_LIBRARY CUDA PATHS")
|
||||
lib_path = os.environ['LD_LIBRARY_PATH'].strip()
|
||||
for path in set(lib_path.split(':')):
|
||||
try:
|
||||
if isdir(path):
|
||||
print_header(f"{path} CUDA PATHS")
|
||||
paths = find_file_recursive(path, '*cuda*so')
|
||||
print(paths)
|
||||
except:
|
||||
print(f'Could not read LD_LIBRARY_PATH: {path}')
|
||||
print('')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def print_header(
|
||||
txt: str, width: int = HEADER_WIDTH, filler: str = "+"
|
||||
|
@ -21,28 +92,16 @@ def print_debug_info() -> None:
|
|||
)
|
||||
|
||||
|
||||
print_header("")
|
||||
print_header("DEBUG INFORMATION")
|
||||
print_header("")
|
||||
print()
|
||||
generate_bug_report_information()
|
||||
|
||||
|
||||
from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL
|
||||
from .cuda_setup.env_vars import to_be_ignored
|
||||
from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle
|
||||
|
||||
print_header("POTENTIALLY LIBRARY-PATH-LIKE ENV VARS")
|
||||
for k, v in os.environ.items():
|
||||
if "/" in v and not to_be_ignored(k, v):
|
||||
print(f"'{k}': '{v}'")
|
||||
print_header("")
|
||||
|
||||
print(
|
||||
"\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n"
|
||||
)
|
||||
|
||||
print_header("OTHER")
|
||||
print(f"{COMPILED_WITH_CUDA = }")
|
||||
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
|
||||
cuda = get_cuda_lib_handle()
|
||||
print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}")
|
||||
print_header("")
|
||||
|
@ -55,6 +114,7 @@ Running a quick check that:
|
|||
+ CUDA function is callable
|
||||
"""
|
||||
)
|
||||
print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n")
|
||||
|
||||
try:
|
||||
from bitsandbytes.optim import Adam
|
||||
|
@ -91,3 +151,4 @@ except Exception as e:
|
|||
print(e)
|
||||
print_debug_info()
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
@ -221,9 +221,20 @@ bmm_cublas = MatMul8bit.apply
|
|||
matmul_cublas = MatMul8bit.apply
|
||||
|
||||
|
||||
def supports_igemmlt(device: torch.device) -> bool:
|
||||
"""check if this device supports the optimized int8 kernel"""
|
||||
if torch.cuda.get_device_capability(device=device) < (7, 5):
|
||||
return False
|
||||
device_name = torch.cuda.get_device_name(device=device)
|
||||
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
|
||||
if any(model_name in device_name for model_name in nvidia16_models):
|
||||
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatmulLtState:
|
||||
tile_indices: Optional[torch.Tensor] = None
|
||||
_tile_indices: Optional[torch.Tensor] = None
|
||||
force_no_igemmlt: bool = False
|
||||
CB = None
|
||||
CxB = None
|
||||
|
@ -263,6 +274,15 @@ class MatmulLtState:
|
|||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
||||
|
||||
@property
|
||||
def tile_indices(self):
|
||||
if self._tile_indices is None:
|
||||
device = self.CxB.device
|
||||
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
|
||||
return self._tile_indices
|
||||
|
||||
|
||||
class MatMul8bitLt(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
|
@ -270,7 +290,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
|
||||
using_igemmlt = torch.cuda.get_device_capability(device=A.device) >= (7, 5) and not state.force_no_igemmlt
|
||||
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
|
@ -456,13 +476,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
elif state.CxB is not None:
|
||||
|
||||
if state.tile_indices is None:
|
||||
order, tile_size = state.formatB, state.get_tile_size()
|
||||
transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
state.tile_indices = get_inverse_transform_indices(transform, tile_size).to(state.CxB.device)
|
||||
|
||||
CB = (
|
||||
undo_layout(state.CxB, state.tile_indices)
|
||||
.to(ctx.dtype_A)
|
||||
|
|
|
@ -9,10 +9,8 @@ from bitsandbytes.cuda_setup.main import CUDASetup
|
|||
|
||||
|
||||
setup = CUDASetup.get_instance()
|
||||
if not setup.initialized:
|
||||
if setup.initialized != True:
|
||||
setup.run_cuda_setup()
|
||||
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
|
||||
setup.print_log_stack()
|
||||
|
||||
lib = setup.lib
|
||||
try:
|
||||
|
@ -20,15 +18,25 @@ try:
|
|||
CUDASetup.get_instance().generate_instructions()
|
||||
CUDASetup.get_instance().print_log_stack()
|
||||
raise RuntimeError('''
|
||||
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
|
||||
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
|
||||
https://github.com/TimDettmers/bitsandbytes/issues''')
|
||||
lib.cadam_8bit_blockwise_fp32
|
||||
CUDA Setup failed despite GPU being available. Please run the following command to get more information:
|
||||
|
||||
python -m bitsandbytes
|
||||
|
||||
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
|
||||
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
|
||||
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
|
||||
lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
|
||||
lib.get_context.restype = ct.c_void_p
|
||||
lib.get_cusparse.restype = ct.c_void_p
|
||||
lib.cget_managed_ptr.restype = ct.c_void_p
|
||||
COMPILED_WITH_CUDA = True
|
||||
except AttributeError:
|
||||
except AttributeError as ex:
|
||||
warn("The installed version of bitsandbytes was compiled without GPU support. "
|
||||
"8-bit optimizers and GPU quantization are unavailable.")
|
||||
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
|
||||
COMPILED_WITH_CUDA = False
|
||||
print(str(ex))
|
||||
|
||||
|
||||
# print the setup details after checking for errors so we do not print twice
|
||||
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
|
||||
setup.print_log_stack()
|
||||
|
|
|
@ -11,6 +11,7 @@ def to_be_ignored(env_var: str, value: str) -> bool:
|
|||
"HOME", # Linux shell default
|
||||
"TMUX", # Terminal Multiplexer
|
||||
"XDG_DATA_DIRS", # XDG: Desktop environment stuff
|
||||
"XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff
|
||||
"XDG_RUNTIME_DIR",
|
||||
"MAIL", # something related to emails
|
||||
"SHELL", # binary for currently invoked shell
|
||||
|
|
|
@ -21,12 +21,21 @@ import os
|
|||
import errno
|
||||
import torch
|
||||
from warnings import warn
|
||||
from itertools import product
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Set, Union
|
||||
from .env_vars import get_potentially_lib_path_containing_env_vars
|
||||
|
||||
CUDA_RUNTIME_LIB: str = "libcudart.so"
|
||||
# these are the most common libs names
|
||||
# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead
|
||||
# we have libcudart.so.11.0 which causes a lot of errors before
|
||||
# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt
|
||||
CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0']
|
||||
|
||||
# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths
|
||||
backup_paths = []
|
||||
backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0')
|
||||
|
||||
class CUDASetup:
|
||||
_instance = None
|
||||
|
@ -102,6 +111,8 @@ class CUDASetup:
|
|||
package_dir = Path(__file__).parent.parent
|
||||
binary_path = package_dir / binary_name
|
||||
|
||||
print('bin', binary_path)
|
||||
|
||||
try:
|
||||
if not binary_path.exists():
|
||||
self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?")
|
||||
|
@ -121,7 +132,6 @@ class CUDASetup:
|
|||
self.add_log_entry('='*80)
|
||||
self.add_log_entry('')
|
||||
self.generate_instructions()
|
||||
self.print_log_stack()
|
||||
raise Exception('CUDA SETUP: Setup Failed!')
|
||||
self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||
else:
|
||||
|
@ -129,7 +139,6 @@ class CUDASetup:
|
|||
self.lib = ct.cdll.LoadLibrary(binary_path)
|
||||
except Exception as ex:
|
||||
self.add_log_entry(str(ex))
|
||||
self.print_log_stack()
|
||||
|
||||
def add_log_entry(self, msg, is_warning=False):
|
||||
self.cuda_setup_log.append((msg, is_warning))
|
||||
|
@ -154,7 +163,7 @@ def is_cublasLt_compatible(cc):
|
|||
if cc is not None:
|
||||
cc_major, cc_minor = cc.split('.')
|
||||
if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5):
|
||||
cuda_setup.add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
|
||||
CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True)
|
||||
else:
|
||||
has_cublaslt = True
|
||||
return has_cublaslt
|
||||
|
@ -182,11 +191,12 @@ def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]:
|
|||
|
||||
|
||||
def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]:
|
||||
return {
|
||||
path / CUDA_RUNTIME_LIB
|
||||
for path in candidate_paths
|
||||
if (path / CUDA_RUNTIME_LIB).is_file()
|
||||
}
|
||||
paths = set()
|
||||
for libname in CUDA_RUNTIME_LIBS:
|
||||
for path in candidate_paths:
|
||||
if (path / libname).is_file():
|
||||
paths.add(path / libname)
|
||||
return paths
|
||||
|
||||
|
||||
def resolve_paths_list(paths_list_candidate: str) -> Set[Path]:
|
||||
|
@ -206,12 +216,12 @@ def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]:
|
|||
def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None:
|
||||
if len(results_paths) > 1:
|
||||
warning_msg = (
|
||||
f"Found duplicate {CUDA_RUNTIME_LIB} files: {results_paths}.. "
|
||||
f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. "
|
||||
"We'll flip a coin and try one of these, in order to fail forward.\n"
|
||||
"Either way, this might cause trouble in the future:\n"
|
||||
"If you get `CUDA error: invalid device function` errors, the above "
|
||||
"might be the cause and the solution is to make sure only one "
|
||||
f"{CUDA_RUNTIME_LIB} in the paths that we search based on your env.")
|
||||
f"{CUDA_RUNTIME_LIBS} in the paths that we search based on your env.")
|
||||
CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True)
|
||||
|
||||
|
||||
|
@ -239,7 +249,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
|
|||
return next(iter(conda_cuda_libs))
|
||||
|
||||
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain '
|
||||
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
|
||||
f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
|
||||
|
||||
if "LD_LIBRARY_PATH" in candidate_env_vars:
|
||||
lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"])
|
||||
|
@ -249,7 +259,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
|
|||
warn_in_case_of_duplicates(lib_ld_cuda_libs)
|
||||
|
||||
CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '
|
||||
f'{CUDA_RUNTIME_LIB} as expected! Searching further paths...', is_warning=True)
|
||||
f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True)
|
||||
|
||||
remaining_candidate_env_vars = {
|
||||
env_var: value for env_var, value in candidate_env_vars.items()
|
||||
|
@ -261,7 +271,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]:
|
|||
cuda_runtime_libs.update(find_cuda_lib_in(value))
|
||||
|
||||
if len(cuda_runtime_libs) == 0:
|
||||
CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching /usr/local/cuda/lib64...')
|
||||
CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...')
|
||||
cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64'))
|
||||
|
||||
warn_in_case_of_duplicates(cuda_runtime_libs)
|
||||
|
@ -367,9 +377,10 @@ def evaluate_cuda_setup():
|
|||
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
|
||||
print('')
|
||||
print('='*35 + 'BUG REPORT' + '='*35)
|
||||
print('Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')
|
||||
print(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'),
|
||||
('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues'))
|
||||
print('='*80)
|
||||
if not torch.cuda.is_available(): return 'libsbitsandbytes_cpu.so', None, None, None, None
|
||||
if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None, None
|
||||
|
||||
cuda_setup = CUDASetup.get_instance()
|
||||
cudart_path = determine_cuda_runtime_lib_path()
|
||||
|
|
|
@ -28,59 +28,71 @@ name2qmap = {}
|
|||
if COMPILED_WITH_CUDA:
|
||||
"""C FUNCTIONS FOR OPTIMIZERS"""
|
||||
str2optimizer32bit = {}
|
||||
str2optimizer32bit["adam"] = (lib.cadam32bit_gfp32, lib.cadam32bit_gfp16, lib.cadam32bit_gbf16)
|
||||
str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16)
|
||||
str2optimizer32bit["momentum"] = (
|
||||
lib.cmomentum32bit_g32,
|
||||
lib.cmomentum32bit_g16,
|
||||
lib.cmomentum32bit_grad_32,
|
||||
lib.cmomentum32bit_grad_16,
|
||||
)
|
||||
str2optimizer32bit["rmsprop"] = (
|
||||
lib.crmsprop32bit_g32,
|
||||
lib.crmsprop32bit_g16,
|
||||
lib.crmsprop32bit_grad_32,
|
||||
lib.crmsprop32bit_grad_16,
|
||||
)
|
||||
str2optimizer32bit["lion"] = (
|
||||
lib.clion32bit_grad_32,
|
||||
lib.clion32bit_grad_16,
|
||||
)
|
||||
str2optimizer32bit["adagrad"] = (
|
||||
lib.cadagrad32bit_g32,
|
||||
lib.cadagrad32bit_g16,
|
||||
lib.cadagrad32bit_grad_32,
|
||||
lib.cadagrad32bit_grad_16,
|
||||
)
|
||||
|
||||
str2optimizer8bit = {}
|
||||
str2optimizer8bit["adam"] = (
|
||||
lib.cadam_static_8bit_g32,
|
||||
lib.cadam_static_8bit_g16,
|
||||
lib.cadam_static_8bit_grad_32,
|
||||
lib.cadam_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["momentum"] = (
|
||||
lib.cmomentum_static_8bit_g32,
|
||||
lib.cmomentum_static_8bit_g16,
|
||||
lib.cmomentum_static_8bit_grad_32,
|
||||
lib.cmomentum_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["rmsprop"] = (
|
||||
lib.crmsprop_static_8bit_g32,
|
||||
lib.crmsprop_static_8bit_g16,
|
||||
lib.crmsprop_static_8bit_grad_32,
|
||||
lib.crmsprop_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["lion"] = (
|
||||
lib.clion_static_8bit_grad_32,
|
||||
lib.clion_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["lamb"] = (
|
||||
lib.cadam_static_8bit_g32,
|
||||
lib.cadam_static_8bit_g16,
|
||||
lib.cadam_static_8bit_grad_32,
|
||||
lib.cadam_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["lars"] = (
|
||||
lib.cmomentum_static_8bit_g32,
|
||||
lib.cmomentum_static_8bit_g16,
|
||||
lib.cmomentum_static_8bit_grad_32,
|
||||
lib.cmomentum_static_8bit_grad_16,
|
||||
)
|
||||
|
||||
str2optimizer8bit_blockwise = {}
|
||||
str2optimizer8bit_blockwise["adam"] = (
|
||||
lib.cadam_8bit_blockwise_fp32,
|
||||
lib.cadam_8bit_blockwise_fp16,
|
||||
lib.cadam_8bit_blockwise_bf16,
|
||||
lib.cadam_8bit_blockwise_grad_fp32,
|
||||
lib.cadam_8bit_blockwise_grad_fp16,
|
||||
lib.cadam_8bit_blockwise_grad_bf16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["momentum"] = (
|
||||
lib.cmomentum_8bit_blockwise_fp32,
|
||||
lib.cmomentum_8bit_blockwise_fp16,
|
||||
lib.cmomentum_8bit_blockwise_grad_fp32,
|
||||
lib.cmomentum_8bit_blockwise_grad_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["rmsprop"] = (
|
||||
lib.crmsprop_8bit_blockwise_fp32,
|
||||
lib.crmsprop_8bit_blockwise_fp16,
|
||||
lib.crmsprop_8bit_blockwise_grad_fp32,
|
||||
lib.crmsprop_8bit_blockwise_grad_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["lion"] = (
|
||||
lib.clion_8bit_blockwise_grad_fp32,
|
||||
lib.clion_8bit_blockwise_grad_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["adagrad"] = (
|
||||
lib.cadagrad_8bit_blockwise_fp32,
|
||||
lib.cadagrad_8bit_blockwise_fp16,
|
||||
lib.cadagrad_8bit_blockwise_grad_fp32,
|
||||
lib.cadagrad_8bit_blockwise_grad_fp16,
|
||||
)
|
||||
|
||||
class GlobalPageManager:
|
||||
|
@ -327,7 +339,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
values = []
|
||||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||
#for ev in evalues:
|
||||
bias = 2**(exponent_bits-1)-1
|
||||
bias = 2**(exponent_bits-1)
|
||||
for evalue in range(2**(exponent_bits)):
|
||||
for bit_pattern in lst:
|
||||
value = (1 if evalue != 0 else 0)
|
||||
|
@ -335,10 +347,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
value += pval*(2**-(i+1))
|
||||
if evalue == 0:
|
||||
# subnormals
|
||||
value = value*2**-(bias-1)
|
||||
value = value*2**-(bias)
|
||||
else:
|
||||
# normals
|
||||
value = value*2**-(evalue-bias-2)
|
||||
value = value*2**-(evalue-bias-1)
|
||||
values.append(value)
|
||||
if signed:
|
||||
values.append(-value)
|
||||
|
@ -624,7 +636,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
|
|||
return out
|
||||
|
||||
|
||||
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
|
||||
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
|
||||
"""
|
||||
Quantize tensor A in blocks of size 4096 values.
|
||||
|
||||
|
@ -640,8 +652,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
The quantization map.
|
||||
absmax : torch.Tensor
|
||||
The absmax values.
|
||||
rand : torch.Tensor
|
||||
The tensor for stochastic rounding.
|
||||
out : torch.Tensor
|
||||
The output tensor (8-bit).
|
||||
|
||||
|
@ -673,30 +683,17 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
cblocksize = ct.c_int32(blocksize)
|
||||
prev_device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if rand is not None:
|
||||
is_on_gpu([code, A, out, absmax, rand])
|
||||
assert blocksize==4096
|
||||
assert rand.numel() >= 1024
|
||||
rand_offset = random.randint(0, 1023)
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
is_on_gpu([code, A, out, absmax])
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
is_on_gpu([code, A, out, absmax])
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
else:
|
||||
# cpu
|
||||
code = code.cpu()
|
||||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
if nested:
|
||||
|
@ -754,13 +751,16 @@ def dequantize_blockwise(
|
|||
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
|
||||
if quant_state is None:
|
||||
quant_state = (absmax, code, blocksize)
|
||||
quant_state = (absmax, code, blocksize)
|
||||
assert absmax is not None and out is not None
|
||||
else:
|
||||
absmax, code, blocksize, nested, offset, state2 = quant_state
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
absmax, code, blocksize, nested, offset, state2 = quant_state
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
device = pre_call(A.device)
|
||||
|
@ -994,9 +994,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
|||
torch.Tensor:
|
||||
Quantized 8-bit tensor.
|
||||
'''
|
||||
prev_device = pre_call(A.device)
|
||||
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
is_on_gpu([A, out])
|
||||
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
||||
post_call(prev_device)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -1021,9 +1023,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
|||
torch.Tensor:
|
||||
32-bit output tensor.
|
||||
'''
|
||||
prev_device = pre_call(A.device)
|
||||
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
|
||||
is_on_gpu([code, A, out])
|
||||
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
||||
post_call(prev_device)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -1196,6 +1200,8 @@ def optimizer_update_8bit(
|
|||
if max_unorm > 0.0:
|
||||
param_norm = torch.norm(p.data.float())
|
||||
|
||||
prev_device = pre_call(g.device)
|
||||
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit[optimizer_name][0](
|
||||
get_ptr(p),
|
||||
|
@ -1248,6 +1254,7 @@ def optimizer_update_8bit(
|
|||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
post_call(prev_device)
|
||||
|
||||
|
||||
def optimizer_update_8bit_blockwise(
|
||||
|
@ -1271,6 +1278,8 @@ def optimizer_update_8bit_blockwise(
|
|||
) -> None:
|
||||
|
||||
optim_func = None
|
||||
prev_device = pre_call(g.device)
|
||||
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
||||
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
|
||||
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
|
||||
|
@ -1282,6 +1291,7 @@ def optimizer_update_8bit_blockwise(
|
|||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
post_call(prev_device)
|
||||
|
||||
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
|
||||
|
||||
|
@ -1320,6 +1330,7 @@ def percentile_clipping(
|
|||
The current optimiation steps (number of past gradient norms).
|
||||
|
||||
"""
|
||||
prev_device = pre_call(grad.device)
|
||||
is_on_gpu([grad, gnorm_vec])
|
||||
if grad.dtype == torch.float32:
|
||||
lib.cpercentile_clipping_g32(
|
||||
|
@ -1337,6 +1348,7 @@ def percentile_clipping(
|
|||
)
|
||||
else:
|
||||
raise ValueError(f"Gradient type {grad.dtype} not supported!")
|
||||
post_call(prev_device)
|
||||
|
||||
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
|
||||
vals, idx = torch.sort(gnorm_vec)
|
||||
|
@ -2210,6 +2222,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
|
|||
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
|
||||
)
|
||||
nnz = cooA.nnz
|
||||
prev_device = pre_call(B.device)
|
||||
assert cooA.rowidx.numel() == nnz
|
||||
assert cooA.colidx.numel() == nnz
|
||||
assert cooA.values.numel() == nnz
|
||||
|
@ -2284,6 +2297,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
|
|||
ccolsB,
|
||||
)
|
||||
# else: assertion error
|
||||
post_call(prev_device)
|
||||
|
||||
return out
|
||||
|
||||
|
|
|
@ -2,4 +2,5 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb
|
||||
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear
|
||||
|
|
|
@ -9,7 +9,10 @@ import torch.nn.functional as F
|
|||
from torch import Tensor, device, dtype, nn
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional
|
||||
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
|
||||
from bitsandbytes.optim import GlobalOptimManager
|
||||
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
@ -320,6 +323,53 @@ class Linear8bitLt(nn.Linear):
|
|||
|
||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
|
||||
# reorder weight layout back from ampere/turing to row
|
||||
reorder_layout = True
|
||||
weight_clone = self.weight.data.clone()
|
||||
else:
|
||||
reorder_layout = False
|
||||
|
||||
try:
|
||||
if reorder_layout:
|
||||
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
|
||||
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||
weight_name = "SCB"
|
||||
|
||||
# case 1: .cuda was called, SCB is in self.weight
|
||||
param_from_weight = getattr(self.weight, weight_name)
|
||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||
param_from_state = getattr(self.state, weight_name)
|
||||
|
||||
key_name = prefix + f"{weight_name}"
|
||||
if param_from_weight is not None:
|
||||
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
||||
elif not self.state.has_fp16_weights and param_from_state is not None:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
finally:
|
||||
if reorder_layout:
|
||||
self.weight.data = weight_clone
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
for key in unexpected_keys:
|
||||
input_name = key[len(prefix):]
|
||||
if input_name == "SCB":
|
||||
if self.weight.SCB is None:
|
||||
# buffers not yet initialized, can't call them directly without
|
||||
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||
"not supported. Please call module.cuda() before module.load_state_dict()")
|
||||
|
||||
input_param = state_dict[key]
|
||||
self.weight.SCB.copy_(input_param)
|
||||
unexpected_keys.remove(key)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
|
@ -336,6 +386,7 @@ class Linear8bitLt(nn.Linear):
|
|||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
|
@ -343,3 +394,71 @@ class Linear8bitLt(nn.Linear):
|
|||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
class OutlierAwareLinear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.outlier_dim = None
|
||||
self.is_quantized = False
|
||||
|
||||
def forward_with_outliers(self, x, outlier_idx):
|
||||
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')
|
||||
|
||||
def quantize_weight(self, w, outlier_idx):
|
||||
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')
|
||||
|
||||
def forward(self, x):
|
||||
if self.outlier_dim is None:
|
||||
tracer = OutlierTracer.get_instance()
|
||||
if not tracer.is_initialized():
|
||||
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
|
||||
outlier_idx = tracer.get_outliers(self.weight)
|
||||
#print(outlier_idx, tracer.get_hvalue(self.weight))
|
||||
self.outlier_dim = outlier_idx
|
||||
|
||||
if not self.is_quantized:
|
||||
w = self.quantize_weight(self.weight, self.outlier_dim)
|
||||
self.weight.data.copy_(w)
|
||||
self.is_quantized = True
|
||||
|
||||
class SwitchBackLinearBnb(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
input_features,
|
||||
output_features,
|
||||
bias=True,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
):
|
||||
super().__init__(
|
||||
input_features, output_features, bias
|
||||
)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(
|
||||
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
|
||||
)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x):
|
||||
self.state.is_training = self.training
|
||||
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||
|
|
258
bitsandbytes/nn/triton_based_modules.py
Normal file
258
bitsandbytes/nn/triton_based_modules.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
|
||||
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
||||
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
|
||||
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
||||
|
||||
|
||||
class _switchback_global(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, X_3D, W, bias):
|
||||
# reshape input to [N * L, D]
|
||||
X = X_3D.view(-1, X_3D.size(-1))
|
||||
|
||||
# rowwise quantize for X, global quantize for W
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
W_int8, state_W = quantize_global(W)
|
||||
|
||||
# save for backward.
|
||||
ctx.save_for_backward = X, W
|
||||
|
||||
# matmult, fused dequant and add bias
|
||||
# call "mixed" because we are mixing rowwise quantized and global quantized
|
||||
return int8_matmul_mixed_dequanitze(
|
||||
X_int8, W_int8.t(), state_X, state_W, bias
|
||||
).view(*X_3D.size()[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
# reshape input to [N_out * L, D]
|
||||
G = G_3D.reshape(-1, G_3D.size(-1))
|
||||
|
||||
grad_X = grad_W = grad_bias = None
|
||||
|
||||
X, W = ctx.save_for_backward
|
||||
if ctx.needs_input_grad[0]:
|
||||
# rowwise quantize for G, global quantize for W
|
||||
# for W, we also fuse the transpose operation because only A @ B^T is supported
|
||||
# so we transpose once then call .t() in the matmul
|
||||
G_int8, state_G = quantize_rowwise(G)
|
||||
W_int8, state_W = quantize_global_transpose(W)
|
||||
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
|
||||
*G_3D.size()[:-1], -1
|
||||
)
|
||||
if ctx.needs_input_grad[1]:
|
||||
# backward pass uses standard weight grad
|
||||
grad_W = torch.matmul(G.t(), X.to(G.dtype))
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = G.sum(dim=0)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
class _switchback_vectorrize(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, X_3D, W, bias):
|
||||
# reshape input to [N * L, D]
|
||||
X = X_3D.view(-1, X_3D.size(-1))
|
||||
|
||||
ctx.save_for_backward = X, W
|
||||
# rowwise quantize for X
|
||||
# columnwise quantize for W (first rowwise, transpose later)
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
W_int8, state_W = quantize_rowwise(W)
|
||||
|
||||
# matmult, fused dequant and add bias
|
||||
# call kernel which expects rowwise quantized X and W
|
||||
return int8_matmul_rowwise_dequantize(
|
||||
X_int8, W_int8.t(), state_X, state_W, bias
|
||||
).view(*X_3D.size()[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
X, W = ctx.save_for_backward
|
||||
|
||||
G = G_3D.reshape(-1, G_3D.size(-1))
|
||||
|
||||
grad_X = grad_W = grad_bias = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
# rowwise quantize for G, columnwise quantize for W and fused transpose
|
||||
# we call .t() for weight later because only A @ B^T is supported
|
||||
G_int8, state_G = quantize_rowwise(G)
|
||||
W_int8, state_W = quantize_columnwise_and_transpose(W)
|
||||
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
|
||||
*G_3D.size()[:-1], -1
|
||||
)
|
||||
if ctx.needs_input_grad[1]:
|
||||
# backward pass uses standard weight grad
|
||||
grad_W = torch.matmul(G.t(), X.to(G.dtype))
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = G.sum(dim=0)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
class _switchback_global_mem_efficient(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, X_3D, W, bias):
|
||||
# reshape input to [N * L, D]
|
||||
X = X_3D.view(-1, X_3D.size(-1))
|
||||
X_3D_sz = X_3D.size()
|
||||
|
||||
# rowwise quantize for X, global quantize for W
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
del X
|
||||
W_int8, state_W = quantize_global(W)
|
||||
|
||||
# save for backward.
|
||||
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
|
||||
|
||||
# matmult, fused dequant and add bias
|
||||
# call "mixed" because we are mixing rowwise quantized and global quantized
|
||||
return int8_matmul_mixed_dequanitze(
|
||||
X_int8, W_int8.t(), state_X, state_W, bias
|
||||
).view(*X_3D_sz[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
# reshape input to [N_out * L, D]
|
||||
G = G_3D.reshape(-1, G_3D.size(-1))
|
||||
G_3D_sz = G_3D.size()
|
||||
|
||||
grad_X = grad_W = grad_bias = None
|
||||
|
||||
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
|
||||
if ctx.needs_input_grad[1]:
|
||||
real_X = dequantize_rowwise(X_int8, state_X)
|
||||
del X_int8
|
||||
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
|
||||
del real_X
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = G.sum(dim=0)
|
||||
if ctx.needs_input_grad[0]:
|
||||
G_int8, state_G = quantize_rowwise(G)
|
||||
del G
|
||||
W_int8 = W_int8.t().contiguous()
|
||||
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
|
||||
*G_3D_sz[:-1], -1
|
||||
)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
class SwitchBackLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
vector_wise_quantization: bool = False,
|
||||
mem_efficient : bool = False,
|
||||
):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
|
||||
if not is_triton_available:
|
||||
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
|
||||
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
|
||||
|
||||
# By default, we use the global quantization.
|
||||
self.vector_wise_quantization = vector_wise_quantization
|
||||
if self.vector_wise_quantization:
|
||||
self._fn = _switchback_vectorrize
|
||||
if mem_efficient:
|
||||
print('mem efficient is not supported for vector-wise quantization.')
|
||||
exit(1)
|
||||
else:
|
||||
if mem_efficient:
|
||||
self._fn = _switchback_global_mem_efficient
|
||||
else:
|
||||
self._fn = _switchback_global
|
||||
|
||||
def prepare_for_eval(self):
|
||||
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
|
||||
# Note this is experimental and not tested thoroughly.
|
||||
# Note this needs to be explicitly called with something like
|
||||
# def cond_prepare(m):
|
||||
# if hasattr(m, "prepare_for_eval"):
|
||||
# m.prepare_for_eval()
|
||||
# model.apply(cond_prepare)
|
||||
print('=> preparing for eval.')
|
||||
if self.vector_wise_quantization:
|
||||
W_int8, state_W = quantize_rowwise(self.weight)
|
||||
else:
|
||||
W_int8, state_W = quantize_global(self.weight)
|
||||
|
||||
self.register_buffer("W_int8", W_int8)
|
||||
self.register_buffer("state_W", state_W)
|
||||
|
||||
del self.weight
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
return self._fn.apply(x, self.weight, self.bias)
|
||||
else:
|
||||
# If it hasn't been "prepared for eval", run the standard forward pass.
|
||||
if not hasattr(self, "W_int8"):
|
||||
return self._fn.apply(x, self.weight, self.bias)
|
||||
|
||||
# Otherwise, use pre-computed weights.
|
||||
X = x.view(-1, x.size(-1))
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
|
||||
if self.vector_wise_quantization:
|
||||
return int8_matmul_rowwise_dequantize(
|
||||
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
||||
).view(*x.size()[:-1], -1)
|
||||
else:
|
||||
return int8_matmul_mixed_dequanitze(
|
||||
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
||||
).view(*x.size()[:-1], -1)
|
||||
|
||||
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
|
||||
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
|
||||
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
|
||||
|
||||
# This is just the standard linear function.
|
||||
class StandardLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias=None):
|
||||
X = input.view(-1, input.size(-1))
|
||||
|
||||
ctx.save_for_backward(X, weight, bias)
|
||||
output = input.matmul(weight.t())
|
||||
if bias is not None:
|
||||
output += bias.unsqueeze(0).expand_as(output)
|
||||
return output.view(*input.size()[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output_3D):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
|
||||
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))
|
||||
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
|
||||
if bias is not None and ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum(0)
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
class StandardLinear(nn.Linear):
|
||||
|
||||
def forward(self, x):
|
||||
return StandardLinearFunction.apply(x, self.weight, self.bias)
|
|
@ -12,4 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
|
|||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .optimizer import GlobalOptimManager
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .lion import Lion, Lion8bit, Lion32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
|
|
87
bitsandbytes/optim/lion.py
Normal file
87
bitsandbytes/optim/lion.py
Normal file
|
@ -0,0 +1,87 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from bitsandbytes.optim.optimizer import Optimizer1State
|
||||
|
||||
|
||||
class Lion(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
|
||||
class Lion8bit(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
|
||||
class Lion32bit(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
|
@ -669,7 +669,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
step,
|
||||
config["lr"],
|
||||
None,
|
||||
0.0,
|
||||
config['betas'][1],
|
||||
config["weight_decay"],
|
||||
gnorm_scale,
|
||||
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
|
||||
|
|
6
bitsandbytes/research/__init__.py
Normal file
6
bitsandbytes/research/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from . import nn
|
||||
from .autograd._functions import (
|
||||
switchback_bnb,
|
||||
matmul_fp8_global,
|
||||
matmul_fp8_mixed,
|
||||
)
|
0
bitsandbytes/research/autograd/__init__.py
Normal file
0
bitsandbytes/research/autograd/__init__.py
Normal file
411
bitsandbytes/research/autograd/_functions.py
Normal file
411
bitsandbytes/research/autograd/_functions.py
Normal file
|
@ -0,0 +1,411 @@
|
|||
import operator
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce # Required in Python 3
|
||||
|
||||
import torch
|
||||
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler
|
||||
|
||||
|
||||
# math.prod not compatible with python < 3.8
|
||||
def prod(iterable):
|
||||
return reduce(operator.mul, iterable, 1)
|
||||
|
||||
tensor = torch.Tensor
|
||||
|
||||
class MatMulFP8Mixed(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||
|
||||
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
if len(A.shape) == 3:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
else:
|
||||
At = A.transpose(1, 0).contiguous()
|
||||
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
# fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
class MatMulFP8Global(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize(A.float(), code=fw_code)
|
||||
fp8A = F.dequantize(cA, state).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype)
|
||||
|
||||
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
if len(A.shape) == 3:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
else:
|
||||
At = A.transpose(1, 0).contiguous()
|
||||
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
class SwitchBackBnb(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||
# default to pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
ctx.bias = bias
|
||||
if A.shape[-1] == B.shape[0]:
|
||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
# 5. Save state
|
||||
formatB = state.formatB
|
||||
input_shape = A.shape
|
||||
if state.outlier_pool is None:
|
||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||
|
||||
# Cast A to fp16
|
||||
if A.dtype != torch.float16:
|
||||
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
|
||||
|
||||
# 1. Quantize A
|
||||
if len(A.shape) == 3:
|
||||
A = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
||||
A.to(torch.float16), threshold=state.threshold
|
||||
)
|
||||
|
||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||
if state.has_fp16_weights:
|
||||
idx = torch.unique(coo_tensorA.colidx).long()
|
||||
CA[:, idx] = 0
|
||||
CAt[:, idx] = 0
|
||||
subA = A[:, idx]
|
||||
state.subB = B[:, idx].t().contiguous()
|
||||
state.idx = idx
|
||||
else:
|
||||
if state.CxB is None:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
#print('A shape', A.shape)
|
||||
if not state.has_fp16_weights and state.CxB is None:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
subA = None
|
||||
|
||||
# 2. Quantize B
|
||||
if state.has_fp16_weights:
|
||||
#print('B shape', B.shape)
|
||||
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
||||
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
||||
if is_transposed:
|
||||
B = B.contiguous()
|
||||
|
||||
if (state.is_training and not has_grad) or state.CxB is None:
|
||||
state.reset_grads()
|
||||
(
|
||||
CB,
|
||||
state.CBt,
|
||||
state.SCB,
|
||||
state.SCBt,
|
||||
coo_tensorB,
|
||||
) = F.double_quant(B.to(torch.float16))
|
||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||
else:
|
||||
has_grad = False
|
||||
|
||||
if coo_tensorA is not None and not state.has_fp16_weights:
|
||||
# extract outliers
|
||||
|
||||
outlier_idx = torch.unique(coo_tensorA.colidx)
|
||||
state.idx = outlier_idx
|
||||
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
||||
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
||||
# # do not use pool for 2nd FFN layer
|
||||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||
# else:
|
||||
# state.idx = outlier_idx
|
||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||
state.subB = (
|
||||
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||
.t()
|
||||
.contiguous()
|
||||
.to(A.dtype)
|
||||
)
|
||||
CA[:, state.idx.long()] = 0
|
||||
CAt[:, state.idx.long()] = 0
|
||||
subA = A[:, state.idx.long()]
|
||||
|
||||
shapeB = state.SB[0]
|
||||
|
||||
if len(input_shape) == 3:
|
||||
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
||||
else:
|
||||
output_shape = (input_shape[0], shapeB[0])
|
||||
|
||||
# 3. Matmul
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||
# we apply the fused bias here
|
||||
|
||||
if bias is None or bias.dtype == torch.float16:
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||
output = output.to(A.dtype)
|
||||
else: # apply bias separately
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||
output = output.to(A.dtype).add_(bias)
|
||||
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
if coo_tensorA is not None and subA is not None:
|
||||
output += torch.matmul(subA, state.subB)
|
||||
|
||||
# 5. Save state
|
||||
ctx.state = state
|
||||
|
||||
ctx.formatB = formatB
|
||||
ctx.grad_shape = input_shape
|
||||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
ctx.tensors = (CAt, subA, A)
|
||||
ctx.tensor_states = (SCAt, state.idx)
|
||||
else:
|
||||
ctx.tensors = [None, None, None]
|
||||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||
return clone_func(output.view(output_shape))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||
CAt, subA, A = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
state = ctx.state
|
||||
grad_A = grad_B = grad_bias = None
|
||||
|
||||
if req_gradBias:
|
||||
# compute grad_bias first before changing grad_output dtype
|
||||
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||
|
||||
# Cast grad_output to fp16
|
||||
if len(grad_output.shape) == 3:
|
||||
grad_output = grad_output.reshape(
|
||||
-1, grad_output.shape[-1]
|
||||
).contiguous()
|
||||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||
|
||||
if req_gradB:
|
||||
# print('back A shape', A.shape)
|
||||
# print('grad output t shape', grad_output.t().shape)
|
||||
grad_B = torch.matmul(grad_output.t(), A)
|
||||
|
||||
if req_gradA:
|
||||
if state.CBt is not None:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
if state.CxBt is None:
|
||||
state.CxBt, state.SBt = F.transform(
|
||||
state.CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
# print('back B shape', state.CxBt.shape)
|
||||
# print('back grad shape', C32grad.shape)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
||||
elif state.CB is not None:
|
||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
else:
|
||||
raise Exception('State must contain either CBt or CB matrix for backward')
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
def get_block_sizes(input_matrix, weight_matrix):
|
||||
input_features = input_matrix.shape[-1]
|
||||
output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1])
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
bsz, bsz2 = 1024, 1024
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
bsz2 = k
|
||||
break
|
||||
|
||||
return bsz, bsz2
|
||||
|
||||
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
def switchback_bnb(
|
||||
A: tensor,
|
||||
B: tensor,
|
||||
out: tensor = None,
|
||||
state: MatmulLtState = None,
|
||||
threshold=0.0,
|
||||
bias=None
|
||||
):
|
||||
state = state or MatmulLtState()
|
||||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return SwitchBackBnb.apply(A, B, out, bias, state)
|
1
bitsandbytes/research/nn/__init__.py
Normal file
1
bitsandbytes/research/nn/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .modules import LinearFP8Mixed, LinearFP8Global
|
64
bitsandbytes/research/nn/modules.py
Normal file
64
bitsandbytes/research/nn/modules.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
from typing import Optional, TypeVar, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, device, dtype, nn
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.optim import GlobalOptimManager
|
||||
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
||||
class LinearFP8Mixed(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
|
||||
class LinearFP8Global(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
0
bitsandbytes/triton/__init__.py
Normal file
0
bitsandbytes/triton/__init__.py
Normal file
64
bitsandbytes/triton/dequantize_rowwise.py
Normal file
64
bitsandbytes/triton/dequantize_rowwise.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# rowwise quantize
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _dequantize_rowwise(
|
||||
x_ptr,
|
||||
state_x,
|
||||
output_ptr,
|
||||
inv_127,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
max_val = tl.load(state_x + pid)
|
||||
output = max_val * x * inv_127
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
|
||||
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output
|
163
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
Normal file
163
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
import torch
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and global quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
w_factor = tl.load(state_w_ptr)
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
# conditionally add bias
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
|
||||
device = a.device
|
||||
divfactor = 1. / (127. * 127.)
|
||||
has_bias = 0 if bias is None else 1
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_mixed_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
164
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
Normal file
164
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
import torch
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
|
||||
else:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and columnwise quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
w_factor = tl.load(state_w_ptr + rbn)[None, :]
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
|
||||
divfactor = 1. / (127. * 127.)
|
||||
|
||||
has_bias = 0 if bias is None else 1
|
||||
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_rowwise_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
74
bitsandbytes/triton/quantize_columnwise_and_transpose.py
Normal file
74
bitsandbytes/triton/quantize_columnwise_and_transpose.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# This kernel does fused columnwise quantization and transpose.
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_stages=16),
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=16, num_warps=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_columnwise_and_transpose(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
M : tl.constexpr, N : tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid
|
||||
p2_arange = tl.arange(0, P2)
|
||||
p2_arange_mask = p2_arange < M
|
||||
arange = p2_arange * N
|
||||
offsets = block_start + arange
|
||||
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
|
||||
new_start = pid * M
|
||||
new_offsets = new_start + p2_arange
|
||||
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor):
|
||||
M, N = x.shape
|
||||
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(M))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
|
||||
return output, output_maxs
|
||||
|
107
bitsandbytes/triton/quantize_global.py
Normal file
107
bitsandbytes/triton/quantize_global.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_global_transpose(input): return None
|
||||
def quantize_global(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# global quantize
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
|
||||
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global(
|
||||
x_ptr,
|
||||
absmax_inv_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
output = tl.libdevice.llrint(127. * (x * absmax_inv))
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
def quantize_global(x: torch.Tensor):
|
||||
absmax = x.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_global[grid](x, absmax_inv, output, n_elements)
|
||||
return output, absmax
|
||||
|
||||
|
||||
# global quantize and transpose
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
|
||||
# ...
|
||||
],
|
||||
key=['M', 'N']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
|
||||
BLOCK_M : tl.constexpr,
|
||||
BLOCK_N : tl.constexpr,
|
||||
GROUP_M : tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // group_size
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
a = tl.load(A, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
|
||||
# rematerialize to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
|
||||
output = tl.libdevice.llrint(127. * (a * absmax_inv))
|
||||
|
||||
tl.store(B, output, mask=mask)
|
||||
|
||||
def quantize_global_transpose(input):
|
||||
absmax = input.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
M, N = input.shape
|
||||
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
|
||||
|
||||
assert out.size(0) == N and out.size(1) == M
|
||||
assert input.stride(0) == 1 or input.stride(1) == 1
|
||||
assert out.stride(0) == 1 or out.stride(1) == 1
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
|
||||
return out, absmax
|
||||
|
68
bitsandbytes/triton/quantize_rowwise.py
Normal file
68
bitsandbytes/triton/quantize_rowwise.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_rowwise(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# rowwise quantize
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_rowwise(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
|
||||
def quantize_rowwise(x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output, output_maxs
|
||||
|
4
bitsandbytes/triton/triton_utils.py
Normal file
4
bitsandbytes/triton/triton_utils.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
import importlib
|
||||
|
||||
def is_triton_available():
|
||||
return importlib.util.find_spec("triton") is not None
|
|
@ -1,7 +1,143 @@
|
|||
import shlex
|
||||
import subprocess
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
def outlier_hook(module, input):
|
||||
assert isinstance(module, torch.nn.Linear)
|
||||
tracer = OutlierTracer.get_instance()
|
||||
hvalue = tracer.get_hvalue(module.weight)
|
||||
if hvalue not in tracer.hvalue2outlier_idx:
|
||||
outlier_idx = find_outlier_dims(module.weight)
|
||||
tracer.outliers.append(outlier_idx)
|
||||
tracer.hvalues.append(hvalue)
|
||||
if len(tracer.outliers) > 1:
|
||||
# assign the current layer the outlier idx found from the weight
|
||||
# of the previous linear layer
|
||||
if tracer.outliers[-1].numel() > 0:
|
||||
assert tracer.outliers[-1].max() < module.weight.shape[1]
|
||||
tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1]
|
||||
|
||||
else:
|
||||
# first layer, we cannot use the weight for outlier detection
|
||||
# we follow a mixed approach:
|
||||
# (1) zscore test of std of hidden dimension
|
||||
# (2) magnitude > 6 test
|
||||
merged = input[0].view(-1, input[0].shape[-1])
|
||||
# (1) zscore test of std of hidden dimension
|
||||
outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3)
|
||||
# (2) magnitude > 6 test
|
||||
dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1)))
|
||||
outlier_idx2 = torch.where(dims > 0)[0]
|
||||
outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
|
||||
tracer.hvalue2outlier_idx[hvalue] = outlier_idx
|
||||
else:
|
||||
for hook in tracer.hooks:
|
||||
hook.remove()
|
||||
|
||||
|
||||
class OutlierTracer(object):
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("Call get_instance() instead")
|
||||
|
||||
def initialize(self, model):
|
||||
self.last_w = None
|
||||
self.current_outlier_dims = None
|
||||
self.hvalues = []
|
||||
self.outliers = []
|
||||
self.hvalue2outlier_idx = {}
|
||||
self.initialized = True
|
||||
self.hooks = []
|
||||
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
self.hooks.append(m.register_forward_pre_hook(outlier_hook))
|
||||
|
||||
def is_initialized(self):
|
||||
return getattr(self, 'initialized', False)
|
||||
|
||||
def get_hvalue(self, weight):
|
||||
return weight.data.storage().data_ptr()
|
||||
|
||||
def get_outliers(self, weight):
|
||||
if not self.is_initialized():
|
||||
print('Outlier tracer is not initialized...')
|
||||
return None
|
||||
hvalue = self.get_hvalue(weight)
|
||||
if hvalue in self.hvalue2outlier_idx:
|
||||
return self.hvalue2outlier_idx[hvalue]
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls.__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False):
|
||||
if rdm:
|
||||
return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
|
||||
|
||||
m = weight.mean(reduction_dim)
|
||||
mm = m.mean()
|
||||
mstd = m.std()
|
||||
zm = (m-mm)/mstd
|
||||
|
||||
std = weight.std(reduction_dim)
|
||||
stdm = std.mean()
|
||||
stdstd = std.std()
|
||||
|
||||
zstd = (std-stdm)/stdstd
|
||||
|
||||
if topk is not None:
|
||||
val, idx = torch.topk(std.abs(), k=topk, dim=0)
|
||||
else:
|
||||
idx = torch.where(zstd > zscore)[0]
|
||||
|
||||
return idx
|
||||
|
||||
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
|
||||
"""
|
||||
Replace linear modules with a new Linear module.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
linear_replacement (`torch.nn.Module`):
|
||||
The linear module that replaces the old one. Only expects standard arguments.
|
||||
If other arguments need to be passed, use a lambda.
|
||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||
List of modules names not to convert. Defaults to `lm_head`.
|
||||
copy_weights (`bool`):
|
||||
Copy the weights from the old linear module to the new one
|
||||
post_processing_fun_name (`str`):
|
||||
A function name of the replacement linear class that is called
|
||||
after processing.
|
||||
"""
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||
old_module = model._modules[name]
|
||||
model._modules[name] = linear_replacement(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
)
|
||||
if copy_weights:
|
||||
model._modules[name].weight = old_module.weight
|
||||
model._modules[name].bias = old_module.bias
|
||||
|
||||
if post_processing_function is not None:
|
||||
func = getattr(module, post_processing_function, None)
|
||||
if func is not None: func(module)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
||||
def _decode(subprocess_err_out_tuple):
|
||||
|
|
|
@ -1,20 +1,35 @@
|
|||
# Compiling from source
|
||||
|
||||
Basic steps.
|
||||
1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly`
|
||||
2. `CUDA_VERSION=XXX python setup.py install`
|
||||
1. `CUDA_VERSION=XXX make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly`
|
||||
2. `python setup.py install`
|
||||
|
||||
To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive).
|
||||
|
||||
For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands:
|
||||
You can install CUDA locally without sudo by following the following steps:
|
||||
|
||||
```bash
|
||||
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64/" >> ~/.bashrc
|
||||
echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh
|
||||
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
|
||||
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121}
|
||||
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
|
||||
|
||||
# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc
|
||||
bash cuda install 117 ~/local 1
|
||||
```
|
||||
|
||||
By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler.
|
||||
|
||||
Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed
|
||||
|
||||
If you type `nvcc` and it cannot be found, you might need to add to your path or set the CUDA_HOME variable. You can run `python -m bitsandbytes` to find the path to CUDA. For example if `python -m bitsandbytes` shows you the following:
|
||||
```
|
||||
++++++++++++++++++ /usr/local CUDA PATHS +++++++++++++++++++
|
||||
/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so
|
||||
```
|
||||
You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be able to compile like this.
|
||||
|
||||
``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x``
|
||||
|
||||
|
||||
If you have problems compiling the library with these instructions from source, please open an issue.
|
||||
|
|
|
@ -329,6 +329,13 @@ __device__ unsigned char dQuantizeNF4(float x)
|
|||
else
|
||||
return 0b0000;
|
||||
}
|
||||
// sign function for lion
|
||||
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
|
||||
|
||||
template <typename T> __device__ int sgn(T val)
|
||||
{
|
||||
return (T(0) < val) - (val < T(0));
|
||||
}
|
||||
|
||||
template <int STOCHASTIC>
|
||||
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
|
||||
|
@ -857,7 +864,6 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
|
|||
__syncthreads();
|
||||
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
|
||||
|
||||
|
||||
switch(DATA_TYPE)
|
||||
{
|
||||
case General8bit:
|
||||
|
@ -1081,7 +1087,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
|||
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n)
|
||||
{
|
||||
|
||||
|
@ -1128,6 +1134,9 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case LION:
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
|
||||
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
|
||||
|
@ -1159,7 +1168,7 @@ template<typename T, int OPTIMIZER>
|
|||
__launch_bounds__(TH, 1)
|
||||
__global__ void kOptimizer32bit1State(T *g, T *p,
|
||||
float *state1, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
|
||||
{
|
||||
|
||||
|
@ -1228,6 +1237,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
|
||||
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
|
||||
|
@ -1496,7 +1509,7 @@ __global__ void
|
|||
__launch_bounds__(NUM_THREADS, 2)
|
||||
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||
float *unorm,
|
||||
const float beta1,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
|
@ -1557,6 +1570,9 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
if(unorm != NULL)
|
||||
local_unorm += s1_vals[j]*s1_vals[j];
|
||||
break;
|
||||
case LION:
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
|
@ -1580,9 +1596,10 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
|
|||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void
|
||||
__launch_bounds__(1024, 1)
|
||||
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||
const float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
|
@ -1645,8 +1662,19 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
{
|
||||
g_val = float(g_vals[j]);
|
||||
g_val *= gnorm_scale;
|
||||
if(weight_decay > 0.0f)
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
|
||||
if(weight_decay > 0.0f) {
|
||||
switch(OPTIMIZER) {
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
|
||||
|
||||
switch(OPTIMIZER)
|
||||
|
@ -1659,6 +1687,10 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
|||
|
||||
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
|
||||
|
@ -1997,10 +2029,20 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
{
|
||||
g_val = float(g_vals[j]);
|
||||
g_val *= gnorm_scale;
|
||||
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
|
||||
{
|
||||
if(weight_decay > 0.0f)
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
|
||||
{
|
||||
if(weight_decay > 0.0f) {
|
||||
switch(OPTIMIZER) {
|
||||
case MOMENTUM:
|
||||
case ADAGRAD:
|
||||
case RMSPROP:
|
||||
g_val += ((float)p_vals[j])*weight_decay;
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
|
||||
|
||||
|
@ -2012,6 +2054,11 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
else
|
||||
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
|
||||
break;
|
||||
case LION:
|
||||
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2
|
||||
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val));
|
||||
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
|
||||
break;
|
||||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
|
@ -2049,6 +2096,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
case MOMENTUM:
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
|
||||
break;
|
||||
case LION:
|
||||
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
|
||||
break;
|
||||
case RMSPROP:
|
||||
g_val = g_vals[j];
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||
|
@ -3607,24 +3657,28 @@ template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *c
|
|||
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
|
||||
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
|
||||
float* state1, float *unorm, \
|
||||
const float beta1, const float eps, const float weight_decay, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, const int n); \
|
||||
|
||||
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
#define MAKE_Optimizer32bit1State(oname, gtype) \
|
||||
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
|
||||
const float beta1, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
|
||||
|
||||
MAKE_Optimizer32bit1State(MOMENTUM, half)
|
||||
MAKE_Optimizer32bit1State(MOMENTUM, float)
|
||||
MAKE_Optimizer32bit1State(RMSPROP, half)
|
||||
MAKE_Optimizer32bit1State(RMSPROP, float)
|
||||
MAKE_Optimizer32bit1State(LION, half)
|
||||
MAKE_Optimizer32bit1State(LION, float)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
|
@ -3649,6 +3703,7 @@ template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat1
|
|||
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
|
||||
float *unorm, \
|
||||
const float beta1, \
|
||||
const float beta2, \
|
||||
const float eps, const int step, \
|
||||
float* __restrict__ const quantiles1, \
|
||||
float* max1, float* new_max1, \
|
||||
|
@ -3660,11 +3715,14 @@ MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
|
|||
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
|
||||
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
|
||||
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
|
||||
MAKE_PreconditionStatic8bit1State(LION, half)
|
||||
MAKE_PreconditionStatic8bit1State(LION, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
|
||||
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
|
||||
const float *unorm, const float max_unorm, const float param_norm, \
|
||||
const float beta1, \
|
||||
const float beta2, \
|
||||
const float eps, const int step, const float lr, \
|
||||
float* __restrict__ const quantiles1, \
|
||||
float* max1, float* new_max1, \
|
||||
|
@ -3676,6 +3734,8 @@ MAKE_optimizerStatic8bit1State(MOMENTUM, half)
|
|||
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
|
||||
MAKE_optimizerStatic8bit1State(RMSPROP, half)
|
||||
MAKE_optimizerStatic8bit1State(RMSPROP, float)
|
||||
MAKE_optimizerStatic8bit1State(LION, half)
|
||||
MAKE_optimizerStatic8bit1State(LION, float)
|
||||
|
||||
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
|
||||
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
|
||||
|
@ -3762,7 +3822,6 @@ template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(fl
|
|||
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
|
||||
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
|
||||
|
||||
|
||||
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
|
||||
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
|
||||
const float beta1, const float beta2, \
|
||||
|
@ -3791,5 +3850,7 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
|
|||
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
|
||||
|
|
|
@ -34,20 +34,20 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
|
|||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void kOptimizer32bit1State(T* g, T* p,
|
||||
float* state1, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float eps, const float weight_decay,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER>
|
||||
__global__ void
|
||||
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||
float *unorm,
|
||||
const float beta1,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
|
@ -59,7 +59,7 @@ template<typename T, int OPTIMIZER>
|
|||
__global__ void
|
||||
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||
const float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1,
|
||||
const float beta1, const float beta2,
|
||||
const float eps, const int step, const float lr,
|
||||
float* __restrict__ const quantiles1,
|
||||
float* max1, float* new_max1,
|
||||
|
|
40
csrc/ops.cu
40
csrc/ops.cu
|
@ -54,8 +54,6 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
|
|||
{
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
if(STOCHASTIC == 1)
|
||||
assert(blocksize == 4096);
|
||||
|
||||
if(blocksize == 4096)
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
@ -121,17 +119,28 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case LION:
|
||||
// in lion, the momentum update after the parameter update
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -165,12 +174,22 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case LION:
|
||||
// in lion, the momentum update happens after the parameter update
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -199,6 +218,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
case LION:
|
||||
num_blocks = n/BLOCKSIZE_1STATE;
|
||||
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
|
@ -780,6 +800,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
|
|||
MAKE_optimizer32bit(MOMENTUM, float)
|
||||
MAKE_optimizer32bit(RMSPROP, half)
|
||||
MAKE_optimizer32bit(RMSPROP, float)
|
||||
MAKE_optimizer32bit(LION, half)
|
||||
MAKE_optimizer32bit(LION, float)
|
||||
MAKE_optimizer32bit(ADAGRAD, half)
|
||||
MAKE_optimizer32bit(ADAGRAD, float)
|
||||
|
||||
|
@ -799,6 +821,8 @@ MAKE_optimizerStatic8bit(MOMENTUM, half)
|
|||
MAKE_optimizerStatic8bit(MOMENTUM, float)
|
||||
MAKE_optimizerStatic8bit(RMSPROP, half)
|
||||
MAKE_optimizerStatic8bit(RMSPROP, float)
|
||||
MAKE_optimizerStatic8bit(LION, half)
|
||||
MAKE_optimizerStatic8bit(LION, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
|
||||
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
|
||||
|
@ -811,6 +835,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
|
|||
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ typedef enum Optimizer_t
|
|||
RMSPROP = 2,
|
||||
LARS = 3,
|
||||
ADAGRAD = 4,
|
||||
LION = 5,
|
||||
} Optimizer_t;
|
||||
|
||||
typedef enum Transform_t
|
||||
|
|
|
@ -38,7 +38,7 @@ MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
|||
|
||||
|
||||
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
|
||||
void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
void fname##32bit_grad_##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
|
||||
|
@ -51,11 +51,13 @@ MAKE_FUNC32(adam, ADAM, half, fp16)
|
|||
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC32(lion, LION, float, 32)
|
||||
MAKE_FUNC32(lion, LION, half, 16)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
|
||||
|
||||
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
|
||||
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
float *unorm, float max_unorm, float param_norm, \
|
||||
float beta1, float beta2, \
|
||||
float eps, int step, float lr, \
|
||||
|
@ -73,9 +75,11 @@ MAKE_FUNC8(momentum, MOMENTUM, float, 32)
|
|||
MAKE_FUNC8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC8(lion, LION, float, 32)
|
||||
MAKE_FUNC8(lion, LION, half, 16)
|
||||
|
||||
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \
|
||||
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
|
||||
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
|
||||
|
@ -89,6 +93,8 @@ MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
|
|||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_BLOCKWISE8(lion, LION, half, fp16)
|
||||
MAKE_BLOCKWISE8(lion, LION, float, fp32)
|
||||
|
||||
|
||||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
|
@ -96,8 +102,6 @@ void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){
|
|||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1, General8bit>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1, General8bit>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
@ -110,6 +114,7 @@ void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax,
|
|||
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||
|
||||
|
||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
{ \
|
||||
|
@ -169,8 +174,6 @@ extern "C"
|
|||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
||||
|
||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
|
@ -185,11 +188,11 @@ extern "C"
|
|||
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
|
||||
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CFUNC32(adam, float, fp32)
|
||||
MAKE_CFUNC32(adam, half, fp16)
|
||||
|
@ -198,11 +201,13 @@ extern "C"
|
|||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
MAKE_CFUNC32(rmsprop, half, 16)
|
||||
MAKE_CFUNC32(lion, float, 32)
|
||||
MAKE_CFUNC32(lion, half, 16)
|
||||
MAKE_CFUNC32(adagrad, float, 32)
|
||||
MAKE_CFUNC32(adagrad, half, 16)
|
||||
|
||||
#define MAKE_CFUNC8(name, gtype, gbits) \
|
||||
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
float *unorm, float max_unorm, float param_norm, \
|
||||
float beta1, float beta2, \
|
||||
float eps, int step, float lr, \
|
||||
|
@ -210,7 +215,7 @@ extern "C"
|
|||
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||
float weight_decay, float gnorm_scale, int n) \
|
||||
{ \
|
||||
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||
} \
|
||||
|
||||
|
@ -220,12 +225,14 @@ extern "C"
|
|||
MAKE_CFUNC8(momentum, half, 16)
|
||||
MAKE_CFUNC8(rmsprop, float, 32)
|
||||
MAKE_CFUNC8(rmsprop, half, 16)
|
||||
MAKE_CFUNC8(lion, float, 32)
|
||||
MAKE_CFUNC8(lion, half, 16)
|
||||
|
||||
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void c##fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \
|
||||
void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
|
||||
{ fname##_8bit_blockwise_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
|
||||
{ fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
|
||||
|
@ -236,6 +243,8 @@ extern "C"
|
|||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
|
||||
|
||||
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
|
||||
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
|
||||
|
|
|
@ -12,10 +12,12 @@ URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installer
|
|||
URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run
|
||||
URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run
|
||||
URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run
|
||||
URL121=https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run
|
||||
|
||||
|
||||
CUDA_VERSION=$1
|
||||
BASE_PATH=$2
|
||||
EXPORT_BASHRC=$3
|
||||
|
||||
if [[ -n "$CUDA_VERSION" ]]; then
|
||||
if [[ "$CUDA_VERSION" -eq "92" ]]; then
|
||||
|
@ -60,11 +62,14 @@ if [[ -n "$CUDA_VERSION" ]]; then
|
|||
elif [[ "$CUDA_VERSION" -eq "120" ]]; then
|
||||
URL=$URL120
|
||||
FOLDER=cuda-12.0
|
||||
elif [[ "$CUDA_VERSION" -eq "121" ]]; then
|
||||
URL=$URL121
|
||||
FOLDER=cuda-12.1
|
||||
else
|
||||
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
|
||||
echo "argument error: No cuda version passed as input. Choose among versions 92 to 121"
|
||||
fi
|
||||
else
|
||||
echo "argument error: No cuda version passed as input. Choose among: {111, 115}"
|
||||
echo "argument error: No cuda version passed as input. Choose among versions 92 to 112"
|
||||
fi
|
||||
|
||||
FILE=$(basename $URL)
|
||||
|
@ -72,11 +77,13 @@ FILE=$(basename $URL)
|
|||
if [[ -n "$CUDA_VERSION" ]]; then
|
||||
echo $URL
|
||||
echo $FILE
|
||||
wget $URL
|
||||
#wget $URL
|
||||
bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent
|
||||
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64/" >> ~/.bashrc
|
||||
echo "export PATH=$PATH:$BASE_PATH/$FOLDER/bin/" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
if [ "$EXPORT_BASHRC" -eq "1" ]; then
|
||||
echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc
|
||||
echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
fi
|
||||
else
|
||||
echo ""
|
||||
fi
|
||||
|
|
24
deploy.sh
24
deploy.sh
|
@ -10,8 +10,8 @@ if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then
|
|||
fi
|
||||
|
||||
|
||||
module unload cuda
|
||||
module unload gcc
|
||||
module unload cuda && echo "no module function available. Probably not on a slurm cluster."
|
||||
module unload gcc && echo "no module function available. Probably not on a slurm cluster."
|
||||
|
||||
rm -rf dist build
|
||||
make cleaneggs
|
||||
|
@ -128,6 +128,16 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then
|
|||
exit 64
|
||||
fi
|
||||
|
||||
make clean
|
||||
export CUDA_HOME=$BASE_PATH/cuda-12.1
|
||||
make cuda12x CUDA_VERSION=121
|
||||
|
||||
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
|
||||
# Control will enter here if $DIRECTORY doesn't exist.
|
||||
echo "Compilation unsuccessul!" 1>&2
|
||||
exit 64
|
||||
fi
|
||||
|
||||
|
||||
make clean
|
||||
export CUDA_HOME=$BASE_PATH/cuda-10.2
|
||||
|
@ -241,5 +251,15 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then
|
|||
exit 64
|
||||
fi
|
||||
|
||||
make clean
|
||||
export CUDA_HOME=$BASE_PATH/cuda-12.1
|
||||
make cuda12x_nomatmul CUDA_VERSION=121
|
||||
|
||||
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then
|
||||
# Control will enter here if $DIRECTORY doesn't exist.
|
||||
echo "Compilation unsuccessul!" 1>&2
|
||||
exit 64
|
||||
fi
|
||||
|
||||
python -m build
|
||||
python -m twine upload dist/* --verbose
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# No kernel image available
|
||||
|
||||
This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. So solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``?
|
||||
This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``?
|
||||
|
||||
If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation.
|
||||
|
||||
|
|
27
examples/int8_inference_huggingface.py
Normal file
27
examples/int8_inference_huggingface.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
MAX_NEW_TOKENS = 128
|
||||
model_name = 'decapoda-research/llama-7b-hf'
|
||||
|
||||
text = 'Hamburg is in which country?\n'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_ids = tokenizer(text, return_tensors="pt").input_ids
|
||||
|
||||
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
|
||||
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
|
||||
|
||||
n_gpus = torch.cuda.device_count()
|
||||
max_memory = {i: max_memory for i in range(n_gpus)}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_map='auto',
|
||||
load_in_8bit=True,
|
||||
max_memory=max_memory
|
||||
)
|
||||
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
|
||||
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
|
|
@ -1 +1,2 @@
|
|||
lion-pytorch
|
||||
pytest
|
||||
|
|
2
setup.py
2
setup.py
|
@ -18,7 +18,7 @@ def read(fname):
|
|||
|
||||
setup(
|
||||
name=f"bitsandbytes",
|
||||
version=f"0.37.0",
|
||||
version=f"0.38.1",
|
||||
author="Tim Dettmers",
|
||||
author_email="dettmers@cs.washington.edu",
|
||||
description="8-bit optimizers and matrix multiplication routines.",
|
||||
|
|
|
@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||
)
|
||||
if req_grad[1]:
|
||||
|
@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
assert (idx == 0).sum().item() < n * 0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx == 0).sum().item() < n * 0.02
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||
)
|
||||
|
||||
|
@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
n = out_bnb.numel()
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||
assert (idx == 0).sum().item() < n * 0.01
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
out_bnb, out_torch, atol=0.027, rtol=0.2
|
||||
)
|
||||
|
||||
|
@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||
)
|
||||
if req_grad[1]:
|
||||
|
@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||
)
|
||||
if req_grad[1]:
|
||||
|
@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
|
|||
dim2.append(0)
|
||||
|
||||
decomp = [0.0, 6.0]
|
||||
funcs = [(torch.matmul, bnb.matmul)]
|
||||
str_funcs = ["matmul"]
|
||||
funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
|
||||
str_funcs = ["matmullt", 'switchback_bnb']
|
||||
req_grad = [(False, False), (True, False), (True, True), (False, True)]
|
||||
req_grad = list(product([True, False], repeat=3))
|
||||
req_grad_str = []
|
||||
|
@ -407,7 +407,7 @@ def test_matmullt(
|
|||
bias.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||
)
|
||||
if req_grad[1]:
|
||||
|
@ -423,12 +423,12 @@ def test_matmullt(
|
|||
assert (idx == 0).sum().item() <= n * 0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx == 0).sum().item() <= n * 0.02
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||
)
|
||||
|
||||
if req_grad[2]:
|
||||
torch.testing.assert_allclose(gradBias1, gradBias2)
|
||||
torch.testing.assert_close(gradBias1, gradBias2)
|
||||
|
||||
|
||||
n = 1
|
||||
|
@ -502,6 +502,7 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
if n > 0:
|
||||
assert err < 0.115
|
||||
|
||||
#assert err < 0.20
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -526,7 +527,100 @@ def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
|
|||
bias.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose( gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
|
||||
if req_grad[2]:
|
||||
torch.testing.assert_allclose(gradBias1, gradBias2)
|
||||
torch.testing.assert_close(gradBias1, gradBias2)
|
||||
|
||||
|
||||
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
|
||||
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
|
||||
req_grad = list(product([True, False], repeat=3))
|
||||
req_grad_str = []
|
||||
for c in req_grad:
|
||||
strval = ''
|
||||
for v in c:
|
||||
if v == True: strval += 'T'
|
||||
else: strval += 'F'
|
||||
req_grad_str.append(strval)
|
||||
|
||||
transpose = [(False, True), (False, False)]
|
||||
str_transpose = ["NT", "NN"]
|
||||
dtype = [torch.float16, torch.float32]
|
||||
has_fp16_weights = [True, False]
|
||||
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
|
||||
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
|
||||
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||
req_grad = list(req_grad)
|
||||
req_grad[2] = False
|
||||
|
||||
for i in range(k):
|
||||
# normal multiply
|
||||
if funcs[0] in [torch.mm, torch.matmul]:
|
||||
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
|
||||
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
|
||||
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
|
||||
|
||||
if not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
|
||||
elif not transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B, fw_code, bw_code)
|
||||
|
||||
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
|
||||
|
||||
n = out_bnb.numel()
|
||||
err = torch.abs(out_bnb - out_torch).float().mean().item()
|
||||
if n > 0:
|
||||
assert err < 0.115
|
||||
#assert err < 0.20
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
||||
loss_bnb.backward()
|
||||
gradA1 = A.grad
|
||||
gradB1 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
|
||||
loss_torch.backward()
|
||||
gradA2 = A.grad
|
||||
gradB2 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
if dim2 > 0:
|
||||
assert torch.abs(gradB1).sum() > 0.0
|
||||
assert torch.abs(gradB2).sum() > 0.0
|
||||
else:
|
||||
assert torch.abs(gradB1).sum() == 0.0
|
||||
assert torch.abs(gradB2).sum() == 0.0
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
|
||||
assert (idx == 0).sum().item() <= n * 0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx == 0).sum().item() <= n * 0.02
|
||||
grad_err = (gradB1-gradB2).abs().mean()
|
||||
assert grad_err.item() < 0.003
|
||||
torch.testing.assert_close(
|
||||
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||
)
|
||||
|
||||
|
|
|
@ -5,95 +5,20 @@ import pytest
|
|||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.cuda_setup.main import (
|
||||
CUDA_RUNTIME_LIB,
|
||||
determine_cuda_runtime_lib_path,
|
||||
evaluate_cuda_setup,
|
||||
extract_candidate_paths,
|
||||
)
|
||||
|
||||
"""
|
||||
'LD_LIBRARY_PATH': ':/mnt/D/titus/local/cuda-11.1/lib64/'
|
||||
'CONDA_EXE': '/mnt/D/titus/miniconda/bin/conda'
|
||||
'LESSCLOSE': '/usr/bin/lesspipe %s %s'
|
||||
'OLDPWD': '/mnt/D/titus/src'
|
||||
'CONDA_PREFIX': '/mnt/D/titus/miniconda/envs/8-bit'
|
||||
'SSH_AUTH_SOCK': '/mnt/D/titus/.ssh/ssh-agent.tim-uw.sock'
|
||||
'CONDA_PREFIX_1': '/mnt/D/titus/miniconda'
|
||||
'PWD': '/mnt/D/titus/src/8-bit'
|
||||
'HOME': '/mnt/D/titus'
|
||||
'CONDA_PYTHON_EXE': '/mnt/D/titus/miniconda/bin/python'
|
||||
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
|
||||
'TMUX': '/tmp/tmux-1007/default,59286,1'
|
||||
'XDG_DATA_DIRS': '/usr/local/share:/usr/share:/var/lib/snapd/desktop'
|
||||
'SSH_TTY': '/dev/pts/0'
|
||||
'MAIL': '/var/mail/titus'
|
||||
'SHELL': '/bin/bash'
|
||||
'DBUS_SESSION_BUS_ADDRESS': 'unix:path=/run/user/1007/bus'
|
||||
'XDG_RUNTIME_DIR': '/run/user/1007'
|
||||
'PATH': '/mnt/D/titus/miniconda/envs/8-bit/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin:/mnt/D/titus/local/cuda-11.1/bin'
|
||||
'LESSOPEN': '| /usr/bin/lesspipe %s'
|
||||
'_': '/mnt/D/titus/miniconda/envs/8-bit/bin/python'
|
||||
# any that include 'CONDA' that are not 'CONDA_PREFIX'
|
||||
|
||||
# we search for
|
||||
'CUDA_HOME': '/mnt/D/titus/local/cuda-11.1/'
|
||||
"""
|
||||
|
||||
|
||||
class InputAndExpectedOutput(NamedTuple):
|
||||
input: str
|
||||
output: str
|
||||
|
||||
|
||||
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
|
||||
(
|
||||
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
|
||||
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||
),
|
||||
(
|
||||
f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}",
|
||||
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||
),
|
||||
(
|
||||
f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:",
|
||||
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||
),
|
||||
(
|
||||
f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}",
|
||||
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||
),
|
||||
(
|
||||
f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir",
|
||||
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||
),
|
||||
(
|
||||
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
|
||||
f"dir/with/{CUDA_RUNTIME_LIB}",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
|
||||
def happy_path_path_string(tmpdir, request):
|
||||
for path in extract_candidate_paths(request.param):
|
||||
test_dir.mkdir()
|
||||
if CUDA_RUNTIME_LIB in path:
|
||||
(test_input / CUDA_RUNTIME_LIB).touch()
|
||||
|
||||
UNHAPPY_PATH__LD_LIB_TEST_PATHS = [
|
||||
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}",
|
||||
f"a/b/c/{CUDA_RUNTIME_LIB}:d/e/f/{CUDA_RUNTIME_LIB}:g/h/j/{CUDA_RUNTIME_LIB}",
|
||||
]
|
||||
|
||||
|
||||
def test_full_system():
|
||||
def test_cuda_full_system():
|
||||
## this only tests the cuda version and not compute capability
|
||||
|
||||
# if CONDA_PREFIX exists, it has priority before all other env variables
|
||||
# but it does not contain the library directly, so we need to look at the a sub-folder
|
||||
version = ""
|
||||
if "CONDA_PREFIX" in os.environ:
|
||||
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so')
|
||||
ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so.11.0')
|
||||
major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split("."))
|
||||
version = float(f"{major}.{minor}")
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
|
|||
if sumval > count:
|
||||
if throw:
|
||||
print(f"Too many values not close: assert {sumval} < {count}")
|
||||
torch.testing.assert_allclose(a, b, rtol, atol)
|
||||
torch.testing.assert_close(a, b, rtol, atol)
|
||||
|
||||
return sumval
|
||||
|
||||
|
@ -100,7 +100,7 @@ def test_estimate_quantiles(dtype):
|
|||
code = F.estimate_quantiles(A)
|
||||
|
||||
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
|
||||
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
|
||||
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
|
||||
|
||||
A = torch.randn(1024, 1024, device="cuda")
|
||||
A = A.to(dtype)
|
||||
|
@ -125,7 +125,7 @@ def test_quantile_quantization():
|
|||
C = F.quantize_no_absmax(A1, code)
|
||||
A2 = F.dequantize_no_absmax(C, code)
|
||||
diff = torch.abs(A1 - A2).mean().item()
|
||||
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
|
||||
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
|
||||
assert diff < 0.001
|
||||
|
||||
|
||||
|
@ -149,7 +149,7 @@ def test_dynamic_quantization():
|
|||
C, S = F.quantize(A1)
|
||||
A2 = F.dequantize(C, S)
|
||||
diff = torch.abs(A1 - A2).mean().item()
|
||||
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
||||
assert diff < 0.004
|
||||
|
||||
|
||||
|
@ -184,7 +184,7 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
|
|||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
|
@ -193,22 +193,6 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
|
|||
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
rand = torch.rand(1024).cuda()
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C1, S1 = F.quantize_blockwise(A1, rand=rand)
|
||||
C2, S2 = F.quantize_blockwise(A1)
|
||||
# a maximunm distance of quantized values of 1
|
||||
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
|
||||
fraction_smaller = (C1 < C2).float().sum() / C1.numel()
|
||||
fraction_larger = (C1 > C2).float().sum() / C1.numel()
|
||||
torch.testing.assert_allclose(
|
||||
fraction_larger, fraction_smaller, atol=0.01, rtol=0
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gtype", [torch.float32, torch.float16], ids=["float", "half"]
|
||||
|
@ -236,9 +220,9 @@ def test_percentile_clipping(gtype):
|
|||
vals, idx = torch.sort(gnorm_vec1)
|
||||
clip1 = vals[percentile]
|
||||
|
||||
torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
|
||||
torch.testing.assert_allclose(clip1, clip2)
|
||||
torch.testing.assert_allclose(gnorm1, gnorm2)
|
||||
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
|
||||
torch.testing.assert_close(clip1, clip2)
|
||||
torch.testing.assert_close(gnorm1, gnorm2)
|
||||
|
||||
|
||||
def quant(x):
|
||||
|
@ -332,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
|||
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
|
||||
maxA, Ac = quant_methods[0](A, 1)
|
||||
maxB, Bc = quant_methods[1](B, 0)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
|
||||
)
|
||||
if batched:
|
||||
|
@ -403,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
|
|||
out2 = torch.matmul(A.t().float(), B.t().float())
|
||||
out = F.igemm(A.t(), B.t())
|
||||
|
||||
torch.testing.assert_allclose(out.float(), out2)
|
||||
torch.testing.assert_close(out.float(), out2)
|
||||
|
||||
for i in range(k):
|
||||
shapeA = (batch_dim, seq_dim, hidden_dim)
|
||||
|
@ -421,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
|
|||
out2 = torch.matmul(A.float(), B.t().float())
|
||||
out = F.igemm(A, B.t())
|
||||
|
||||
torch.testing.assert_allclose(out.float(), out2)
|
||||
torch.testing.assert_close(out.float(), out2)
|
||||
|
||||
|
||||
n = 3
|
||||
|
@ -452,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
|
|||
)
|
||||
out = F.igemm(A, B, out=iout)
|
||||
|
||||
torch.testing.assert_allclose(out.float(), out2)
|
||||
torch.testing.assert_close(out.float(), out2)
|
||||
|
||||
|
||||
n = 2
|
||||
|
@ -577,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
|
|||
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
|
||||
)
|
||||
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
|
||||
torch.testing.assert_allclose(out.float(), out2.float())
|
||||
torch.testing.assert_close(out.float(), out2.float())
|
||||
|
||||
|
||||
n = 1
|
||||
|
@ -635,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
|
|||
out, S = F.nvidia_transform(A, to_order=orderOut)
|
||||
|
||||
if orderOut == "row":
|
||||
torch.testing.assert_allclose(A.flatten(), out.flatten())
|
||||
torch.testing.assert_close(A.flatten(), out.flatten())
|
||||
elif orderOut == "col":
|
||||
torch.testing.assert_allclose(A.t().flatten(), out.flatten())
|
||||
torch.testing.assert_close(A.t().flatten(), out.flatten())
|
||||
elif orderOut == "col32":
|
||||
if dims == 2:
|
||||
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
|
||||
|
@ -670,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
|
|||
|
||||
assert A.flatten()[i + j] == A[row, col]
|
||||
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
|
||||
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
|
||||
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
|
||||
# torch.testing.assert_close(A.flatten()[i+j], A[row, col])
|
||||
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
|
||||
|
||||
if orderOut == "col32":
|
||||
out2, S = F.nvidia_transform(
|
||||
out, from_order=orderOut, to_order="row", state=S
|
||||
)
|
||||
torch.testing.assert_allclose(A, out2)
|
||||
torch.testing.assert_close(A, out2)
|
||||
|
||||
|
||||
n = 1
|
||||
|
@ -721,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
|
|||
B2, SB = F.transform(B, "col_turing")
|
||||
C2, SC = F.igemmlt(A2, B2, SA, SB)
|
||||
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
||||
torch.testing.assert_allclose(C1, C3.float())
|
||||
torch.testing.assert_close(C1, C3.float())
|
||||
|
||||
# transpose
|
||||
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
|
||||
|
@ -732,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
|
|||
B2t, SBt = F.transform(B, "col_turing", transpose=True)
|
||||
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
|
||||
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
||||
torch.testing.assert_allclose(C1, C3.float())
|
||||
torch.testing.assert_close(C1, C3.float())
|
||||
|
||||
|
||||
dim1 = [32]
|
||||
|
@ -778,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
|
|||
# print(C1.flatten()[:10])
|
||||
# print(C2.flatten()[:10])
|
||||
|
||||
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
|
||||
# torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
|
||||
|
||||
# transpose
|
||||
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
|
||||
|
@ -787,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
|
|||
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
|
||||
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
|
||||
# C3, S = F.transform(C2, 'row', state=SC)
|
||||
# torch.testing.assert_allclose(C1, C3.float())
|
||||
# torch.testing.assert_close(C1, C3.float())
|
||||
|
||||
|
||||
batch_size = 2
|
||||
|
@ -1006,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
|
|||
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
|
||||
|
||||
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
|
||||
#torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1)
|
||||
#torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
|
||||
n = C5.numel()
|
||||
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
|
||||
|
||||
|
@ -1056,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims):
|
|||
)
|
||||
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
|
||||
|
||||
torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
|
||||
torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
|
||||
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
|
||||
torch.testing.assert_close(col_stats1_trunc, col_stats2)
|
||||
torch.testing.assert_close(row_stats1_trunc, row_stats2)
|
||||
torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
|
||||
|
||||
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
|
||||
A, threshold=0.0
|
||||
)
|
||||
|
||||
torch.testing.assert_allclose(col_stats1, col_stats2)
|
||||
torch.testing.assert_allclose(row_stats1, row_stats2)
|
||||
torch.testing.assert_close(col_stats1, col_stats2)
|
||||
torch.testing.assert_close(row_stats1, row_stats2)
|
||||
assert nnz_block_ptr2 is None
|
||||
|
||||
|
||||
|
@ -1089,8 +1073,8 @@ def test_double_quant(dim1, dim2):
|
|||
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
||||
|
||||
# max difference is 1 due to rounding differences
|
||||
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
|
||||
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
|
||||
torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
|
||||
torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
|
||||
|
||||
n = CAt.numel()
|
||||
num_not_close_rows = (
|
||||
|
@ -1113,8 +1097,8 @@ def test_double_quant(dim1, dim2):
|
|||
)
|
||||
assert False
|
||||
|
||||
torch.testing.assert_allclose(Srow.flatten(), statsA)
|
||||
torch.testing.assert_allclose(Scol.flatten(), statsAt)
|
||||
torch.testing.assert_close(Srow.flatten().float(), statsA)
|
||||
torch.testing.assert_close(Scol.flatten().float(), statsAt)
|
||||
|
||||
|
||||
n = 4
|
||||
|
@ -1139,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner):
|
|||
A1, maxA = F.vectorwise_quant(A, dim=1)
|
||||
B1, maxB = F.vectorwise_quant(B, dim=1)
|
||||
|
||||
torch.testing.assert_allclose(maxA.flatten(), stats1a)
|
||||
torch.testing.assert_allclose(maxB.flatten(), stats2a)
|
||||
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
|
||||
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
|
||||
torch.testing.assert_close(maxA.flatten().float(), stats1a)
|
||||
torch.testing.assert_close(maxB.flatten().float(), stats2a)
|
||||
torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
|
||||
torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
|
||||
|
||||
A2, SA = F.nvidia_transform(C1a, "col32")
|
||||
B2, SB = F.nvidia_transform(C2a, "col_turing")
|
||||
|
@ -1344,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
|
|||
# print(out1)
|
||||
# print(out2)
|
||||
|
||||
torch.testing.assert_allclose(out1, out2)
|
||||
torch.testing.assert_close(out1, out2)
|
||||
|
||||
|
||||
n = 2
|
||||
|
@ -1406,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2):
|
|||
A2[
|
||||
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
|
||||
] = coo_tensor.values
|
||||
torch.testing.assert_allclose(A1, A2)
|
||||
torch.testing.assert_close(A1, A2)
|
||||
|
||||
A1 = A * (idx == 0)
|
||||
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
|
||||
)
|
||||
|
||||
|
@ -1618,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
|
|||
|
||||
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
|
||||
|
||||
# torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
|
||||
# torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
|
||||
|
||||
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
|
||||
# torch.cuda.synchronize()
|
||||
|
@ -1649,9 +1633,9 @@ def test_coo2csr():
|
|||
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
|
||||
assert counts.numel() == A.shape[0]
|
||||
|
||||
torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
|
||||
torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
|
||||
idx = A2 != 0
|
||||
torch.testing.assert_allclose(A2[idx], csrA.values)
|
||||
torch.testing.assert_close(A2[idx], csrA.values)
|
||||
|
||||
|
||||
def test_coo2csc():
|
||||
|
@ -1669,10 +1653,10 @@ def test_coo2csc():
|
|||
counts = cscA.colptr[1:] - cscA.colptr[:-1]
|
||||
assert counts.numel() == A.shape[1]
|
||||
|
||||
torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
|
||||
torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
|
||||
# torch uses row-major -> use transpose to transfer to col-major
|
||||
idx = A2.t() != 0
|
||||
torch.testing.assert_allclose(A2.t()[idx], cscA.values)
|
||||
torch.testing.assert_close(A2.t()[idx], cscA.values)
|
||||
|
||||
|
||||
n = 2
|
||||
|
@ -1722,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
max_count, max_idx = torch.sort(counts, descending=True)
|
||||
print(torch.median(max_count.float()))
|
||||
|
||||
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
|
||||
torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
|
||||
|
||||
p = 200 / (2048 * 12288 * 4)
|
||||
n = out1.numel()
|
||||
|
@ -1793,13 +1777,13 @@ batch_size = 2
|
|||
seqdim = 2048
|
||||
values = []
|
||||
values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
#values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
#values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
#values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
#values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
#values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
#values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||
def test_bench_matmul(batch, seq, model, hidden):
|
||||
|
@ -2047,7 +2031,7 @@ def test_extract_outliers():
|
|||
assert outliers2.shape[0] == shapeA[0]
|
||||
assert outliers2.shape[1] == idx.numel()
|
||||
|
||||
torch.testing.assert_allclose(outliers1, outliers2)
|
||||
torch.testing.assert_close(outliers1, outliers2)
|
||||
|
||||
CA, SA = F.transform(A, "col_ampere")
|
||||
|
||||
|
@ -2056,7 +2040,7 @@ def test_extract_outliers():
|
|||
assert outliers2.shape[0] == shapeA[0]
|
||||
assert outliers2.shape[1] == idx.numel()
|
||||
|
||||
torch.testing.assert_allclose(outliers1, outliers2)
|
||||
torch.testing.assert_close(outliers1, outliers2)
|
||||
|
||||
|
||||
|
||||
|
@ -2186,7 +2170,7 @@ def test_few_bit_quant():
|
|||
#assert err2.mean() <= err1
|
||||
|
||||
else:
|
||||
torch.testing.assert_allclose(q1, q2)
|
||||
torch.testing.assert_close(q1, q2)
|
||||
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
|
||||
#assert False
|
||||
|
||||
|
@ -2218,7 +2202,9 @@ def test_kbit_quantile_estimation():
|
|||
|
||||
def test_bench_dequantization():
|
||||
a = torch.rand(1024, 1024, device='cuda').half()
|
||||
qa, SA = F.quantize_blockwise(a)
|
||||
code =F.create_fp8_map(True, 3, 0, 4).cuda()
|
||||
qa, SA = F.quantize_blockwise(a, code=code)
|
||||
print(qa.max())
|
||||
|
||||
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
|
||||
#print(max_theoretical_mu)
|
||||
|
@ -2489,6 +2475,7 @@ def test_gemm_4bit(dtype):
|
|||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
|
||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||
def test_managed():
|
||||
n = 32*10
|
||||
A = F.get_paged(n, n, dtype=torch.float32)
|
||||
|
@ -2523,4 +2510,4 @@ def test_managed():
|
|||
|
||||
# assert (A==17).sum().item() == n*n
|
||||
|
||||
# torch.testing.assert_allclose(A, torch.ones(A.shape)*289)
|
||||
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
|
||||
|
|
|
@ -1,11 +1,17 @@
|
|||
import bitsandbytes as bnb
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from itertools import product
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from bitsandbytes import functional as F
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes import functional as F
|
||||
from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout
|
||||
from bitsandbytes.nn.modules import Linear8bitLt
|
||||
|
||||
|
||||
# contributed by Alex Borzunov, see:
|
||||
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
|
||||
|
||||
|
@ -26,6 +32,7 @@ def test_layout_exact_match():
|
|||
assert restored_x.is_contiguous()
|
||||
assert torch.all(torch.eq(restored_x, x))
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
def test_linear_no_igemmlt():
|
||||
linear = torch.nn.Linear(1024, 3072)
|
||||
|
@ -43,7 +50,7 @@ def test_linear_no_igemmlt():
|
|||
linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False
|
||||
).to(linear.weight.dtype)
|
||||
linear_custom.bias = linear.bias
|
||||
linear = linear_custom.cuda()
|
||||
linear_custom = linear_custom.cuda()
|
||||
linear = linear.half().cuda()
|
||||
|
||||
x_ref = x.clone().cuda().requires_grad_(True)
|
||||
|
@ -59,3 +66,78 @@ def test_linear_no_igemmlt():
|
|||
assert not linear_custom.state.has_fp16_weights
|
||||
assert linear_custom.state.CB is not None
|
||||
assert linear_custom.state.CxB is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt",
|
||||
list(product([False, True], [False, True], [False, True], [False, True])))
|
||||
def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt):
|
||||
linear = torch.nn.Linear(32, 96)
|
||||
x = torch.randn(3, 32, dtype=torch.half)
|
||||
|
||||
linear_custom = Linear8bitLt(
|
||||
linear.in_features,
|
||||
linear.out_features,
|
||||
linear.bias is not None,
|
||||
has_fp16_weights=has_fp16_weights,
|
||||
threshold=6.0,
|
||||
)
|
||||
if force_no_igemmlt:
|
||||
linear_custom.state.force_no_igemmlt = True
|
||||
|
||||
linear_custom.weight = bnb.nn.Int8Params(
|
||||
linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights
|
||||
)
|
||||
linear_custom.bias = linear.bias
|
||||
linear_custom = linear_custom.cuda()
|
||||
|
||||
if serialize_before_forward:
|
||||
state_dict_8bit = linear_custom.state_dict()
|
||||
|
||||
x_first = x.clone().cuda().requires_grad_(True)
|
||||
fx_first = linear_custom(x_first).float()
|
||||
grad_proj = torch.randn_like(fx_first)
|
||||
(fx_first * grad_proj).mean().backward()
|
||||
|
||||
if not serialize_before_forward:
|
||||
state_dict_8bit = linear_custom.state_dict()
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
state_path_8bit = os.path.join(tmpdir, "state_8bit.pth")
|
||||
state_path = os.path.join(tmpdir, "state.pth")
|
||||
|
||||
torch.save(linear.state_dict(), state_path)
|
||||
torch.save(state_dict_8bit, state_path_8bit)
|
||||
|
||||
if not has_fp16_weights:
|
||||
assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path)
|
||||
|
||||
new_state_dict = torch.load(state_path_8bit)
|
||||
|
||||
new_linear_custom = Linear8bitLt(
|
||||
linear.in_features,
|
||||
linear.out_features,
|
||||
linear.bias is not None,
|
||||
has_fp16_weights=has_fp16_weights,
|
||||
threshold=6.0,
|
||||
)
|
||||
if force_no_igemmlt:
|
||||
new_linear_custom.state.force_no_igemmlt = True
|
||||
|
||||
if deserialize_before_cuda:
|
||||
with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError):
|
||||
new_linear_custom.load_state_dict(new_state_dict, strict=True)
|
||||
|
||||
new_linear_custom = new_linear_custom.cuda()
|
||||
|
||||
if not deserialize_before_cuda:
|
||||
new_linear_custom.load_state_dict(new_state_dict, strict=True)
|
||||
|
||||
x_second = x.clone().cuda().requires_grad_(True)
|
||||
fx_second = new_linear_custom(x_second).float()
|
||||
(fx_second * grad_proj).mean().backward()
|
||||
|
||||
# if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised
|
||||
if has_fp16_weights or not deserialize_before_cuda:
|
||||
assert torch.allclose(fx_first, fx_second, atol=1e-5)
|
||||
assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5)
|
||||
|
|
|
@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
|
|||
sumval = (idx == 0).sum().item()
|
||||
if sumval > count:
|
||||
print(f"Too many values not close: assert {sumval} < {count}")
|
||||
torch.testing.assert_allclose(a, b, rtol, atol)
|
||||
torch.testing.assert_close(a, b, rtol, atol)
|
||||
|
||||
|
||||
class LinearFunction(torch.autograd.Function):
|
||||
|
@ -353,6 +353,7 @@ def test_linear8bitlt_accumulated_gradient():
|
|||
assert l1[0].state.CxB is not None
|
||||
assert l1[1].state.CxB is not None
|
||||
|
||||
print(i)
|
||||
if i > 0 and i % acc_steps == 0:
|
||||
opt1.step()
|
||||
opt1.zero_grad(True)
|
||||
|
@ -368,8 +369,8 @@ def test_linear8bitlt_accumulated_gradient():
|
|||
l1[0].weight.data.copy_(l2[0].weight.data)
|
||||
l1[1].weight.data.copy_(l2[1].weight.data)
|
||||
else:
|
||||
torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
|
||||
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
|
||||
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad)
|
||||
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threshold", [0.0, 2.0])
|
||||
|
@ -478,7 +479,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
|||
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
|
||||
scale = grad_ref.abs().mean()
|
||||
|
||||
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
|
||||
torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
|
||||
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
|
||||
assert (idx == 0).sum().item() <= b1.numel() * 0.005
|
||||
|
||||
|
@ -559,11 +560,11 @@ def test_kbit_backprop(module):
|
|||
relerrs2.append(relerr2.mean().item())
|
||||
|
||||
if isinstance(module, bnb.nn.Linear8bitLt):
|
||||
torch.testing.assert_allclose(grad1, grad2, atol=0.008, rtol=0.05)
|
||||
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.008, rtol=0.05)
|
||||
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
|
||||
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
|
||||
else:
|
||||
torch.testing.assert_allclose(grad1, grad2, atol=0.015, rtol=0.05)
|
||||
torch.testing.assert_allclose(bgrad1, bgrad2, atol=0.02, rtol=0.05)
|
||||
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
|
||||
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
|
||||
ref.zero_grad()
|
||||
kbit.zero_grad()
|
||||
|
||||
|
@ -574,4 +575,39 @@ def test_kbit_backprop(module):
|
|||
print('rel out', sum(relerrs1)/len(relerrs1))
|
||||
print('rel grad', sum(relerrs2)/len(relerrs2))
|
||||
|
||||
def test_fp8linear():
|
||||
|
||||
b = 10
|
||||
h = 1024
|
||||
inp = torch.randn(b, h).cuda()
|
||||
fp32 = torch.nn.Linear(h, h*2).cuda()
|
||||
fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
|
||||
fp32b = torch.nn.Linear(h*2, h).cuda()
|
||||
fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
|
||||
|
||||
fp8.weight.data.copy_(fp32.weight.data)
|
||||
fp8.bias.data.copy_(fp32.bias.data)
|
||||
fp8b.weight.data.copy_(fp32b.weight.data)
|
||||
fp8b.bias.data.copy_(fp32b.bias.data)
|
||||
|
||||
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
|
||||
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
|
||||
|
||||
err = (a-b).abs().mean()
|
||||
|
||||
a.mean().backward()
|
||||
b.mean().backward()
|
||||
|
||||
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
|
||||
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
|
||||
|
||||
assert err < 0.05
|
||||
assert graderr < 0.00002
|
||||
assert bgraderr < 0.00002
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@ from itertools import product
|
|||
from os.path import join
|
||||
|
||||
import pytest
|
||||
from lion_pytorch import Lion
|
||||
|
||||
import torch
|
||||
|
||||
import bitsandbytes as bnb
|
||||
|
@ -16,6 +18,13 @@ import bitsandbytes.functional as F
|
|||
|
||||
k = 20
|
||||
|
||||
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
|
||||
idx = torch.isclose(a, b, rtol, atol)
|
||||
error_count = (idx == 0).sum().item()
|
||||
if error_count > max_error_count:
|
||||
print(f"Too many values not close: assert {error_count} < {max_error_count}")
|
||||
torch.testing.assert_close(a, b, rtol, atol)
|
||||
|
||||
|
||||
def get_temp_dir():
|
||||
path = f"/tmp/autoswap/{str(uuid.uuid4())}"
|
||||
|
@ -33,6 +42,7 @@ str2optimizers = {}
|
|||
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
|
||||
str2optimizers["momentum_pytorch"] = (
|
||||
None,
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
|
@ -42,6 +52,7 @@ str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
|||
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
|
||||
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
|
||||
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
|
||||
str2optimizers["momentum"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -51,6 +62,7 @@ str2optimizers["rmsprop"] = (
|
|||
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
|
||||
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
|
||||
str2optimizers["momentum8bit"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -63,6 +75,7 @@ str2optimizers["rmsprop8bit"] = (
|
|||
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
|
||||
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
|
||||
str2optimizers["momentum8bit_blockwise"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
||||
|
@ -76,6 +89,7 @@ str2statenames = {}
|
|||
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["lion"] = [("exp_avg", "state1")]
|
||||
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["rmsprop"] = [("square_avg", "state1")]
|
||||
|
@ -85,14 +99,16 @@ str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"
|
|||
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
|
||||
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
|
||||
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
||||
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
|
||||
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
|
||||
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
|
||||
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
|
||||
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
|
||||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam']
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion']
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
|
@ -121,6 +137,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer.step()
|
||||
torch_optimizer.step()
|
||||
|
||||
|
||||
for name1, name2 in str2statenames[optim_name]:
|
||||
torch.testing.assert_close(
|
||||
torch_optimizer.state[p1][name1],
|
||||
|
@ -129,7 +146,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
rtol=rtol,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
|
||||
# since Lion can have pretty noisy updates where things lie at the boundary
|
||||
# allow up to 10 errors for Lion
|
||||
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
|
||||
|
||||
if i % (k // 5) == 0 and i > 0:
|
||||
path = get_temp_dir()
|
||||
|
@ -139,14 +158,15 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
||||
rm_path(path)
|
||||
torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
|
||||
# since Lion can have pretty noisy updates where things lie at the boundary
|
||||
# allow up to 10 errors for Lion
|
||||
assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10)
|
||||
for name1, name2 in str2statenames[optim_name]:
|
||||
torch.testing.assert_close(
|
||||
torch_optimizer.state[p1][name1],
|
||||
bnb_optimizer.state[p2][name2],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
# since Lion can have pretty noisy updates where things lie at the boundary
|
||||
# allow up to 10 errors for Lion
|
||||
assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2],
|
||||
atol=atol, rtol=rtol,
|
||||
max_error_count=10)
|
||||
|
||||
if gtype != torch.float32:
|
||||
# the adam buffers should also be close because they are 32-bit
|
||||
|
@ -218,9 +238,11 @@ dim2 = [32, 1024, 4097]
|
|||
gtype = [torch.float32, torch.float16, torch.bfloat16]
|
||||
optimizer_names = [
|
||||
"adam8bit",
|
||||
"lion8bit",
|
||||
"momentum8bit",
|
||||
"rmsprop8bit",
|
||||
"adam8bit_blockwise",
|
||||
"lion8bit_blockwise",
|
||||
"momentum8bit_blockwise",
|
||||
"rmsprop8bit_blockwise",
|
||||
]
|
||||
|
@ -264,7 +286,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer.step()
|
||||
torch_optimizer.step()
|
||||
|
||||
torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
|
||||
# since Lion can have pretty noisy updates where things lie at the boundary
|
||||
# allow up to 5 errors for Lion
|
||||
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
|
||||
|
||||
dequant_states = []
|
||||
for name1, name2, qmap, max_val in str2statenames[optim_name]:
|
||||
|
@ -292,7 +316,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
dequant_states.append(s1.clone())
|
||||
|
||||
err = torch.abs(p1 - p2)
|
||||
relerr = err / torch.abs(p1)
|
||||
relerr = err / (torch.abs(p1)+1e-9)
|
||||
if g.dtype == torch.bfloat16:
|
||||
assert err.mean() < 0.00015
|
||||
assert relerr.mean() < 0.0016
|
||||
|
@ -338,7 +362,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
|
||||
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
|
||||
assert num_not_close.sum().item() < 20
|
||||
torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
|
||||
# since Lion can have pretty noisy updates where things lie at the boundary
|
||||
# allow up to 5 errors for Lion
|
||||
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5)
|
||||
|
||||
# the parameters diverge quickly. Here we keep them close
|
||||
# together so we can test against the Adam error
|
||||
|
@ -491,7 +517,7 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
|
|||
print(optim_name, gtype, s / params)
|
||||
# assert s < 3.9
|
||||
|
||||
dim1 = [10*1024]
|
||||
dim1 = [2*1024]
|
||||
gtype = [torch.float16]
|
||||
#mode = ['torch', 'bnb']
|
||||
mode = ['bnb']
|
||||
|
|
59
tests/test_triton.py
Normal file
59
tests/test_triton.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
|
||||
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
|
||||
reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
|
||||
@pytest.mark.parametrize("vector_wise_quantization", [False, True])
|
||||
def test_switchback(vector_wise_quantization):
|
||||
for dim in [83]:
|
||||
for batch in [13]:
|
||||
|
||||
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
|
||||
switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
|
||||
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
|
||||
switchback.weight.data.copy_(standard.weight)
|
||||
switchback.bias.data.copy_(standard.bias)
|
||||
baseline.weight.data.copy_(standard.weight)
|
||||
baseline.bias.data.copy_(standard.bias)
|
||||
|
||||
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
|
||||
x2 = x1.clone().detach().requires_grad_(True)
|
||||
x3 = x1.clone().detach().requires_grad_(True)
|
||||
|
||||
out_standard = standard(x1)
|
||||
(2**10 * out_standard.abs().mean()).backward()
|
||||
|
||||
print(x2.dtype)
|
||||
out_sb = switchback(x2)
|
||||
(2**10 * out_sb.abs().mean()).backward()
|
||||
|
||||
out_baseline = baseline(x3)
|
||||
(2**10 * out_baseline.abs().mean()).backward()
|
||||
|
||||
err_sb = (out_standard - out_sb).abs().mean()
|
||||
err_baseline = (out_standard - out_baseline).abs().mean()
|
||||
print('OUT', err_sb, err_baseline)
|
||||
assert err_sb < 2 * err_baseline
|
||||
|
||||
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
|
||||
err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
|
||||
|
||||
print('GW2', err_sb, err_baseline)
|
||||
assert err_sb < 2 * err_baseline
|
||||
|
||||
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
|
||||
err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
|
||||
|
||||
print('GW1', err_sb, err_baseline)
|
||||
assert err_sb < 2 * err_baseline
|
||||
|
||||
err_sb = (x1.grad - x2.grad).abs().mean()
|
||||
err_baseline = (x1.grad - x3.grad).abs().mean()
|
||||
|
||||
print('GX1', err_sb, err_baseline)
|
||||
assert err_sb < 2 * err_baseline
|
||||
|
Loading…
Reference in New Issue
Block a user