Merge branch 'main' into fix/libcuda-to-torch
This commit is contained in:
commit
a24aae30bf
26
CHANGELOG.md
26
CHANGELOG.md
|
@ -221,3 +221,29 @@ Improvements:
|
|||
Deprecated:
|
||||
- Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
|
||||
- Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
|
||||
|
||||
|
||||
### 0.38.1
|
||||
|
||||
Features:
|
||||
- Added Int8 SwitchBack layers
|
||||
- Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`)
|
||||
|
||||
|
||||
### 0.39.0
|
||||
|
||||
|
||||
Features:
|
||||
- 4-bit matrix multiplication for Float4 and NormalFloat4 data types.
|
||||
- Added 4-bit quantization routines
|
||||
- Doubled quantization routines for 4-bit quantization
|
||||
- Paged optimizers for Adam and Lion.
|
||||
- bfloat16 gradient / weight support for Adam and Lion with 8 or 32-bit states.
|
||||
|
||||
Bug fixes:
|
||||
- Fixed a bug where 8-bit models consumed twice the memory as expected after serialization
|
||||
|
||||
Deprecated:
|
||||
- Kepler binaries (GTX 700s and Tesla K40/K80) are not longer provided via pip and need to be compiled from source. Kepler support might be fully removed in the future.
|
||||
|
||||
|
||||
|
|
42
Makefile
42
Makefile
|
@ -2,6 +2,7 @@ MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST)))
|
|||
ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH)))
|
||||
|
||||
GPP:= /usr/bin/g++
|
||||
#GPP:= /sw/gcc/11.2.0/bin/g++
|
||||
ifeq ($(CUDA_HOME),)
|
||||
CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev)
|
||||
endif
|
||||
|
@ -12,6 +13,7 @@ CUDA_VERSION:=
|
|||
endif
|
||||
|
||||
|
||||
|
||||
NVCC := $(CUDA_HOME)/bin/nvcc
|
||||
|
||||
###########################################
|
||||
|
@ -23,8 +25,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
|
|||
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
|
||||
|
||||
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
|
||||
INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
|
||||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
|
||||
# NVIDIA NVCC compilation flags
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
|
||||
|
@ -32,17 +33,11 @@ COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
|||
COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
|
||||
CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||
CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler
|
||||
|
||||
# Later versions of CUDA support the new architectures
|
||||
CC_CUDA10x += -gencode arch=compute_75,code=sm_75
|
||||
|
||||
CC_CUDA110 := -gencode arch=compute_75,code=sm_75
|
||||
CC_CUDA110 += -gencode arch=compute_80,code=sm_80
|
||||
|
||||
CC_CUDA11x := -gencode arch=compute_75,code=sm_75
|
||||
CC_CUDA11x += -gencode arch=compute_80,code=sm_80
|
||||
CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
||||
|
@ -59,31 +54,32 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
|
|||
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
|
||||
|
||||
|
||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
all: $(BUILD_DIR) env
|
||||
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE_10x) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda110_nomatmul: $(BUILD_DIR) env
|
||||
cuda110_nomatmul_kepler: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda11x_nomatmul: $(BUILD_DIR) env
|
||||
cuda11x_nomatmul_kepler: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
|
||||
cuda110_nomatmul: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda11x_nomatmul: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
|
||||
|
||||
cuda12x_nomatmul: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
|
|
4
benchmarking/switchback/README.md
Normal file
4
benchmarking/switchback/README.md
Normal file
|
@ -0,0 +1,4 @@
|
|||
Steps:
|
||||
|
||||
1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling).
|
||||
2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed.
|
60
benchmarking/switchback/info_a100_py2.jsonl
Normal file
60
benchmarking/switchback/info_a100_py2.jsonl
Normal file
|
@ -0,0 +1,60 @@
|
|||
{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28139352798461914, "standard_gw": 0.2811811864376068, "standard_gx": 0.30258670449256897, "rowwise_fwd": 0.1994594931602478, "rowwise_bwd": 0.16159191727638245, "global_fwd": 0.19502267241477966, "global_bwd": 0.16080215573310852, "x_quantize_rowwise": 0.03306940197944641, "g_quantize_rowwise": 0.08210167288780212, "w_quantize_rowwise": 0.03385916352272034, "w_quantize_colwise_transpose": 0.08635595440864563, "w_quantize_global": 0.09237229824066162, "w_quantize_global_transpose": 0.10007619857788086, "time_standard": 0.8651614189147949, "time_rowwise": 0.8776187896728516, "time_global": 0.944625586271286}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.262625515460968, "standard_gw": 0.2806223928928375, "standard_gx": 0.31118839979171753, "rowwise_fwd": 0.1828707754611969, "rowwise_bwd": 0.21236762404441833, "global_fwd": 0.16665831208229065, "global_bwd": 0.19929558038711548, "x_quantize_rowwise": 0.08227676153182983, "g_quantize_rowwise": 0.03310292959213257, "w_quantize_rowwise": 0.032648444175720215, "w_quantize_colwise_transpose": 0.09015202522277832, "w_quantize_global": 0.0988692045211792, "w_quantize_global_transpose": 0.10057538747787476, "time_standard": 0.8544363081455231, "time_rowwise": 0.9140409529209137, "time_global": 0.96140056848526}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5731917917728424, "standard_gw": 0.5709454417228699, "standard_gx": 0.5963630974292755, "rowwise_fwd": 0.37662312388420105, "rowwise_bwd": 0.281747430562973, "global_fwd": 0.36768242716789246, "global_bwd": 0.28043612837791443, "x_quantize_rowwise": 0.046547502279281616, "g_quantize_rowwise": 0.15532970428466797, "w_quantize_rowwise": 0.032436102628707886, "w_quantize_colwise_transpose": 0.08635222911834717, "w_quantize_global": 0.0947415828704834, "w_quantize_global_transpose": 0.10129809379577637, "time_standard": 1.7405003309249878, "time_rowwise": 1.5499815344810486, "time_global": 1.616980880498886}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5341619253158569, "standard_gw": 0.5690865218639374, "standard_gx": 0.599835067987442, "rowwise_fwd": 0.3233291208744049, "rowwise_bwd": 0.41359663009643555, "global_fwd": 0.2831108868122101, "global_bwd": 0.37280842661857605, "x_quantize_rowwise": 0.15563145279884338, "g_quantize_rowwise": 0.046741217374801636, "w_quantize_rowwise": 0.03306940197944641, "w_quantize_colwise_transpose": 0.09020790457725525, "w_quantize_global": 0.0925213098526001, "w_quantize_global_transpose": 0.09945780038833618, "time_standard": 1.7030835151672363, "time_rowwise": 1.6316622495651245, "time_global": 1.6193576157093048}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2199915945529938, "standard_gw": 1.1069811880588531, "standard_gx": 1.09761580824852, "rowwise_fwd": 0.738043338060379, "rowwise_bwd": 0.5549229681491852, "global_fwd": 0.7219798862934113, "global_bwd": 0.5512163043022156, "x_quantize_rowwise": 0.08748471736907959, "g_quantize_rowwise": 0.3023110330104828, "w_quantize_rowwise": 0.03182142972946167, "w_quantize_colwise_transpose": 0.08632615208625793, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.10032951831817627, "time_standard": 3.424588590860367, "time_rowwise": 2.9078908264636993, "time_global": 2.9647573828697205}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1040829122066498, "standard_gw": 1.1221766471862793, "standard_gx": 1.1548101902008057, "rowwise_fwd": 0.581938773393631, "rowwise_bwd": 0.7480122148990631, "global_fwd": 0.5537159740924835, "global_bwd": 0.7232688367366791, "x_quantize_rowwise": 0.30193477869033813, "g_quantize_rowwise": 0.08745118975639343, "w_quantize_rowwise": 0.03374740481376648, "w_quantize_colwise_transpose": 0.09068101644515991, "w_quantize_global": 0.09645149111747742, "w_quantize_global_transpose": 0.10189786553382874, "time_standard": 3.3810697495937347, "time_rowwise": 2.9659420251846313, "time_global": 2.9868967831134796}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.4533793330192566, "standard_gw": 2.1938569843769073, "standard_gx": 2.179361879825592, "rowwise_fwd": 1.4615543186664581, "rowwise_bwd": 1.0522231459617615, "global_fwd": 1.4288239181041718, "global_bwd": 1.0450035333633423, "x_quantize_rowwise": 0.1691766083240509, "g_quantize_rowwise": 0.5951300263404846, "w_quantize_rowwise": 0.03337860107421875, "w_quantize_colwise_transpose": 0.08653849363327026, "w_quantize_global": 0.0940859317779541, "w_quantize_global_transpose": 0.09976327419281006, "time_standard": 6.826598197221756, "time_rowwise": 5.5918581783771515, "time_global": 5.625840276479721}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.1698065102100372, "standard_gw": 2.1875128149986267, "standard_gx": 2.2887587547302246, "rowwise_fwd": 1.0762326419353485, "rowwise_bwd": 1.4638006687164307, "global_fwd": 1.0450668632984161, "global_bwd": 1.4308765530586243, "x_quantize_rowwise": 0.5953535437583923, "g_quantize_rowwise": 0.16899779438972473, "w_quantize_rowwise": 0.03240257501602173, "w_quantize_colwise_transpose": 0.09106099605560303, "w_quantize_global": 0.09546056389808655, "w_quantize_global_transpose": 0.09852275252342224, "time_standard": 6.6460780799388885, "time_rowwise": 5.615361034870148, "time_global": 5.621790885925293}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.858218133449554, "standard_gw": 4.3631307780742645, "standard_gx": 4.404045641422272, "rowwise_fwd": 2.9063820838928223, "rowwise_bwd": 2.094462513923645, "global_fwd": 2.8426870703697205, "global_bwd": 2.0792782306671143, "x_quantize_rowwise": 0.33241137862205505, "g_quantize_rowwise": 1.1817105114459991, "w_quantize_rowwise": 0.03374367952346802, "w_quantize_colwise_transpose": 0.08633732795715332, "w_quantize_global": 0.09231641888618469, "w_quantize_global_transpose": 0.100012868642807, "time_standard": 13.62539455294609, "time_rowwise": 10.998178273439407, "time_global": 10.991547256708145}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.246581345796585, "standard_gw": 4.42587211728096, "standard_gx": 4.581417888402939, "rowwise_fwd": 2.1114833652973175, "rowwise_bwd": 2.9050447046756744, "global_fwd": 2.0806826651096344, "global_bwd": 2.85966694355011, "x_quantize_rowwise": 1.1816024780273438, "g_quantize_rowwise": 0.33330172300338745, "w_quantize_rowwise": 0.033445656299591064, "w_quantize_colwise_transpose": 0.09065866470336914, "w_quantize_global": 0.09239837527275085, "w_quantize_global_transpose": 0.09984523057937622, "time_standard": 13.253871351480484, "time_rowwise": 11.081408709287643, "time_global": 11.073369532823563}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.4859529435634613, "standard_gw": 0.46338513493537903, "standard_gx": 0.42321905493736267, "rowwise_fwd": 0.2761557698249817, "rowwise_bwd": 0.20775198936462402, "global_fwd": 0.2713911235332489, "global_bwd": 0.20639970898628235, "x_quantize_rowwise": 0.033095479011535645, "g_quantize_rowwise": 0.11894106864929199, "w_quantize_rowwise": 0.03125518560409546, "w_quantize_colwise_transpose": 0.1424551010131836, "w_quantize_global": 0.07288157939910889, "w_quantize_global_transpose": 0.08071959018707275, "time_standard": 1.372557133436203, "time_rowwise": 1.2730397284030914, "time_global": 1.2468136847019196}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.3920421004295349, "standard_gw": 0.44424086809158325, "standard_gx": 0.4759356379508972, "rowwise_fwd": 0.23231282830238342, "rowwise_bwd": 0.28430670499801636, "global_fwd": 0.20883232355117798, "global_bwd": 0.2741999924182892, "x_quantize_rowwise": 0.12018159031867981, "g_quantize_rowwise": 0.03195926547050476, "w_quantize_rowwise": 0.026017427444458008, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.07734447717666626, "w_quantize_global_transpose": 0.0788569450378418, "time_standard": 1.3122186064720154, "time_rowwise": 1.2863576412200928, "time_global": 1.235615462064743}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0111741721630096, "standard_gw": 0.9267590939998627, "standard_gx": 0.8254274725914001, "rowwise_fwd": 0.5434826016426086, "rowwise_bwd": 0.4077926278114319, "global_fwd": 0.5318708717823029, "global_bwd": 0.40537863969802856, "x_quantize_rowwise": 0.059738755226135254, "g_quantize_rowwise": 0.2299174666404724, "w_quantize_rowwise": 0.02545863389968872, "w_quantize_colwise_transpose": 0.14269724488258362, "w_quantize_global": 0.07300823926925659, "w_quantize_global_transpose": 0.07878988981246948, "time_standard": 2.7633607387542725, "time_rowwise": 2.335846424102783, "time_global": 2.305462956428528}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8095316588878632, "standard_gw": 0.8607134222984314, "standard_gx": 0.9204968810081482, "rowwise_fwd": 0.4275888204574585, "rowwise_bwd": 0.5485899746417999, "global_fwd": 0.41000545024871826, "global_bwd": 0.5317628383636475, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.059254467487335205, "w_quantize_rowwise": 0.02466142177581787, "w_quantize_colwise_transpose": 0.14865398406982422, "w_quantize_global": 0.07582828402519226, "w_quantize_global_transpose": 0.08231401443481445, "time_standard": 2.5907419621944427, "time_rowwise": 2.2996440529823303, "time_global": 2.2500604391098022}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0658522844314575, "standard_gw": 1.718364655971527, "standard_gx": 1.6660578548908234, "rowwise_fwd": 1.066897064447403, "rowwise_bwd": 0.8070804178714752, "global_fwd": 1.0473169386386871, "global_bwd": 0.8021742105484009, "x_quantize_rowwise": 0.11274218559265137, "g_quantize_rowwise": 0.4518181085586548, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.14259666204452515, "w_quantize_global": 0.07484853267669678, "w_quantize_global_transpose": 0.07976219058036804, "time_standard": 5.450274795293808, "time_rowwise": 4.326000809669495, "time_global": 4.287026822566986}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.7549192309379578, "standard_gw": 1.6954988241195679, "standard_gx": 1.8179528415203094, "rowwise_fwd": 0.8649080991744995, "rowwise_bwd": 1.0746456682682037, "global_fwd": 0.8023083209991455, "global_bwd": 1.0471977293491364, "x_quantize_rowwise": 0.45225024223327637, "g_quantize_rowwise": 0.11286512017250061, "w_quantize_rowwise": 0.0252649188041687, "w_quantize_colwise_transpose": 0.14732033014297485, "w_quantize_global": 0.07537379860877991, "w_quantize_global_transpose": 0.0807642936706543, "time_standard": 6.268370896577835, "time_rowwise": 4.372753202915192, "time_global": 4.266258329153061}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.098430275917053, "standard_gw": 3.3501461148262024, "standard_gx": 5.560480058193207, "rowwise_fwd": 2.112947404384613, "rowwise_bwd": 1.605246216058731, "global_fwd": 2.0697638392448425, "global_bwd": 1.5953518450260162, "x_quantize_rowwise": 0.21921470761299133, "g_quantize_rowwise": 0.8956789970397949, "w_quantize_rowwise": 0.02710893750190735, "w_quantize_colwise_transpose": 0.14268234372138977, "w_quantize_global": 0.07259473204612732, "w_quantize_global_transpose": 0.07899105548858643, "time_standard": 13.009056448936462, "time_rowwise": 8.35302472114563, "time_global": 8.281741291284561}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.586959421634674, "standard_gw": 3.358360379934311, "standard_gx": 3.6434978246688843, "rowwise_fwd": 1.6269534826278687, "rowwise_bwd": 2.128206193447113, "global_fwd": 1.5950687229633331, "global_bwd": 2.0831897854804993, "x_quantize_rowwise": 0.8954145014286041, "g_quantize_rowwise": 0.21914392709732056, "w_quantize_rowwise": 0.026203691959381104, "w_quantize_colwise_transpose": 0.14658644795417786, "w_quantize_global": 0.07478520274162292, "w_quantize_global_transpose": 0.07964670658111572, "time_standard": 12.58881762623787, "time_rowwise": 8.400868624448776, "time_global": 8.305609226226807}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.229725062847137, "standard_gw": 6.791356950998306, "standard_gx": 6.806455552577972, "rowwise_fwd": 4.252471029758453, "rowwise_bwd": 3.2062679529190063, "global_fwd": 4.175614565610886, "global_bwd": 3.1837262213230133, "x_quantize_rowwise": 0.4321373999118805, "g_quantize_rowwise": 1.787092536687851, "w_quantize_rowwise": 0.0270158052444458, "w_quantize_colwise_transpose": 0.1424252986907959, "w_quantize_global": 0.07348507642745972, "w_quantize_global_transpose": 0.07829815149307251, "time_standard": 21.827537566423416, "time_rowwise": 16.63876697421074, "time_global": 16.52171090245247}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.279478669166565, "standard_gw": 6.7345499992370605, "standard_gx": 7.206875830888748, "rowwise_fwd": 3.209315240383148, "rowwise_bwd": 4.256397485733032, "global_fwd": 3.180190920829773, "global_bwd": 4.177983850240707, "x_quantize_rowwise": 1.7836056649684906, "g_quantize_rowwise": 0.4321075975894928, "w_quantize_rowwise": 0.03205239772796631, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09316205978393555, "w_quantize_global_transpose": 0.10086596012115479, "time_standard": 25.220904499292374, "time_rowwise": 16.5947787463665, "time_global": 16.502466052770615}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5776733160018921, "standard_gw": 0.5300231277942657, "standard_gx": 0.6005913019180298, "rowwise_fwd": 0.33330172300338745, "rowwise_bwd": 0.2957060933113098, "global_fwd": 0.32876431941986084, "global_bwd": 0.29108673334121704, "x_quantize_rowwise": 0.03466755151748657, "g_quantize_rowwise": 0.12264400720596313, "w_quantize_rowwise": 0.033874064683914185, "w_quantize_colwise_transpose": 0.1775398850440979, "w_quantize_global": 0.09503215551376343, "w_quantize_global_transpose": 0.10617449879646301, "time_standard": 1.7082877457141876, "time_rowwise": 1.5277564525604248, "time_global": 1.5083923935890198}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5164109170436859, "standard_gw": 0.5367249250411987, "standard_gx": 0.5876161158084869, "rowwise_fwd": 0.3132447600364685, "rowwise_bwd": 0.3396235406398773, "global_fwd": 0.2943649888038635, "global_bwd": 0.33209100365638733, "x_quantize_rowwise": 0.12357160449028015, "g_quantize_rowwise": 0.035997480154037476, "w_quantize_rowwise": 0.03213062882423401, "w_quantize_colwise_transpose": 0.17676874995231628, "w_quantize_global": 0.09861215949058533, "w_quantize_global_transpose": 0.0998862087726593, "time_standard": 1.6407519578933716, "time_rowwise": 1.5580616891384125, "time_global": 1.5212483704090118}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.2096501886844635, "standard_gw": 1.0663382709026337, "standard_gx": 1.0961703956127167, "rowwise_fwd": 0.6396733224391937, "rowwise_bwd": 0.5173943936824799, "global_fwd": 0.6296299397945404, "global_bwd": 0.5130060017108917, "x_quantize_rowwise": 0.06211921572685242, "g_quantize_rowwise": 0.2361498773097992, "w_quantize_rowwise": 0.03260001540184021, "w_quantize_colwise_transpose": 0.17679482698440552, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.09913742542266846, "time_standard": 3.372158855199814, "time_rowwise": 2.7310699224472046, "time_global": 2.6999935507774353}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.1065565049648285, "standard_gw": 1.0664314031600952, "standard_gx": 1.1266544461250305, "rowwise_fwd": 0.5352050065994263, "rowwise_bwd": 0.6464086472988129, "global_fwd": 0.513765960931778, "global_bwd": 0.6284862756729126, "x_quantize_rowwise": 0.23620948195457458, "g_quantize_rowwise": 0.062271952629089355, "w_quantize_rowwise": 0.031460076570510864, "w_quantize_colwise_transpose": 0.17675384879112244, "w_quantize_global": 0.09486451745033264, "w_quantize_global_transpose": 0.09898096323013306, "time_standard": 3.2996423542499542, "time_rowwise": 2.7547404170036316, "time_global": 2.7010105550289154}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.4367496371269226, "standard_gw": 2.0806193351745605, "standard_gx": 2.19624862074852, "rowwise_fwd": 1.2554042041301727, "rowwise_bwd": 1.0227933526039124, "global_fwd": 1.2322552502155304, "global_bwd": 1.0152235627174377, "x_quantize_rowwise": 0.11792033910751343, "g_quantize_rowwise": 0.4639364778995514, "w_quantize_rowwise": 0.03241002559661865, "w_quantize_colwise_transpose": 0.17657503485679626, "w_quantize_global": 0.09655207395553589, "w_quantize_global_transpose": 0.09958073496818542, "time_standard": 6.713617593050003, "time_rowwise": 5.149658769369125, "time_global": 5.106087774038315}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.1935217082500458, "standard_gw": 2.0055584609508514, "standard_gx": 2.1882541477680206, "rowwise_fwd": 1.0396353900432587, "rowwise_bwd": 1.2542344629764557, "global_fwd": 1.0161921381950378, "global_bwd": 1.233428716659546, "x_quantize_rowwise": 0.4642195999622345, "g_quantize_rowwise": 0.11782720685005188, "w_quantize_rowwise": 0.033117830753326416, "w_quantize_colwise_transpose": 0.17696991562843323, "w_quantize_global": 0.09416043758392334, "w_quantize_global_transpose": 0.10101497173309326, "time_standard": 6.387334316968918, "time_rowwise": 5.091562867164612, "time_global": 5.032401531934738}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.804681986570358, "standard_gw": 4.763372242450714, "standard_gx": 4.064023494720459, "rowwise_fwd": 2.484843134880066, "rowwise_bwd": 1.9691288471221924, "global_fwd": 2.441786229610443, "global_bwd": 1.9574686884880066, "x_quantize_rowwise": 0.2294592559337616, "g_quantize_rowwise": 0.9196549654006958, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.1768544316291809, "w_quantize_global": 0.09644776582717896, "w_quantize_global_transpose": 0.09847059845924377, "time_standard": 13.632077723741531, "time_rowwise": 10.574690997600555, "time_global": 10.506659746170044}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.0907710790634155, "standard_gw": 3.9793066680431366, "standard_gx": 4.302978515625, "rowwise_fwd": 1.992940902709961, "rowwise_bwd": 2.4996213614940643, "global_fwd": 1.9551962614059448, "global_bwd": 2.457551658153534, "x_quantize_rowwise": 0.9200014173984528, "g_quantize_rowwise": 0.2293996512889862, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.17882883548736572, "w_quantize_global": 0.09540095925331116, "w_quantize_global_transpose": 0.09880587458610535, "time_standard": 12.373056262731552, "time_rowwise": 9.831476956605911, "time_global": 9.73566249012947}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.655728936195374, "standard_gw": 8.261296898126602, "standard_gx": 8.064884692430496, "rowwise_fwd": 5.007706582546234, "rowwise_bwd": 3.8615092635154724, "global_fwd": 4.920527338981628, "global_bwd": 3.8330331444740295, "x_quantize_rowwise": 0.45276060700416565, "g_quantize_rowwise": 1.8306002020835876, "w_quantize_rowwise": 0.031366944313049316, "w_quantize_colwise_transpose": 0.1766495406627655, "w_quantize_global": 0.09412690997123718, "w_quantize_global_transpose": 0.09780004620552063, "time_standard": 25.981910526752472, "time_rowwise": 19.621890038251877, "time_global": 19.49014514684677}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 8.033104240894318, "standard_gw": 8.2889124751091, "standard_gx": 8.622754365205765, "rowwise_fwd": 3.8747042417526245, "rowwise_bwd": 5.003921687602997, "global_fwd": 3.8315393030643463, "global_bwd": 4.9162134528160095, "x_quantize_rowwise": 1.8304847180843353, "g_quantize_rowwise": 0.4522763192653656, "w_quantize_rowwise": 0.03413110971450806, "w_quantize_colwise_transpose": 0.1771189272403717, "w_quantize_global": 0.09519979357719421, "w_quantize_global_transpose": 0.09930506348609924, "time_standard": 24.944771081209183, "time_rowwise": 19.661549478769302, "time_global": 19.51393112540245}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7954612374305725, "standard_gw": 0.7456131279468536, "standard_gx": 0.8799619972705841, "rowwise_fwd": 0.43267011642456055, "rowwise_bwd": 0.34622475504875183, "global_fwd": 0.42615458369255066, "global_bwd": 0.344250351190567, "x_quantize_rowwise": 0.03748014569282532, "g_quantize_rowwise": 0.13304129242897034, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.2407953143119812, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10305643081665039, "time_standard": 2.4210363626480103, "time_rowwise": 1.96877121925354, "time_global": 1.8842294812202454}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7120333611965179, "standard_gw": 0.7622130215167999, "standard_gx": 0.8262209594249725, "rowwise_fwd": 0.3702230751514435, "rowwise_bwd": 0.4419572651386261, "global_fwd": 0.3479123115539551, "global_bwd": 0.4306286573410034, "x_quantize_rowwise": 0.13308599591255188, "g_quantize_rowwise": 0.037495046854019165, "w_quantize_rowwise": 0.03398209810256958, "w_quantize_colwise_transpose": 0.23782625794410706, "w_quantize_global": 0.09853765368461609, "w_quantize_global_transpose": 0.10247156023979187, "time_standard": 2.3004673421382904, "time_rowwise": 2.016782760620117, "time_global": 1.9123442471027374}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6292817890644073, "standard_gw": 1.5109702944755554, "standard_gx": 1.482747495174408, "rowwise_fwd": 0.8386112749576569, "rowwise_bwd": 0.6844550371170044, "global_fwd": 0.8220970630645752, "global_bwd": 0.6802082061767578, "x_quantize_rowwise": 0.06883963942527771, "g_quantize_rowwise": 0.25641173124313354, "w_quantize_rowwise": 0.033054500818252563, "w_quantize_colwise_transpose": 0.24027004837989807, "w_quantize_global": 0.0967271625995636, "w_quantize_global_transpose": 0.102948397397995, "time_standard": 4.622999578714371, "time_rowwise": 3.6326125264167786, "time_global": 3.5382024943828583}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4877021312713623, "standard_gw": 1.5015341341495514, "standard_gx": 1.529306173324585, "rowwise_fwd": 0.715944916009903, "rowwise_bwd": 0.8529908955097198, "global_fwd": 0.680088996887207, "global_bwd": 0.8224695920944214, "x_quantize_rowwise": 0.2568177878856659, "g_quantize_rowwise": 0.06864592432975769, "w_quantize_rowwise": 0.03343448042869568, "w_quantize_colwise_transpose": 0.23645907640457153, "w_quantize_global": 0.09399279952049255, "w_quantize_global_transpose": 0.10286271572113037, "time_standard": 4.518542438745499, "time_rowwise": 3.665827214717865, "time_global": 3.5264119505882263}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.261040896177292, "standard_gw": 2.8816498816013336, "standard_gx": 2.8357282280921936, "rowwise_fwd": 1.6594752669334412, "rowwise_bwd": 1.359265297651291, "global_fwd": 1.6287527978420258, "global_bwd": 1.3503879308700562, "x_quantize_rowwise": 0.13146549463272095, "g_quantize_rowwise": 0.5035959184169769, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.24086236953735352, "w_quantize_global": 0.0945068895816803, "w_quantize_global_transpose": 0.10332837700843811, "time_standard": 8.978419005870819, "time_rowwise": 6.8106986582279205, "time_global": 6.693687289953232}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.848360687494278, "standard_gw": 2.8955675661563873, "standard_gx": 3.0499882996082306, "rowwise_fwd": 1.3900883495807648, "rowwise_bwd": 1.6595833003520966, "global_fwd": 1.3514049351215363, "global_bwd": 1.629263162612915, "x_quantize_rowwise": 0.5036592483520508, "g_quantize_rowwise": 0.13118237257003784, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.23709610104560852, "w_quantize_global": 0.0951625406742096, "w_quantize_global_transpose": 0.10216236114501953, "time_standard": 8.793916553258896, "time_rowwise": 6.851561367511749, "time_global": 6.708402186632156}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.4978525042533875, "standard_gw": 6.462603807449341, "standard_gx": 5.5987648665905, "rowwise_fwd": 3.2996535301208496, "rowwise_bwd": 2.6320070028305054, "global_fwd": 3.2426007091999054, "global_bwd": 2.612769603729248, "x_quantize_rowwise": 0.2561397850513458, "g_quantize_rowwise": 0.9984448552131653, "w_quantize_rowwise": 0.033076852560043335, "w_quantize_colwise_transpose": 0.24232640862464905, "w_quantize_global": 0.09618699550628662, "w_quantize_global_transpose": 0.10257214307785034, "time_standard": 18.559221178293228, "time_rowwise": 13.9242522418499, "time_global": 13.771317899227142}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.5702440440654755, "standard_gw": 5.717620253562927, "standard_gx": 6.08203187584877, "rowwise_fwd": 2.649586647748947, "rowwise_bwd": 3.315173089504242, "global_fwd": 2.6132799685001373, "global_bwd": 3.257807344198227, "x_quantize_rowwise": 0.9980201721191406, "g_quantize_rowwise": 0.256560742855072, "w_quantize_rowwise": 0.03356859087944031, "w_quantize_colwise_transpose": 0.23729726672172546, "w_quantize_global": 0.09495764970779419, "w_quantize_global_transpose": 0.103779137134552, "time_standard": 17.369896173477173, "time_rowwise": 13.207826763391495, "time_global": 13.04202526807785}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 13.058379292488098, "standard_gw": 11.480242013931274, "standard_gx": 11.092845350503922, "rowwise_fwd": 6.637874990701675, "rowwise_bwd": 5.24790957570076, "global_fwd": 6.521012634038925, "global_bwd": 5.214303731918335, "x_quantize_rowwise": 0.5057565867900848, "g_quantize_rowwise": 1.989319920539856, "w_quantize_rowwise": 0.03439188003540039, "w_quantize_colwise_transpose": 0.24280324578285217, "w_quantize_global": 0.09520724415779114, "w_quantize_global_transpose": 0.10240450501441956, "time_standard": 35.631466656923294, "time_rowwise": 26.138298213481903, "time_global": 25.908246636390686}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.13397628068924, "standard_gw": 11.371888220310211, "standard_gx": 12.12756335735321, "rowwise_fwd": 5.2495077252388, "rowwise_bwd": 6.638709455728531, "global_fwd": 5.215313285589218, "global_bwd": 6.5222084522247314, "x_quantize_rowwise": 1.9870512187480927, "g_quantize_rowwise": 0.5058236420154572, "w_quantize_rowwise": 0.034634023904800415, "w_quantize_colwise_transpose": 0.23674964904785156, "w_quantize_global": 0.09457767009735107, "w_quantize_global_transpose": 0.10183081030845642, "time_standard": 34.63342785835266, "time_rowwise": 26.024363934993744, "time_global": 25.798693299293518}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2125298380851746, "standard_gw": 1.1111274361610413, "standard_gx": 1.0840706527233124, "rowwise_fwd": 0.6057210266590118, "rowwise_bwd": 0.51865354180336, "global_fwd": 0.5952082574367523, "global_bwd": 0.5167685449123383, "x_quantize_rowwise": 0.045686960220336914, "g_quantize_rowwise": 0.15827640891075134, "w_quantize_rowwise": 0.04361197352409363, "w_quantize_colwise_transpose": 0.34067779779434204, "w_quantize_global": 0.13644620776176453, "w_quantize_global_transpose": 0.14925003051757812, "time_standard": 3.407727926969528, "time_rowwise": 2.823755145072937, "time_global": 2.7127638459205627}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.0731369256973267, "standard_gw": 1.1365897953510284, "standard_gx": 1.1498592793941498, "rowwise_fwd": 0.5573518574237823, "rowwise_bwd": 0.615488737821579, "global_fwd": 0.5220361053943634, "global_bwd": 0.5939789116382599, "x_quantize_rowwise": 0.15765801072120667, "g_quantize_rowwise": 0.04369020462036133, "w_quantize_rowwise": 0.047359615564346313, "w_quantize_colwise_transpose": 0.5526281893253326, "w_quantize_global": 0.13606995344161987, "w_quantize_global_transpose": 0.15017390251159668, "time_standard": 3.359586000442505, "time_rowwise": 3.1107664108276367, "time_global": 2.7401968836784363}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4274885654449463, "standard_gw": 2.1799951791763306, "standard_gx": 2.1426528692245483, "rowwise_fwd": 1.195710152387619, "rowwise_bwd": 1.027170568704605, "global_fwd": 1.1747106909751892, "global_bwd": 1.0251589119434357, "x_quantize_rowwise": 0.08098781108856201, "g_quantize_rowwise": 0.3052949905395508, "w_quantize_rowwise": 0.043764710426330566, "w_quantize_colwise_transpose": 0.33987686038017273, "w_quantize_global": 0.13646483421325684, "w_quantize_global_transpose": 0.14739856123924255, "time_standard": 6.750136613845825, "time_rowwise": 5.172800272703171, "time_global": 5.050010979175568}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1661892533302307, "standard_gw": 2.0948275923728943, "standard_gx": 2.306375652551651, "rowwise_fwd": 1.0587647557258606, "rowwise_bwd": 1.1999905109405518, "global_fwd": 1.0296404361724854, "global_bwd": 1.1749230325222015, "x_quantize_rowwise": 0.3054030239582062, "g_quantize_rowwise": 0.08077546954154968, "w_quantize_rowwise": 0.047225505113601685, "w_quantize_colwise_transpose": 0.600133091211319, "w_quantize_global": 0.13613328337669373, "w_quantize_global_transpose": 0.1484006643295288, "time_standard": 6.567392498254776, "time_rowwise": 5.387119948863983, "time_global": 4.97010350227356}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.807606339454651, "standard_gw": 4.170913249254227, "standard_gx": 4.117622971534729, "rowwise_fwd": 2.370934933423996, "rowwise_bwd": 1.9481778144836426, "global_fwd": 2.3383721709251404, "global_bwd": 1.9443817436695099, "x_quantize_rowwise": 0.1547597348690033, "g_quantize_rowwise": 0.6000511348247528, "w_quantize_rowwise": 0.04361942410469055, "w_quantize_colwise_transpose": 0.3403201699256897, "w_quantize_global": 0.13600289821624756, "w_quantize_global_transpose": 0.1474134624004364, "time_standard": 13.096142560243607, "time_rowwise": 9.628776460886002, "time_global": 9.491894394159317}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.1619837284088135, "standard_gw": 4.181284457445145, "standard_gx": 4.635505378246307, "rowwise_fwd": 1.9684135913848877, "rowwise_bwd": 2.3750364780426025, "global_fwd": 1.9445866346359253, "global_bwd": 2.3551955819129944, "x_quantize_rowwise": 0.6004162132740021, "g_quantize_rowwise": 0.15468522906303406, "w_quantize_rowwise": 0.04730746150016785, "w_quantize_colwise_transpose": 0.5999617278575897, "w_quantize_global": 0.1364201307296753, "w_quantize_global_transpose": 0.14847144484519958, "time_standard": 12.978773564100266, "time_rowwise": 9.927105158567429, "time_global": 9.521059691905975}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.52371209859848, "standard_gw": 8.354485034942627, "standard_gx": 8.69860127568245, "rowwise_fwd": 4.717472940683365, "rowwise_bwd": 3.8843750953674316, "global_fwd": 4.645414650440216, "global_bwd": 3.8761012256145477, "x_quantize_rowwise": 0.3024861216545105, "g_quantize_rowwise": 1.1897757649421692, "w_quantize_rowwise": 0.04366785287857056, "w_quantize_colwise_transpose": 0.33988431096076965, "w_quantize_global": 0.1359507441520691, "w_quantize_global_transpose": 0.14724582433700562, "time_standard": 26.576798409223557, "time_rowwise": 18.832147121429443, "time_global": 18.651459366083145}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.307881653308868, "standard_gw": 8.214320987462997, "standard_gx": 9.21182706952095, "rowwise_fwd": 3.8919784128665924, "rowwise_bwd": 4.72346693277359, "global_fwd": 3.8761794567108154, "global_bwd": 4.673641175031662, "x_quantize_rowwise": 1.1893920600414276, "g_quantize_rowwise": 0.3024972975254059, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.6039328873157501, "w_quantize_global": 0.13624504208564758, "w_quantize_global_transpose": 0.14867261052131653, "time_standard": 25.734029710292816, "time_rowwise": 18.972668796777725, "time_global": 18.540948629379272}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.30372044444084, "standard_gw": 16.480475664138794, "standard_gx": 17.61433482170105, "rowwise_fwd": 9.49602946639061, "rowwise_bwd": 7.768530398607254, "global_fwd": 9.3533955514431, "global_bwd": 7.749464362859726, "x_quantize_rowwise": 0.5977451801300049, "g_quantize_rowwise": 2.3684948682785034, "w_quantize_rowwise": 0.04375725984573364, "w_quantize_colwise_transpose": 0.34042075276374817, "w_quantize_global": 0.13628974556922913, "w_quantize_global_transpose": 0.14671683311462402, "time_standard": 53.398530930280685, "time_rowwise": 37.09545359015465, "time_global": 36.83258220553398}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.041003495454788, "standard_gw": 17.770148813724518, "standard_gx": 17.70009845495224, "rowwise_fwd": 7.756810635328293, "rowwise_bwd": 9.502101689577103, "global_fwd": 7.7384114265441895, "global_bwd": 9.36170294880867, "x_quantize_rowwise": 2.3686252534389496, "g_quantize_rowwise": 0.5980581045150757, "w_quantize_rowwise": 0.04723668098449707, "w_quantize_colwise_transpose": 0.6035342812538147, "w_quantize_global": 0.13603642582893372, "w_quantize_global_transpose": 0.1485198736190796, "time_standard": 53.511250764131546, "time_rowwise": 38.64651545882225, "time_global": 38.121502846479416}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.598241299390793, "standard_gw": 4.294309765100479, "standard_gx": 4.261095076799393, "rowwise_fwd": 2.0976848900318146, "rowwise_bwd": 1.9718967378139496, "global_fwd": 2.0763762295246124, "global_bwd": 1.9703581929206848, "x_quantize_rowwise": 0.08216872811317444, "g_quantize_rowwise": 0.4405900835990906, "w_quantize_rowwise": 0.1553371548652649, "w_quantize_colwise_transpose": 1.6110725700855255, "w_quantize_global": 0.481240451335907, "w_quantize_global_transpose": 0.5061514675617218, "time_standard": 13.153646141290665, "time_rowwise": 10.653059929609299, "time_global": 9.85119491815567}
|
||||
{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.35885414481163, "standard_gw": 4.29583340883255, "standard_gx": 4.5370906591415405, "rowwise_fwd": 2.0015686750411987, "rowwise_bwd": 2.097565680742264, "global_fwd": 1.969795674085617, "global_bwd": 2.075403928756714, "x_quantize_rowwise": 0.43984130024909973, "g_quantize_rowwise": 0.08216127753257751, "w_quantize_rowwise": 0.22544339299201965, "w_quantize_colwise_transpose": 2.4342015385627747, "w_quantize_global": 0.48087164759635925, "w_quantize_global_transpose": 0.5099289119243622, "time_standard": 13.19177821278572, "time_rowwise": 11.576615273952484, "time_global": 9.85383614897728}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.09888744354248, "standard_gw": 8.230950683355331, "standard_gx": 8.465446531772614, "rowwise_fwd": 4.182614386081696, "rowwise_bwd": 3.747660666704178, "global_fwd": 4.138719290494919, "global_bwd": 3.74777615070343, "x_quantize_rowwise": 0.15515834093093872, "g_quantize_rowwise": 0.8699297904968262, "w_quantize_rowwise": 0.15544891357421875, "w_quantize_colwise_transpose": 1.6132444143295288, "w_quantize_global": 0.48100948333740234, "w_quantize_global_transpose": 0.5051903426647186, "time_standard": 25.795284658670425, "time_rowwise": 18.955007195472717, "time_global": 18.128734081983566}
|
||||
{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.378107100725174, "standard_gw": 8.923027664422989, "standard_gx": 9.049762040376663, "rowwise_fwd": 3.765825182199478, "rowwise_bwd": 4.183519631624222, "global_fwd": 3.744799643754959, "global_bwd": 4.1590481996536255, "x_quantize_rowwise": 0.8693933486938477, "g_quantize_rowwise": 0.1553073525428772, "w_quantize_rowwise": 0.2258792519569397, "w_quantize_colwise_transpose": 2.4386271834373474, "w_quantize_global": 0.4811100661754608, "w_quantize_global_transpose": 0.5102269351482391, "time_standard": 26.350896805524826, "time_rowwise": 20.5615796148777, "time_global": 18.842913210392}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.266115337610245, "standard_gw": 17.671160399913788, "standard_gx": 17.10302010178566, "rowwise_fwd": 8.347474038600922, "rowwise_bwd": 7.514089345932007, "global_fwd": 8.263226598501205, "global_bwd": 7.487393915653229, "x_quantize_rowwise": 0.3021806478500366, "g_quantize_rowwise": 1.7319358885288239, "w_quantize_rowwise": 0.15519559383392334, "w_quantize_colwise_transpose": 1.6133114695549011, "w_quantize_global": 0.48247724771499634, "w_quantize_global_transpose": 0.506427139043808, "time_standard": 53.04029583930969, "time_rowwise": 37.3353473842144, "time_global": 36.44480183720589}
|
||||
{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 17.73649826645851, "standard_gw": 16.359902918338776, "standard_gx": 18.0993489921093, "rowwise_fwd": 7.493957877159119, "rowwise_bwd": 8.352488279342651, "global_fwd": 7.486194372177124, "global_bwd": 8.28903540968895, "x_quantize_rowwise": 1.7313472926616669, "g_quantize_rowwise": 0.30205026268959045, "w_quantize_rowwise": 0.2255477011203766, "w_quantize_colwise_transpose": 2.4363920092582703, "w_quantize_global": 0.4815347492694855, "w_quantize_global_transpose": 0.5103759467601776, "time_standard": 52.195750176906586, "time_rowwise": 36.90168634057045, "time_global": 35.16044095158577}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.309611052274704, "standard_gw": 32.85098075866699, "standard_gx": 34.34552624821663, "rowwise_fwd": 16.74525812268257, "rowwise_bwd": 15.026237815618515, "global_fwd": 16.574162989854813, "global_bwd": 14.977734535932541, "x_quantize_rowwise": 0.5954466760158539, "g_quantize_rowwise": 3.4569576382637024, "w_quantize_rowwise": 0.15521422028541565, "w_quantize_colwise_transpose": 1.6133897006511688, "w_quantize_global": 0.4822872579097748, "w_quantize_global_transpose": 0.5065612494945526, "time_standard": 103.50611805915833, "time_rowwise": 70.44348493218422, "time_global": 69.44413110613823}
|
||||
{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.40017828345299, "standard_gw": 33.037226647138596, "standard_gx": 36.30436211824417, "rowwise_fwd": 15.043705701828003, "rowwise_bwd": 16.756191849708557, "global_fwd": 15.011314302682877, "global_bwd": 16.580048948526382, "x_quantize_rowwise": 3.4548528492450714, "g_quantize_rowwise": 0.5951337516307831, "w_quantize_rowwise": 0.22584572434425354, "w_quantize_colwise_transpose": 2.4329908192157745, "w_quantize_global": 0.4813261330127716, "w_quantize_global_transpose": 0.5101598799228668, "time_standard": 104.74176704883575, "time_rowwise": 71.54594734311104, "time_global": 69.67006251215935}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 73.40333238244057, "standard_gw": 73.76311346888542, "standard_gx": 70.41774317622185, "rowwise_fwd": 33.37597846984863, "rowwise_bwd": 30.345775187015533, "global_fwd": 33.00366923213005, "global_bwd": 30.218638479709625, "x_quantize_rowwise": 1.1825822293758392, "g_quantize_rowwise": 6.902601569890976, "w_quantize_rowwise": 0.15529245138168335, "w_quantize_colwise_transpose": 1.6109198331832886, "w_quantize_global": 0.48149004578590393, "w_quantize_global_transpose": 0.5066059529781342, "time_standard": 217.58418902754784, "time_rowwise": 147.33626320958138, "time_global": 146.05870097875595}
|
||||
{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 71.5160183608532, "standard_gw": 73.76786693930626, "standard_gx": 72.98104092478752, "rowwise_fwd": 30.291248112916946, "rowwise_bwd": 33.36654230952263, "global_fwd": 30.181586742401123, "global_bwd": 33.082425594329834, "x_quantize_rowwise": 6.902430206537247, "g_quantize_rowwise": 1.1815279722213745, "w_quantize_rowwise": 0.2262219786643982, "w_quantize_colwise_transpose": 2.4421699345111847, "w_quantize_global": 0.4816502332687378, "w_quantize_global_transpose": 0.5105249583721161, "time_standard": 218.26492622494698, "time_rowwise": 148.17800745368004, "time_global": 146.1080126464367}
|
138
benchmarking/switchback/make_plot_with_jsonl.py
Normal file
138
benchmarking/switchback/make_plot_with_jsonl.py
Normal file
|
@ -0,0 +1,138 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import matplotlib.gridspec as gridspec
|
||||
|
||||
cmap=plt.get_cmap('cool')
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
|
||||
gs = gridspec.GridSpec(1, 2)
|
||||
|
||||
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
|
||||
batch_size_for_plot1 = 32768
|
||||
batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17]
|
||||
dims_to_xtick = [1024, 2048, 4096]
|
||||
logscale_plot1 = True
|
||||
|
||||
ax = fig.add_subplot(gs[0, 0])
|
||||
|
||||
# TODO: change this to what you want.
|
||||
rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
|
||||
df = rdf[rdf.batch_size == batch_size_for_plot1]
|
||||
|
||||
# first plot the time occupied by different operations
|
||||
for k, marker, ls, color, name in [
|
||||
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
|
||||
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
|
||||
|
||||
('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'),
|
||||
('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'),
|
||||
('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'),
|
||||
|
||||
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
|
||||
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
|
||||
|
||||
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
|
||||
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
|
||||
('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'),
|
||||
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'),
|
||||
]:
|
||||
xs = []
|
||||
ys = []
|
||||
for embed_dim in dims_to_consider:
|
||||
# average over dim -> 4*dim and 4*dim -> dim
|
||||
df_ = df[df.dim_in == embed_dim]
|
||||
df_ = df_[df_.dim_out == embed_dim * 4]
|
||||
xs.append(embed_dim)
|
||||
y_ = 0
|
||||
for k_ in k.split('+'):
|
||||
y_ += df_[k_].values[0]
|
||||
df_ = df[df.dim_in == embed_dim * 4]
|
||||
df_ = df_[df_.dim_out == embed_dim]
|
||||
for k_ in k.split('+'):
|
||||
y_ += df_[k_].values[0]
|
||||
ys.append(y_ * 0.5)
|
||||
|
||||
|
||||
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
|
||||
|
||||
|
||||
ax.set_xlabel('dim', fontsize=13)
|
||||
ax.set_ylabel('time (ms)', fontsize=13)
|
||||
|
||||
ax.grid()
|
||||
|
||||
ax.set_xscale('log')
|
||||
if logscale_plot1:
|
||||
ax.set_yscale('log')
|
||||
|
||||
ax.tick_params(axis='x', labelsize=11)
|
||||
ax.tick_params(axis='y', labelsize=11)
|
||||
|
||||
ax.set_xticks(dims_to_xtick)
|
||||
ax.set_xticklabels(dims_to_xtick)
|
||||
ax.set_xticks([], minor=True)
|
||||
|
||||
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
|
||||
leg.get_texts()[0].set_fontweight('bold')
|
||||
leg.get_texts()[1].set_fontweight('bold')
|
||||
plt.subplots_adjust(left=0.1)
|
||||
ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20)
|
||||
|
||||
|
||||
ax = fig.add_subplot(gs[0, 1])
|
||||
|
||||
# now plot the % speedup for different batch sizes
|
||||
for j, batch_size in enumerate(batch_sizes_for_plot2):
|
||||
all_xs, all_ys = [], []
|
||||
for k, marker, ls, color, name in [
|
||||
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
|
||||
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
|
||||
]:
|
||||
|
||||
xs, ys = [], []
|
||||
df = rdf[rdf.batch_size == batch_size]
|
||||
for embed_dim in dims_to_consider:
|
||||
df_ = df[df.dim_in == embed_dim]
|
||||
df_ = df_[df_.dim_out == embed_dim * 4]
|
||||
xs.append(embed_dim)
|
||||
y_ = 0
|
||||
for k_ in k.split('+'):
|
||||
y_ += df_[k_].values[0]
|
||||
df_ = df[df.dim_in == embed_dim * 4]
|
||||
df_ = df_[df_.dim_out == embed_dim]
|
||||
for k_ in k.split('+'):
|
||||
y_ += df_[k_].values[0]
|
||||
ys.append(y_ * 0.5)
|
||||
all_xs.append(xs)
|
||||
all_ys.append(ys)
|
||||
|
||||
color = cmap(j * 0.25)
|
||||
real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))]
|
||||
markers = ['^', 'v', 'P', 'o']
|
||||
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
|
||||
|
||||
ax.legend()
|
||||
ax.set_xlabel('dim', fontsize=13)
|
||||
ax.set_xscale('log')
|
||||
ax.grid()
|
||||
ax.set_ylabel(r'% speedup', fontsize=13)
|
||||
|
||||
|
||||
ax.tick_params(axis='x', labelsize=11)
|
||||
ax.tick_params(axis='y', labelsize=11)
|
||||
|
||||
ax.set_xticks(dims_to_xtick)
|
||||
ax.set_xticklabels(dims_to_xtick)
|
||||
ax.set_xticks([], minor=True)
|
||||
|
||||
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
|
||||
|
||||
|
||||
|
||||
plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')
|
||||
|
BIN
benchmarking/switchback/plot_with_info.pdf
Normal file
BIN
benchmarking/switchback/plot_with_info.pdf
Normal file
Binary file not shown.
102
benchmarking/switchback/speed_benchmark.py
Normal file
102
benchmarking/switchback/speed_benchmark.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
import json
|
||||
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
||||
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
|
||||
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
||||
|
||||
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
|
||||
|
||||
def get_time(k, fn, info_dict):
|
||||
|
||||
for _ in range(repeat // 2):
|
||||
fn()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(repeat):
|
||||
fn()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
ms = (end - start) / repeat * 1000
|
||||
print(f"time {k}: {ms:.3f} ms")
|
||||
info_dict[k] = ms
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(0)
|
||||
wm = 4
|
||||
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
|
||||
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
|
||||
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
|
||||
|
||||
# switch switches dim_in and dim_out
|
||||
for switch in [False, True]:
|
||||
|
||||
# hparams
|
||||
repeat = 64
|
||||
batch_size = batch_size
|
||||
dim_out = dim * wm
|
||||
dim_in = dim
|
||||
if switch:
|
||||
dim_out = dim
|
||||
dim_in = wm * dim
|
||||
|
||||
dim_in = round(dim_in)
|
||||
dim_out = round(dim_out)
|
||||
|
||||
# simulate forward pass
|
||||
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
|
||||
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
|
||||
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
|
||||
|
||||
x_int8 = x.clone().to(torch.int8)
|
||||
g_int8 = g.clone().to(torch.int8)
|
||||
w_int8 = w.clone().to(torch.int8)
|
||||
wt_int8 = w.t().contiguous().clone().to(torch.int8)
|
||||
state_x_rowwise = x.max(dim=1)[0]
|
||||
state_g_rowwise = g.max(dim=1)[0]
|
||||
state_w_columnwise = w.max(dim=0)[0]
|
||||
state_w_rowwise = w.max(dim=1)[0]
|
||||
state_w_global = w.max()
|
||||
|
||||
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
|
||||
|
||||
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
|
||||
get_time('standard_gw', lambda : g.t().matmul(x), info)
|
||||
get_time('standard_gx', lambda : g.matmul(w), info)
|
||||
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
|
||||
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
|
||||
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
|
||||
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
|
||||
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
|
||||
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
|
||||
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
|
||||
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
|
||||
get_time('w_quantize_global', lambda : quantize_global(w), info)
|
||||
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
|
||||
|
||||
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
|
||||
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
|
||||
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
|
||||
|
||||
print('TOTAL STANDARD', time_standard)
|
||||
print('TOTAL ROWWISE', time_rowwise)
|
||||
print('TOTAL GLOBAL', time_global)
|
||||
|
||||
print('speedup', -100*(time_global - time_standard)/time_standard)
|
||||
|
||||
info['time_standard'] = time_standard
|
||||
info['time_rowwise'] = time_rowwise
|
||||
info['time_global'] = time_global
|
||||
|
||||
info_json = json.dumps(info)
|
||||
|
||||
# TODO: change this to what you want.
|
||||
with open("speed_benchmark/info.jsonl", "a") as file:
|
||||
file.write(info_json + "\n")
|
|
@ -3,13 +3,14 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from . import cuda_setup, utils
|
||||
from . import cuda_setup, utils, research
|
||||
from .autograd._functions import (
|
||||
MatmulLtState,
|
||||
bmm_cublas,
|
||||
matmul,
|
||||
matmul_cublas,
|
||||
mm_cublas,
|
||||
matmul_4bit
|
||||
)
|
||||
from .cextension import COMPILED_WITH_CUDA
|
||||
from .nn import modules
|
||||
|
|
|
@ -2,7 +2,7 @@ import operator
|
|||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce # Required in Python 3
|
||||
from typing import Tuple, Optional
|
||||
from typing import Tuple, Optional, List
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _get_tile_size(format):
|
||||
assert format in (
|
||||
"col_turing",
|
||||
"col_ampere",
|
||||
), f"please find this assert and manually enter tile size for {format}"
|
||||
return (8, 32) if format == "col_turing" else (32, 32)
|
||||
|
||||
|
||||
def get_tile_inds(format, device):
|
||||
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
|
||||
|
||||
@dataclass
|
||||
class MatmulLtState:
|
||||
_tile_indices: Optional[torch.Tensor] = None
|
||||
|
@ -267,20 +280,10 @@ class MatmulLtState:
|
|||
self.SBt = None
|
||||
self.CBt = None
|
||||
|
||||
def get_tile_size(self):
|
||||
assert self.formatB in (
|
||||
"col_turing",
|
||||
"col_ampere",
|
||||
), f"please find this assert and manually enter tile size for {self.formatB}"
|
||||
return (8, 32) if self.formatB == "col_turing" else (32, 32)
|
||||
|
||||
@property
|
||||
def tile_indices(self):
|
||||
if self._tile_indices is None:
|
||||
device = self.CxB.device
|
||||
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
|
||||
with torch.no_grad():
|
||||
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
|
||||
self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
|
||||
return self._tile_indices
|
||||
|
||||
|
||||
|
@ -424,10 +427,10 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
ctx.tensors = (CAt, subA)
|
||||
ctx.tensors = (CAt, subA, A)
|
||||
ctx.tensor_states = (SCAt, state.idx)
|
||||
else:
|
||||
ctx.tensors = [None, None]
|
||||
ctx.tensors = [None, None, A]
|
||||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
|
@ -440,7 +443,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||
CAt, subA = ctx.tensors
|
||||
CAt, subA, A = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
state = ctx.state
|
||||
|
@ -487,6 +490,64 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
||||
class MatMul4Bit(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=None):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
ctx.bias = bias
|
||||
B_shape = state[1]
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias)
|
||||
|
||||
# 3. Save state
|
||||
ctx.state = state
|
||||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
ctx.tensors = (A, B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
|
||||
req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
state = ctx.state
|
||||
|
||||
grad_A, grad_B, grad_bias = None, None, None
|
||||
|
||||
if req_gradBias:
|
||||
# compute grad_bias first before changing grad_output dtype
|
||||
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
||||
if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t())
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
||||
def matmul(
|
||||
A: tensor,
|
||||
B: tensor,
|
||||
|
@ -499,3 +560,8 @@ def matmul(
|
|||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||
|
||||
|
||||
def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None):
|
||||
assert quant_state is not None
|
||||
return MatMul4Bit.apply(A, B, out, bias, quant_state)
|
||||
|
|
|
@ -18,17 +18,24 @@ try:
|
|||
CUDASetup.get_instance().generate_instructions()
|
||||
CUDASetup.get_instance().print_log_stack()
|
||||
raise RuntimeError('''
|
||||
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
|
||||
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
|
||||
https://github.com/TimDettmers/bitsandbytes/issues''')
|
||||
lib.cadam32bit_g32
|
||||
CUDA Setup failed despite GPU being available. Please run the following command to get more information:
|
||||
|
||||
python -m bitsandbytes
|
||||
|
||||
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
|
||||
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
|
||||
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''')
|
||||
lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False
|
||||
lib.get_context.restype = ct.c_void_p
|
||||
lib.get_cusparse.restype = ct.c_void_p
|
||||
lib.cget_managed_ptr.restype = ct.c_void_p
|
||||
COMPILED_WITH_CUDA = True
|
||||
except AttributeError:
|
||||
except AttributeError as ex:
|
||||
warn("The installed version of bitsandbytes was compiled without GPU support. "
|
||||
"8-bit optimizers and GPU quantization are unavailable.")
|
||||
"8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.")
|
||||
COMPILED_WITH_CUDA = False
|
||||
print(str(ex))
|
||||
|
||||
|
||||
# print the setup details after checking for errors so we do not print twice
|
||||
if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0':
|
||||
|
|
|
@ -44,6 +44,9 @@ class CUDASetup:
|
|||
raise RuntimeError("Call get_instance() instead")
|
||||
|
||||
def generate_instructions(self):
|
||||
if getattr(self, 'error', False): return
|
||||
print(self.error)
|
||||
self.error = True
|
||||
if self.cuda is None:
|
||||
self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.')
|
||||
self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.')
|
||||
|
@ -93,6 +96,7 @@ class CUDASetup:
|
|||
self.has_printed = False
|
||||
self.lib = None
|
||||
self.initialized = False
|
||||
self.error = False
|
||||
|
||||
def run_cuda_setup(self):
|
||||
self.initialized = True
|
||||
|
|
|
@ -9,6 +9,8 @@ import random
|
|||
import torch
|
||||
import itertools
|
||||
import math
|
||||
from scipy.stats import norm
|
||||
import numpy as np
|
||||
|
||||
from functools import reduce # Required in Python 3
|
||||
from typing import Tuple
|
||||
|
@ -26,77 +28,95 @@ name2qmap = {}
|
|||
if COMPILED_WITH_CUDA:
|
||||
"""C FUNCTIONS FOR OPTIMIZERS"""
|
||||
str2optimizer32bit = {}
|
||||
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
||||
str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16)
|
||||
str2optimizer32bit["momentum"] = (
|
||||
lib.cmomentum32bit_g32,
|
||||
lib.cmomentum32bit_g16,
|
||||
lib.cmomentum32bit_grad_32,
|
||||
lib.cmomentum32bit_grad_16,
|
||||
)
|
||||
str2optimizer32bit["rmsprop"] = (
|
||||
lib.crmsprop32bit_g32,
|
||||
lib.crmsprop32bit_g16,
|
||||
)
|
||||
str2optimizer32bit["lion"] = (
|
||||
lib.clion32bit_g32,
|
||||
lib.clion32bit_g16,
|
||||
lib.crmsprop32bit_grad_32,
|
||||
lib.crmsprop32bit_grad_16,
|
||||
)
|
||||
str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16)
|
||||
str2optimizer32bit["adagrad"] = (
|
||||
lib.cadagrad32bit_g32,
|
||||
lib.cadagrad32bit_g16,
|
||||
lib.cadagrad32bit_grad_32,
|
||||
lib.cadagrad32bit_grad_16,
|
||||
)
|
||||
str2optimizer32bit["lars"] = (
|
||||
lib.cmomentum32bit_g32,
|
||||
lib.cmomentum32bit_g16,
|
||||
)
|
||||
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
||||
|
||||
str2optimizer8bit = {}
|
||||
str2optimizer8bit["adam"] = (
|
||||
lib.cadam_static_8bit_g32,
|
||||
lib.cadam_static_8bit_g16,
|
||||
lib.cadam_static_8bit_grad_32,
|
||||
lib.cadam_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["momentum"] = (
|
||||
lib.cmomentum_static_8bit_g32,
|
||||
lib.cmomentum_static_8bit_g16,
|
||||
lib.cmomentum_static_8bit_grad_32,
|
||||
lib.cmomentum_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["rmsprop"] = (
|
||||
lib.crmsprop_static_8bit_g32,
|
||||
lib.crmsprop_static_8bit_g16,
|
||||
lib.crmsprop_static_8bit_grad_32,
|
||||
lib.crmsprop_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["lion"] = (
|
||||
lib.clion_static_8bit_g32,
|
||||
lib.clion_static_8bit_g16,
|
||||
lib.clion_static_8bit_grad_32,
|
||||
lib.clion_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["lamb"] = (
|
||||
lib.cadam_static_8bit_g32,
|
||||
lib.cadam_static_8bit_g16,
|
||||
lib.cadam_static_8bit_grad_32,
|
||||
lib.cadam_static_8bit_grad_16,
|
||||
)
|
||||
str2optimizer8bit["lars"] = (
|
||||
lib.cmomentum_static_8bit_g32,
|
||||
lib.cmomentum_static_8bit_g16,
|
||||
lib.cmomentum_static_8bit_grad_32,
|
||||
lib.cmomentum_static_8bit_grad_16,
|
||||
)
|
||||
|
||||
str2optimizer8bit_blockwise = {}
|
||||
str2optimizer8bit_blockwise["adam"] = (
|
||||
lib.cadam_8bit_blockwise_fp32,
|
||||
lib.cadam_8bit_blockwise_fp16,
|
||||
lib.cadam_8bit_blockwise_grad_fp32,
|
||||
lib.cadam_8bit_blockwise_grad_fp16,
|
||||
lib.cadam_8bit_blockwise_grad_bf16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["momentum"] = (
|
||||
lib.cmomentum_8bit_blockwise_fp32,
|
||||
lib.cmomentum_8bit_blockwise_fp16,
|
||||
lib.cmomentum_8bit_blockwise_grad_fp32,
|
||||
lib.cmomentum_8bit_blockwise_grad_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["rmsprop"] = (
|
||||
lib.crmsprop_8bit_blockwise_fp32,
|
||||
lib.crmsprop_8bit_blockwise_fp16,
|
||||
lib.crmsprop_8bit_blockwise_grad_fp32,
|
||||
lib.crmsprop_8bit_blockwise_grad_fp16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["lion"] = (
|
||||
lib.clion_8bit_blockwise_fp32,
|
||||
lib.clion_8bit_blockwise_fp16,
|
||||
lib.clion_8bit_blockwise_grad_fp32,
|
||||
lib.clion_8bit_blockwise_grad_fp16,
|
||||
lib.clion_8bit_blockwise_grad_bf16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["adagrad"] = (
|
||||
lib.cadagrad_8bit_blockwise_fp32,
|
||||
lib.cadagrad_8bit_blockwise_fp16,
|
||||
lib.cadagrad_8bit_blockwise_grad_fp32,
|
||||
lib.cadagrad_8bit_blockwise_grad_fp16,
|
||||
)
|
||||
|
||||
class GlobalPageManager:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("Call get_instance() instead")
|
||||
|
||||
def initialize(self):
|
||||
self.paged_tensors = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls.__new__(cls)
|
||||
cls._instance.initialize()
|
||||
return cls._instance
|
||||
|
||||
def prefetch_all(self, to_cpu=False):
|
||||
# assume the first added, will be hte
|
||||
# ones that are used first, so swap them in last
|
||||
# in the case they are evicted again
|
||||
for t in self.paged_tensors[::-1]:
|
||||
prefetch_tensor(t, to_cpu)
|
||||
|
||||
|
||||
|
||||
class CUBLAS_Context:
|
||||
_instance = None
|
||||
|
@ -106,11 +126,6 @@ class CUBLAS_Context:
|
|||
|
||||
def initialize(self):
|
||||
self.context = {}
|
||||
# prev_device = torch.cuda.current_device()
|
||||
# for i in range(torch.cuda.device_count()):
|
||||
# torch.cuda.set_device(torch.device('cuda', i))
|
||||
# self.context.append(ct.c_void_p(lib.get_context()))
|
||||
# torch.cuda.set_device(prev_device)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
|
@ -144,6 +159,61 @@ class Cusparse_Context:
|
|||
cls._instance.initialize()
|
||||
return cls._instance
|
||||
|
||||
dtype2bytes = {}
|
||||
dtype2bytes[torch.float32] = 4
|
||||
dtype2bytes[torch.float16] = 2
|
||||
dtype2bytes[torch.bfloat16] = 2
|
||||
dtype2bytes[torch.uint8] = 1
|
||||
dtype2bytes[torch.int8] = 1
|
||||
|
||||
def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)):
|
||||
num_bytes = dtype2bytes[dtype]*prod(shape)
|
||||
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
|
||||
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
|
||||
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
|
||||
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape)
|
||||
out.is_paged = True
|
||||
out.page_deviceid = device.index
|
||||
return out
|
||||
|
||||
def prefetch_tensor(A, to_cpu=False):
|
||||
assert A.is_paged, 'Only paged tensors can be prefetched!'
|
||||
if to_cpu:
|
||||
deviceid = -1
|
||||
else:
|
||||
deviceid = A.page_deviceid
|
||||
|
||||
num_bytes = dtype2bytes[A.dtype]*A.numel()
|
||||
lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid))
|
||||
|
||||
def elementwise_func(func_name, A, B, value, prefetch=True):
|
||||
func = None
|
||||
if A.dtype == torch.float32:
|
||||
func = getattr(lib, f'c{func_name}_fp32', None)
|
||||
cvalue = ct.c_float(value)
|
||||
elif A.dtype == torch.uint8:
|
||||
func = getattr(lib, f'c{func_name}_uint8', None)
|
||||
cvalue = ct.c_uint8(value)
|
||||
|
||||
if func is None: raise NotImplementedError(f'Function not implemented: {func_name}')
|
||||
|
||||
is_managed = getattr(A, 'is_managed', False)
|
||||
if is_managed and prefetch:
|
||||
prefetch_tensor(A)
|
||||
if B is not None: prefetch_tensor(B)
|
||||
|
||||
func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel()))
|
||||
if A.is_paged or B.is_paged:
|
||||
# paged function are fully asynchronous
|
||||
# if we return from this function, we want to the tensor
|
||||
# to be in the correct state, that is the final state after the
|
||||
# operation occured. So we synchronize.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value)
|
||||
def arange(A, device=None): elementwise_func('arange', A, None, 0)
|
||||
def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0)
|
||||
|
||||
|
||||
def create_linear_map(signed=True, total_bits=8, add_zero=True):
|
||||
sign = (-1.0 if signed else 0.0)
|
||||
|
@ -161,9 +231,27 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True):
|
|||
return values
|
||||
else:
|
||||
l = values.numel()//2
|
||||
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
|
||||
return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist())
|
||||
|
||||
def create_normal_map(offset=0.9677083, use_extra_value=True):
|
||||
|
||||
if use_extra_value:
|
||||
# one more positive value, this is an asymmetric type
|
||||
v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist()
|
||||
v2 = [0]*(256-15) ## we have 15 non-zero values in this data type
|
||||
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
|
||||
v = v1 + v2 + v3
|
||||
else:
|
||||
v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist()
|
||||
v2 = [0]*(256-14) ## we have 14 non-zero values in this data type
|
||||
v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist()
|
||||
v = v1 + v2 + v3
|
||||
|
||||
values = torch.Tensor(v)
|
||||
values = values.sort().values
|
||||
values /= values.max()
|
||||
assert values.numel() == 256
|
||||
return values
|
||||
|
||||
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8):
|
||||
e = exponent_bits
|
||||
|
@ -180,7 +268,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
values = []
|
||||
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
||||
#for ev in evalues:
|
||||
bias = 2**(exponent_bits-1)-1
|
||||
bias = 2**(exponent_bits-1)
|
||||
for evalue in range(2**(exponent_bits)):
|
||||
for bit_pattern in lst:
|
||||
value = (1 if evalue != 0 else 0)
|
||||
|
@ -188,10 +276,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
|
|||
value += pval*(2**-(i+1))
|
||||
if evalue == 0:
|
||||
# subnormals
|
||||
value = value*2**-(bias-1)
|
||||
value = value*2**-(bias)
|
||||
else:
|
||||
# normals
|
||||
value = value*2**-(evalue-bias-2)
|
||||
value = value*2**-(evalue-bias-1)
|
||||
values.append(value)
|
||||
if signed:
|
||||
values.append(-value)
|
||||
|
@ -289,9 +377,17 @@ def get_special_format_str():
|
|||
|
||||
def is_on_gpu(tensors):
|
||||
on_gpu = True
|
||||
gpu_ids = set()
|
||||
for t in tensors:
|
||||
if t is None: continue # NULL pointers are fine
|
||||
on_gpu &= t.device.type == 'cuda'
|
||||
is_paged = getattr(t, 'is_paged', False)
|
||||
on_gpu &= (t.device.type == 'cuda' or is_paged)
|
||||
if not is_paged:
|
||||
gpu_ids.add(t.device.index)
|
||||
if not on_gpu:
|
||||
raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}')
|
||||
if len(gpu_ids) > 1:
|
||||
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
|
||||
return on_gpu
|
||||
|
||||
def get_ptr(A: Tensor) -> ct.c_void_p:
|
||||
|
@ -469,7 +565,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n
|
|||
return out
|
||||
|
||||
|
||||
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
|
||||
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor:
|
||||
"""
|
||||
Quantize tensor A in blocks of size 4096 values.
|
||||
|
||||
|
@ -485,8 +581,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
The quantization map.
|
||||
absmax : torch.Tensor
|
||||
The absmax values.
|
||||
rand : torch.Tensor
|
||||
The tensor for stochastic rounding.
|
||||
out : torch.Tensor
|
||||
The output tensor (8-bit).
|
||||
|
||||
|
@ -518,33 +612,30 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
|
|||
cblocksize = ct.c_int32(blocksize)
|
||||
prev_device = pre_call(A.device)
|
||||
code = code.to(A.device)
|
||||
if rand is not None:
|
||||
is_on_gpu([code, A, out, absmax, rand])
|
||||
assert blocksize==4096
|
||||
assert rand.numel() >= 1024
|
||||
rand_offset = random.randint(0, 1023)
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
is_on_gpu([code, A, out, absmax])
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
is_on_gpu([code, A, out, absmax])
|
||||
if A.dtype == torch.float32:
|
||||
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
else:
|
||||
# cpu
|
||||
code = code.cpu()
|
||||
assert rand is None
|
||||
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
||||
|
||||
return out, (absmax, code)
|
||||
if nested:
|
||||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False)
|
||||
state = [qabsmax, code, blocksize, nested, offset, state2]
|
||||
else:
|
||||
state = [absmax, code, blocksize, nested, None, None]
|
||||
|
||||
|
||||
|
||||
return out, state
|
||||
|
||||
|
||||
def dequantize_blockwise(
|
||||
|
@ -554,6 +645,7 @@ def dequantize_blockwise(
|
|||
code: Tensor = None,
|
||||
out: Tensor = None,
|
||||
blocksize: int = 4096,
|
||||
nested=False
|
||||
) -> Tensor:
|
||||
"""
|
||||
Dequantizes blockwise quantized values.
|
||||
|
@ -588,10 +680,15 @@ def dequantize_blockwise(
|
|||
|
||||
if out is None:
|
||||
out = torch.zeros_like(A, dtype=torch.float32)
|
||||
|
||||
if quant_state is None:
|
||||
quant_state = (absmax, code)
|
||||
quant_state = (absmax, code, blocksize)
|
||||
assert absmax is not None and out is not None
|
||||
else:
|
||||
absmax, code = quant_state
|
||||
absmax, code, blocksize, nested, offset, state2 = quant_state
|
||||
if nested:
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
|
||||
if A.device.type != 'cpu':
|
||||
|
@ -599,7 +696,7 @@ def dequantize_blockwise(
|
|||
code = code.to(A.device)
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
|
||||
is_on_gpu([A, out])
|
||||
is_on_gpu([A, absmax, out])
|
||||
if out.dtype == torch.float32:
|
||||
lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
||||
elif out.dtype == torch.float16:
|
||||
|
@ -613,6 +710,164 @@ def dequantize_blockwise(
|
|||
|
||||
return out
|
||||
|
||||
def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
|
||||
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4')
|
||||
|
||||
def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False):
|
||||
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4')
|
||||
|
||||
def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor:
|
||||
"""
|
||||
Quantize tensor A in blocks of 4-bit values.
|
||||
|
||||
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : torch.Tensor
|
||||
The input tensor.
|
||||
absmax : torch.Tensor
|
||||
The absmax values.
|
||||
out : torch.Tensor
|
||||
The output tensor (8-bit).
|
||||
blocksize : int
|
||||
The blocksize used in quantization.
|
||||
quant_type : str
|
||||
The 4-bit quantization data type {fp4, nf4}
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor:
|
||||
The 8-bit tensor with packed 4-bit values.
|
||||
tuple(torch.Tensor, torch.Size, torch.dtype, int):
|
||||
The quantization state to undo the quantization.
|
||||
"""
|
||||
if A.device.type != 'cuda':
|
||||
raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}')
|
||||
if quant_type not in ['fp4', 'nf4']:
|
||||
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
|
||||
|
||||
n = A.numel()
|
||||
input_shape = A.shape
|
||||
|
||||
if absmax is None:
|
||||
blocks = n // blocksize
|
||||
blocks += 1 if n % blocksize > 0 else 0
|
||||
absmax = torch.zeros((blocks,), device=A.device)
|
||||
|
||||
|
||||
if out is None:
|
||||
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)
|
||||
|
||||
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
|
||||
|
||||
prev_device = pre_call(A.device)
|
||||
is_on_gpu([A, out, absmax])
|
||||
|
||||
if A.dtype == torch.float32:
|
||||
if quant_type == 'fp4':
|
||||
lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
elif A.dtype == torch.float16:
|
||||
if quant_type == 'fp4':
|
||||
lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
||||
if compress_statistics:
|
||||
offset = absmax.mean()
|
||||
absmax -= offset
|
||||
#code = create_custom_map().to(absmax.device)
|
||||
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
|
||||
qabsmax, state2 = quantize_blockwise(absmax, blocksize=256)
|
||||
del absmax
|
||||
state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
|
||||
else:
|
||||
state = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
|
||||
|
||||
return out, state
|
||||
|
||||
def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
|
||||
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4')
|
||||
|
||||
def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor:
|
||||
return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4')
|
||||
|
||||
def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor:
|
||||
"""
|
||||
Dequantizes FP4 blockwise quantized values.
|
||||
|
||||
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A : torch.Tensor
|
||||
The input 8-bit tensor (packed 4-bit values).
|
||||
quant_state : tuple(torch.Tensor, torch.Size, torch.dtype)
|
||||
Tuple of absmax values, original tensor shape and original dtype.
|
||||
absmax : torch.Tensor
|
||||
The absmax values.
|
||||
out : torch.Tensor
|
||||
Dequantized output tensor.
|
||||
blocksize : int
|
||||
The blocksize used in quantization.
|
||||
quant_type : str
|
||||
The 4-bit quantization data type {fp4, nf4}
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor:
|
||||
Dequantized tensor.
|
||||
"""
|
||||
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
|
||||
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]")
|
||||
if quant_type not in ['fp4', 'nf4']:
|
||||
raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.')
|
||||
|
||||
if quant_state is None:
|
||||
assert absmax is not None and out is not None
|
||||
shape = out.shape
|
||||
dtype = out.dtype
|
||||
else:
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state
|
||||
|
||||
|
||||
if compressed_stats is not None:
|
||||
offset, state2 = compressed_stats
|
||||
absmax = dequantize_blockwise(absmax, state2)
|
||||
absmax += offset
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(shape, dtype=dtype, device=A.device)
|
||||
|
||||
n = out.numel()
|
||||
|
||||
|
||||
device = pre_call(A.device)
|
||||
is_on_gpu([A, absmax, out])
|
||||
if out.dtype == torch.float32:
|
||||
if quant_type == 'fp4':
|
||||
lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
elif out.dtype == torch.float16:
|
||||
if quant_type == 'fp4':
|
||||
lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n))
|
||||
else:
|
||||
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
|
||||
post_call(A.device)
|
||||
|
||||
is_transposed = (True if A.shape[0] == 1 else False)
|
||||
if is_transposed: return out.t()
|
||||
else: return out
|
||||
|
||||
|
||||
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
|
||||
if code is None:
|
||||
|
@ -765,55 +1020,36 @@ def optimizer_update_32bit(
|
|||
if max_unorm > 0.0:
|
||||
param_norm = torch.norm(p.data.float())
|
||||
|
||||
if optimizer_name not in str2optimizer32bit:
|
||||
raise NotImplementedError(
|
||||
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
|
||||
)
|
||||
|
||||
prev_device = pre_call(g.device)
|
||||
is_on_gpu([g, p, state1, state2, unorm_vec])
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.float32:
|
||||
str2optimizer32bit[optimizer_name][0](
|
||||
get_ptr(g),
|
||||
get_ptr(p),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
get_ptr(unorm_vec),
|
||||
ct.c_float(max_unorm),
|
||||
ct.c_float(param_norm),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
elif g.dtype == torch.float16 and state1.dtype == torch.float32:
|
||||
str2optimizer32bit[optimizer_name][1](
|
||||
get_ptr(g),
|
||||
get_ptr(p),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
get_ptr(unorm_vec),
|
||||
ct.c_float(max_unorm),
|
||||
ct.c_float(param_norm),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
optim_func = None
|
||||
if g.dtype == torch.float32:
|
||||
optim_func = str2optimizer32bit[optimizer_name][0]
|
||||
elif g.dtype == torch.float16:
|
||||
optim_func = str2optimizer32bit[optimizer_name][1]
|
||||
elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3):
|
||||
optim_func = str2optimizer32bit[optimizer_name][2]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}")
|
||||
|
||||
is_on_gpu([g, p, state1, state2, unorm_vec])
|
||||
prev_device = pre_call(g.device)
|
||||
optim_func(
|
||||
get_ptr(g),
|
||||
get_ptr(p),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
get_ptr(unorm_vec),
|
||||
ct.c_float(max_unorm),
|
||||
ct.c_float(param_norm),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()))
|
||||
post_call(prev_device)
|
||||
|
||||
|
||||
|
@ -970,54 +1206,45 @@ def optimizer_update_8bit_blockwise(
|
|||
skip_zeros=False,
|
||||
) -> None:
|
||||
|
||||
optim_func = None
|
||||
prev_device = pre_call(g.device)
|
||||
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit_blockwise[optimizer_name][0](
|
||||
get_ptr(p),
|
||||
get_ptr(g),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
get_ptr(qmap1),
|
||||
get_ptr(qmap2),
|
||||
get_ptr(absmax1),
|
||||
get_ptr(absmax2),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
|
||||
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit_blockwise[optimizer_name][1](
|
||||
get_ptr(p),
|
||||
get_ptr(g),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
get_ptr(qmap1),
|
||||
get_ptr(qmap2),
|
||||
get_ptr(absmax1),
|
||||
get_ptr(absmax2),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
|
||||
elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and
|
||||
len(str2optimizer8bit_blockwise[optimizer_name])==3):
|
||||
optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
post_call(prev_device)
|
||||
|
||||
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
|
||||
|
||||
prev_device = pre_call(g.device)
|
||||
optim_func(
|
||||
get_ptr(p),
|
||||
get_ptr(g),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
get_ptr(qmap1),
|
||||
get_ptr(qmap2),
|
||||
get_ptr(absmax1),
|
||||
get_ptr(absmax2),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
post_call(prev_device)
|
||||
|
||||
def percentile_clipping(
|
||||
grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
|
||||
|
@ -1171,6 +1398,123 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8
|
|||
|
||||
return sout
|
||||
|
||||
def cutlass3_gemm(
|
||||
A: Tensor,
|
||||
B: Tensor,
|
||||
out: Tensor = None,
|
||||
transposed_A=False,
|
||||
transposed_B=False,
|
||||
state=None
|
||||
):
|
||||
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
|
||||
if state is None:
|
||||
Bshape = B.shape
|
||||
bout = Bshape[1]
|
||||
else:
|
||||
Bshape = state[1]
|
||||
bout = Bshape[0]
|
||||
if out is None:
|
||||
out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device)
|
||||
|
||||
sA = A.shape
|
||||
sB = B.shape
|
||||
if transposed_A and len(sA) == 2:
|
||||
sA = (sA[1], sA[0])
|
||||
elif transposed_A and len(sA) == 3:
|
||||
sA = (sA[0], sA[2], sA[0])
|
||||
if transposed_B and len(sB) == 2:
|
||||
sB = (sB[1], sB[0])
|
||||
elif transposed_B and len(sB) == 3:
|
||||
sB = (sB[0], sB[2], sB[0])
|
||||
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
|
||||
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
|
||||
# (transpose of row major is column major)
|
||||
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
|
||||
|
||||
# matrices in the input arguments for cuBLAS
|
||||
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
|
||||
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
|
||||
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
|
||||
if len(sB) == 2:
|
||||
if B.stride()[0] == B.shape[1]:
|
||||
transposed_B = False
|
||||
elif B.stride()[1] == B.shape[0]:
|
||||
transposed_B = True
|
||||
if len(A.shape) == 2:
|
||||
if A.stride()[0] == A.shape[1]:
|
||||
transposed_A = False
|
||||
elif A.stride()[1] == A.shape[0]:
|
||||
transposed_A = True
|
||||
else:
|
||||
if A.stride()[1] == A.shape[2]:
|
||||
transposed_A = False
|
||||
elif A.stride()[2] == A.shape[1]:
|
||||
transposed_A = True
|
||||
|
||||
if len(sA) == 2:
|
||||
n = sA[0]
|
||||
ldb = A.stride()[1 if transposed_A else 0]
|
||||
elif len(sA) == 3 and len(sB) == 2:
|
||||
n = sA[0] * sA[1]
|
||||
ldb = sA[2]
|
||||
|
||||
m = sB[1]
|
||||
k = sB[0]
|
||||
lda = B.stride()[0]
|
||||
ldc = sB[1]
|
||||
elif len(sB) == 3:
|
||||
# special case
|
||||
assert len(sA) == 3
|
||||
if not (sA[0] == sB[0] and sA[1] == sB[1]):
|
||||
raise ValueError(
|
||||
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
|
||||
)
|
||||
|
||||
transposed_A = True
|
||||
transposed_B = False
|
||||
|
||||
m = sB[2]
|
||||
n = sA[2]
|
||||
k = sB[0] * sB[1]
|
||||
|
||||
lda = n
|
||||
ldb = sA[2]
|
||||
ldc = m
|
||||
|
||||
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
||||
|
||||
# B^T @ A^T = C^T
|
||||
# [km, nk -> mn]
|
||||
#lda = ldb = ldc = 1
|
||||
#lda = 1
|
||||
if state is not None:
|
||||
m = Bshape[0]
|
||||
k = Bshape[1]
|
||||
lda = Bshape[0]
|
||||
ldc = Bshape[0]
|
||||
ldb = (ldb+1)//2
|
||||
#print(m, n, k, lda, ldb, ldc)
|
||||
is_on_gpu([B, A, out])
|
||||
m = ct.c_int32(m)
|
||||
n = ct.c_int32(n)
|
||||
k = ct.c_int32(k)
|
||||
lda = ct.c_int32(lda)
|
||||
ldb = ct.c_int32(ldb)
|
||||
ldc = ct.c_int32(ldc)
|
||||
|
||||
if B.dtype == torch.uint8:
|
||||
lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3]))
|
||||
elif A.dtype == torch.float32:
|
||||
lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
elif A.dtype == torch.float16:
|
||||
lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc)
|
||||
else:
|
||||
raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}')
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
def igemm(
|
||||
A: Tensor,
|
||||
|
@ -1845,8 +2189,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
|
|||
ccolsB = ct.c_int32(B.shape[1])
|
||||
cldb = ct.c_int32(ldb)
|
||||
cldc = ct.c_int32(ldc)
|
||||
# print(cooA.rowidx[:64])
|
||||
# print(cooA.colidx[:64].sort()[0])
|
||||
|
||||
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
|
||||
if B.dtype == torch.float16:
|
||||
|
@ -2044,3 +2386,8 @@ def extract_outliers(A, SA, idx):
|
|||
post_call(prev_device)
|
||||
|
||||
return out
|
||||
|
||||
def pipeline_test(A, batch_size):
|
||||
out = torch.zeros_like(A)
|
||||
lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
|
||||
return out
|
||||
|
|
|
@ -2,4 +2,5 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb
|
||||
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear
|
||||
|
|
|
@ -10,8 +10,9 @@ from torch import Tensor, device, dtype, nn
|
|||
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional
|
||||
from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout
|
||||
from bitsandbytes.autograd._functions import undo_layout, get_tile_inds
|
||||
from bitsandbytes.optim import GlobalOptimManager
|
||||
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
@ -135,6 +136,101 @@ class Embedding(torch.nn.Embedding):
|
|||
|
||||
return emb
|
||||
|
||||
class Params4bit(torch.nn.Parameter):
|
||||
def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'):
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
|
||||
self = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
self.blocksize = blocksize
|
||||
self.compress_statistics = compress_statistics
|
||||
self.quant_type = quant_type
|
||||
self.quant_state = quant_state
|
||||
self.data = data
|
||||
return self
|
||||
|
||||
def cuda(self, device):
|
||||
w = self.data.contiguous().half().cuda(device)
|
||||
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
|
||||
self.data = w_4bit
|
||||
self.quant_state = quant_state
|
||||
|
||||
return self
|
||||
|
||||
@overload
|
||||
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
|
||||
...
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
|
||||
if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"):
|
||||
return self.cuda(device)
|
||||
else:
|
||||
s = self.quant_state
|
||||
if s is not None:
|
||||
# make sure the quantization state is on the right device
|
||||
s[0] = s[0].to(device)
|
||||
if self.compress_statistics:
|
||||
# TODO: refactor this. This is a nightmare
|
||||
# for 4-bit:
|
||||
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
|
||||
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
|
||||
#s[-2][0] = s[-2][0].to(device) # offset
|
||||
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
|
||||
|
||||
# for 8-bit
|
||||
s[-2][0] = s[-2][0].to(device) # offset
|
||||
s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
|
||||
s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook
|
||||
new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking),
|
||||
requires_grad=self.requires_grad, quant_state=self.quant_state,
|
||||
blocksize=self.blocksize, compress_statistics=self.compress_statistics,
|
||||
quant_type=self.quant_type)
|
||||
|
||||
return new_param
|
||||
|
||||
class Linear4bit(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
|
||||
self.compute_dtype = compute_dtype
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != x.dtype:
|
||||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
if getattr(self.weight, 'quant_state', None) is None:
|
||||
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
||||
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
|
||||
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
|
||||
|
||||
out = out.to(inp_dtype)
|
||||
|
||||
return out
|
||||
|
||||
class LinearFP4(Linear4bit):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4')
|
||||
|
||||
class LinearNF4(Linear4bit):
|
||||
def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True):
|
||||
super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4')
|
||||
|
||||
|
||||
|
||||
class Int8Params(torch.nn.Parameter):
|
||||
def __new__(
|
||||
|
@ -210,6 +306,18 @@ class Int8Params(torch.nn.Parameter):
|
|||
return new_param
|
||||
|
||||
|
||||
def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
weight = state_dict.get(f"{prefix}weight")
|
||||
if weight is None:
|
||||
# if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing
|
||||
return
|
||||
weight_format = state_dict.pop(f"{prefix}weight_format", "row")
|
||||
|
||||
if weight_format != "row":
|
||||
tile_indices = get_tile_inds(weight_format, weight.device)
|
||||
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True,
|
||||
memory_efficient_backward=False, threshold=0.0, index=None):
|
||||
|
@ -225,52 +333,55 @@ class Linear8bitLt(nn.Linear):
|
|||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights)
|
||||
self._register_load_state_dict_pre_hook(maybe_rearrange_weight)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None:
|
||||
# reorder weight layout back from ampere/turing to row
|
||||
reorder_layout = True
|
||||
weight_clone = self.weight.data.clone()
|
||||
else:
|
||||
reorder_layout = False
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
try:
|
||||
if reorder_layout:
|
||||
self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices)
|
||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||
scb_name = "SCB"
|
||||
|
||||
super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
# case 1: .cuda was called, SCB is in self.weight
|
||||
param_from_weight = getattr(self.weight, scb_name)
|
||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||
param_from_state = getattr(self.state, scb_name)
|
||||
# case 3: SCB is in self.state, weight layout reordered after first forward()
|
||||
layout_reordered = self.state.CxB is not None
|
||||
|
||||
# we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data
|
||||
weight_name = "SCB"
|
||||
key_name = prefix + f"{scb_name}"
|
||||
format_name = prefix + "weight_format"
|
||||
|
||||
# case 1: .cuda was called, SCB is in self.weight
|
||||
param_from_weight = getattr(self.weight, weight_name)
|
||||
# case 2: self.init_8bit_state was called, SCB is in self.state
|
||||
param_from_state = getattr(self.state, weight_name)
|
||||
|
||||
key_name = prefix + f"{weight_name}"
|
||||
if not self.state.has_fp16_weights:
|
||||
if param_from_weight is not None:
|
||||
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
|
||||
elif not self.state.has_fp16_weights and param_from_state is not None:
|
||||
destination[format_name] = "row"
|
||||
elif param_from_state is not None and not layout_reordered:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
finally:
|
||||
if reorder_layout:
|
||||
self.weight.data = weight_clone
|
||||
destination[format_name] = "row"
|
||||
elif param_from_state is not None:
|
||||
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
|
||||
destination[format_name] = self.state.formatB
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
for key in unexpected_keys:
|
||||
unexpected_copy = list(unexpected_keys)
|
||||
|
||||
for key in unexpected_copy:
|
||||
input_name = key[len(prefix):]
|
||||
if input_name == "SCB":
|
||||
if self.weight.SCB is None:
|
||||
# buffers not yet initialized, can't call them directly without
|
||||
# buffers not yet initialized, can't access them directly without quantizing first
|
||||
raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is "
|
||||
"not supported. Please call module.cuda() before module.load_state_dict()")
|
||||
|
||||
input_param = state_dict[key]
|
||||
self.weight.SCB.copy_(input_param)
|
||||
|
||||
if self.state.SCB is not None:
|
||||
self.state.SCB = self.weight.SCB
|
||||
|
||||
unexpected_keys.remove(key)
|
||||
|
||||
def init_8bit_state(self):
|
||||
|
@ -289,6 +400,7 @@ class Linear8bitLt(nn.Linear):
|
|||
self.bias.data = self.bias.data.to(x.dtype)
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if self.state.CB is not None and self.state.CxB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
|
@ -296,3 +408,71 @@ class Linear8bitLt(nn.Linear):
|
|||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
return out
|
||||
|
||||
|
||||
class OutlierAwareLinear(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.outlier_dim = None
|
||||
self.is_quantized = False
|
||||
|
||||
def forward_with_outliers(self, x, outlier_idx):
|
||||
raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function')
|
||||
|
||||
def quantize_weight(self, w, outlier_idx):
|
||||
raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function')
|
||||
|
||||
def forward(self, x):
|
||||
if self.outlier_dim is None:
|
||||
tracer = OutlierTracer.get_instance()
|
||||
if not tracer.is_initialized():
|
||||
print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer')
|
||||
outlier_idx = tracer.get_outliers(self.weight)
|
||||
#print(outlier_idx, tracer.get_hvalue(self.weight))
|
||||
self.outlier_dim = outlier_idx
|
||||
|
||||
if not self.is_quantized:
|
||||
w = self.quantize_weight(self.weight, self.outlier_dim)
|
||||
self.weight.data.copy_(w)
|
||||
self.is_quantized = True
|
||||
|
||||
class SwitchBackLinearBnb(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
input_features,
|
||||
output_features,
|
||||
bias=True,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=0.0,
|
||||
index=None,
|
||||
):
|
||||
super().__init__(
|
||||
input_features, output_features, bias
|
||||
)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index = index
|
||||
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
self.state.memory_efficient_backward = memory_efficient_backward
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(
|
||||
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
|
||||
)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x):
|
||||
self.state.is_training = self.training
|
||||
|
||||
if self.weight.CB is not None:
|
||||
self.init_8bit_state()
|
||||
|
||||
out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||
|
|
258
bitsandbytes/nn/triton_based_modules.py
Normal file
258
bitsandbytes/nn/triton_based_modules.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
|
||||
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
||||
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
|
||||
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
||||
|
||||
|
||||
class _switchback_global(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, X_3D, W, bias):
|
||||
# reshape input to [N * L, D]
|
||||
X = X_3D.view(-1, X_3D.size(-1))
|
||||
|
||||
# rowwise quantize for X, global quantize for W
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
W_int8, state_W = quantize_global(W)
|
||||
|
||||
# save for backward.
|
||||
ctx.save_for_backward = X, W
|
||||
|
||||
# matmult, fused dequant and add bias
|
||||
# call "mixed" because we are mixing rowwise quantized and global quantized
|
||||
return int8_matmul_mixed_dequanitze(
|
||||
X_int8, W_int8.t(), state_X, state_W, bias
|
||||
).view(*X_3D.size()[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
# reshape input to [N_out * L, D]
|
||||
G = G_3D.reshape(-1, G_3D.size(-1))
|
||||
|
||||
grad_X = grad_W = grad_bias = None
|
||||
|
||||
X, W = ctx.save_for_backward
|
||||
if ctx.needs_input_grad[0]:
|
||||
# rowwise quantize for G, global quantize for W
|
||||
# for W, we also fuse the transpose operation because only A @ B^T is supported
|
||||
# so we transpose once then call .t() in the matmul
|
||||
G_int8, state_G = quantize_rowwise(G)
|
||||
W_int8, state_W = quantize_global_transpose(W)
|
||||
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
|
||||
*G_3D.size()[:-1], -1
|
||||
)
|
||||
if ctx.needs_input_grad[1]:
|
||||
# backward pass uses standard weight grad
|
||||
grad_W = torch.matmul(G.t(), X.to(G.dtype))
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = G.sum(dim=0)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
class _switchback_vectorrize(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, X_3D, W, bias):
|
||||
# reshape input to [N * L, D]
|
||||
X = X_3D.view(-1, X_3D.size(-1))
|
||||
|
||||
ctx.save_for_backward = X, W
|
||||
# rowwise quantize for X
|
||||
# columnwise quantize for W (first rowwise, transpose later)
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
W_int8, state_W = quantize_rowwise(W)
|
||||
|
||||
# matmult, fused dequant and add bias
|
||||
# call kernel which expects rowwise quantized X and W
|
||||
return int8_matmul_rowwise_dequantize(
|
||||
X_int8, W_int8.t(), state_X, state_W, bias
|
||||
).view(*X_3D.size()[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
X, W = ctx.save_for_backward
|
||||
|
||||
G = G_3D.reshape(-1, G_3D.size(-1))
|
||||
|
||||
grad_X = grad_W = grad_bias = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
# rowwise quantize for G, columnwise quantize for W and fused transpose
|
||||
# we call .t() for weight later because only A @ B^T is supported
|
||||
G_int8, state_G = quantize_rowwise(G)
|
||||
W_int8, state_W = quantize_columnwise_and_transpose(W)
|
||||
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
|
||||
*G_3D.size()[:-1], -1
|
||||
)
|
||||
if ctx.needs_input_grad[1]:
|
||||
# backward pass uses standard weight grad
|
||||
grad_W = torch.matmul(G.t(), X.to(G.dtype))
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = G.sum(dim=0)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
class _switchback_global_mem_efficient(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, X_3D, W, bias):
|
||||
# reshape input to [N * L, D]
|
||||
X = X_3D.view(-1, X_3D.size(-1))
|
||||
X_3D_sz = X_3D.size()
|
||||
|
||||
# rowwise quantize for X, global quantize for W
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
del X
|
||||
W_int8, state_W = quantize_global(W)
|
||||
|
||||
# save for backward.
|
||||
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
|
||||
|
||||
# matmult, fused dequant and add bias
|
||||
# call "mixed" because we are mixing rowwise quantized and global quantized
|
||||
return int8_matmul_mixed_dequanitze(
|
||||
X_int8, W_int8.t(), state_X, state_W, bias
|
||||
).view(*X_3D_sz[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
# reshape input to [N_out * L, D]
|
||||
G = G_3D.reshape(-1, G_3D.size(-1))
|
||||
G_3D_sz = G_3D.size()
|
||||
|
||||
grad_X = grad_W = grad_bias = None
|
||||
|
||||
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
|
||||
if ctx.needs_input_grad[1]:
|
||||
real_X = dequantize_rowwise(X_int8, state_X)
|
||||
del X_int8
|
||||
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
|
||||
del real_X
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = G.sum(dim=0)
|
||||
if ctx.needs_input_grad[0]:
|
||||
G_int8, state_G = quantize_rowwise(G)
|
||||
del G
|
||||
W_int8 = W_int8.t().contiguous()
|
||||
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
|
||||
*G_3D_sz[:-1], -1
|
||||
)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
class SwitchBackLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
vector_wise_quantization: bool = False,
|
||||
mem_efficient : bool = False,
|
||||
):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
|
||||
if not is_triton_available:
|
||||
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
|
||||
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
|
||||
|
||||
# By default, we use the global quantization.
|
||||
self.vector_wise_quantization = vector_wise_quantization
|
||||
if self.vector_wise_quantization:
|
||||
self._fn = _switchback_vectorrize
|
||||
if mem_efficient:
|
||||
print('mem efficient is not supported for vector-wise quantization.')
|
||||
exit(1)
|
||||
else:
|
||||
if mem_efficient:
|
||||
self._fn = _switchback_global_mem_efficient
|
||||
else:
|
||||
self._fn = _switchback_global
|
||||
|
||||
def prepare_for_eval(self):
|
||||
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
|
||||
# Note this is experimental and not tested thoroughly.
|
||||
# Note this needs to be explicitly called with something like
|
||||
# def cond_prepare(m):
|
||||
# if hasattr(m, "prepare_for_eval"):
|
||||
# m.prepare_for_eval()
|
||||
# model.apply(cond_prepare)
|
||||
print('=> preparing for eval.')
|
||||
if self.vector_wise_quantization:
|
||||
W_int8, state_W = quantize_rowwise(self.weight)
|
||||
else:
|
||||
W_int8, state_W = quantize_global(self.weight)
|
||||
|
||||
self.register_buffer("W_int8", W_int8)
|
||||
self.register_buffer("state_W", state_W)
|
||||
|
||||
del self.weight
|
||||
|
||||
def forward(self, x):
|
||||
if self.training:
|
||||
return self._fn.apply(x, self.weight, self.bias)
|
||||
else:
|
||||
# If it hasn't been "prepared for eval", run the standard forward pass.
|
||||
if not hasattr(self, "W_int8"):
|
||||
return self._fn.apply(x, self.weight, self.bias)
|
||||
|
||||
# Otherwise, use pre-computed weights.
|
||||
X = x.view(-1, x.size(-1))
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
|
||||
if self.vector_wise_quantization:
|
||||
return int8_matmul_rowwise_dequantize(
|
||||
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
||||
).view(*x.size()[:-1], -1)
|
||||
else:
|
||||
return int8_matmul_mixed_dequanitze(
|
||||
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
||||
).view(*x.size()[:-1], -1)
|
||||
|
||||
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
|
||||
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
|
||||
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
|
||||
|
||||
# This is just the standard linear function.
|
||||
class StandardLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias=None):
|
||||
X = input.view(-1, input.size(-1))
|
||||
|
||||
ctx.save_for_backward(X, weight, bias)
|
||||
output = input.matmul(weight.t())
|
||||
if bias is not None:
|
||||
output += bias.unsqueeze(0).expand_as(output)
|
||||
return output.view(*input.size()[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output_3D):
|
||||
input, weight, bias = ctx.saved_tensors
|
||||
|
||||
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))
|
||||
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
|
||||
if ctx.needs_input_grad[1]:
|
||||
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
|
||||
if bias is not None and ctx.needs_input_grad[2]:
|
||||
grad_bias = grad_output.sum(0)
|
||||
|
||||
return grad_input, grad_weight, grad_bias
|
||||
|
||||
class StandardLinear(nn.Linear):
|
||||
|
||||
def forward(self, x):
|
||||
return StandardLinearFunction.apply(x, self.weight, self.bias)
|
|
@ -6,11 +6,11 @@
|
|||
from bitsandbytes.cextension import COMPILED_WITH_CUDA
|
||||
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
from .adam import Adam, Adam8bit, Adam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
||||
from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .optimizer import GlobalOptimManager
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .lion import Lion, Lion8bit, Lion32bit
|
||||
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
|
|
|
@ -14,92 +14,34 @@ from bitsandbytes.optim.optimizer import Optimizer2State
|
|||
|
||||
|
||||
class Adam(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class Adam8bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class Adam32bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class PagedAdam(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedAdam8bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedAdam32bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class AnalysisAdam(torch.optim.Optimizer):
|
||||
"""Adam that performs 8-bit vs 32-bit error analysis.
|
||||
|
|
|
@ -5,89 +5,35 @@
|
|||
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||
|
||||
|
||||
class AdamW(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
class AdamW(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
|
||||
|
||||
class AdamW8bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
|
||||
|
||||
class AdamW32bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
|
||||
class PagedAdamW(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedAdamW8bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedAdamW32bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
|
|
|
@ -4,84 +4,27 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
from bitsandbytes.optim.optimizer import Optimizer1State
|
||||
|
||||
|
||||
class Lion(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class Lion8bit(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class Lion32bit(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
|
||||
class PagedLion(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedLion8bit(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedLion32bit(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
|
|
@ -92,10 +92,12 @@ class GlobalOptimManager:
|
|||
|
||||
|
||||
class Optimizer8bit(torch.optim.Optimizer):
|
||||
def __init__(self, params, defaults, optim_bits=32):
|
||||
def __init__(self, params, defaults, optim_bits=32, is_paged=False):
|
||||
super().__init__(params, defaults)
|
||||
self.initialized = False
|
||||
self.name2qmap = {}
|
||||
self.is_paged = is_paged
|
||||
self.page_mng = F.GlobalPageManager.get_instance()
|
||||
|
||||
self.mng = GlobalOptimManager.get_instance()
|
||||
self.non_castable_tensor_keys = {
|
||||
|
@ -207,7 +209,9 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
values = self.state[p]
|
||||
for k, v in values.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.state[p][k] = v.to(p.device)
|
||||
is_paged = getattr(v, 'is_paged', False)
|
||||
if not is_paged:
|
||||
self.state[p][k] = v.to(p.device)
|
||||
|
||||
def check_overrides(self):
|
||||
for module, attr, config in self.mng.module_weight_config_triple:
|
||||
|
@ -252,6 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
self.to_gpu() # needed for fairseq pure fp16 training
|
||||
self.initialized = True
|
||||
|
||||
#if self.is_paged: self.page_mng.prefetch_all()
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
for pindex, p in enumerate(group["params"]):
|
||||
if p.grad is None:
|
||||
|
@ -260,7 +265,14 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
if len(state) == 0:
|
||||
self.init_state(group, p, gindex, pindex)
|
||||
|
||||
self.prefetch_state(p)
|
||||
self.update_step(group, p, gindex, pindex)
|
||||
torch.cuda.synchronize()
|
||||
if self.is_paged:
|
||||
# all paged operation are asynchronous, we need
|
||||
# to sync to make sure all tensors are in the right state
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
|
@ -289,6 +301,26 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
"The update_step method needs to be overridden"
|
||||
)
|
||||
|
||||
def get_state_buffer(self, p, dtype=torch.float32):
|
||||
if not self.is_paged or p.numel() < 1e5:
|
||||
return torch.zeros_like(p, dtype=dtype, device=p.device)
|
||||
else:
|
||||
# > 1 MB
|
||||
buff = F.get_paged(*p.shape, dtype=dtype, device=p.device)
|
||||
F.fill(buff, 0)
|
||||
self.page_mng.paged_tensors.append(buff)
|
||||
return buff
|
||||
|
||||
def prefetch_state(self, p):
|
||||
if self.is_paged:
|
||||
state = self.state[p]
|
||||
s1 = state['state1']
|
||||
is_paged = getattr(s1, 'is_paged', False)
|
||||
if is_paged:
|
||||
F.prefetch_tensor(state['state1'])
|
||||
if 'state2' in state:
|
||||
F.prefetch_tensor(state['state2'])
|
||||
|
||||
|
||||
class Optimizer2State(Optimizer8bit):
|
||||
def __init__(
|
||||
|
@ -306,6 +338,7 @@ class Optimizer2State(Optimizer8bit):
|
|||
block_wise=True,
|
||||
max_unorm=0.0,
|
||||
skip_zeros=False,
|
||||
is_paged=False
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
|
@ -325,7 +358,7 @@ class Optimizer2State(Optimizer8bit):
|
|||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults, optim_bits)
|
||||
super().__init__(params, defaults, optim_bits, is_paged)
|
||||
|
||||
if args is None:
|
||||
args = {}
|
||||
|
@ -365,18 +398,8 @@ class Optimizer2State(Optimizer8bit):
|
|||
if dtype == torch.float32 or (
|
||||
dtype == torch.uint8 and p.numel() < 4096
|
||||
):
|
||||
state["state1"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.float32,
|
||||
device=p.device,
|
||||
)
|
||||
state["state2"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.float32,
|
||||
device=p.device,
|
||||
)
|
||||
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
|
||||
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
|
||||
elif dtype == torch.uint8:
|
||||
if state["step"] == 0:
|
||||
if "dynamic" not in self.name2qmap:
|
||||
|
@ -388,20 +411,10 @@ class Optimizer2State(Optimizer8bit):
|
|||
p.device
|
||||
)
|
||||
|
||||
state["state1"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.uint8,
|
||||
device=p.device,
|
||||
)
|
||||
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
|
||||
state["qmap1"] = self.name2qmap["dynamic"]
|
||||
|
||||
state["state2"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.uint8,
|
||||
device=p.device,
|
||||
)
|
||||
state["state2"] = self.get_state_buffer(p, dtype=torch.uint8)
|
||||
state["qmap2"] = self.name2qmap["udynamic"]
|
||||
|
||||
if config["block_wise"]:
|
||||
|
@ -538,6 +551,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
block_wise=True,
|
||||
max_unorm=0.0,
|
||||
skip_zeros=False,
|
||||
is_paged=False
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
|
@ -553,7 +567,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults, optim_bits)
|
||||
super().__init__(params, defaults, optim_bits, is_paged)
|
||||
|
||||
if args is None:
|
||||
args = {}
|
||||
|
@ -593,12 +607,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
if dtype == torch.float32 or (
|
||||
dtype == torch.uint8 and p.numel() < 4096
|
||||
):
|
||||
state["state1"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.float32,
|
||||
device=p.device,
|
||||
)
|
||||
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
|
||||
elif dtype == torch.uint8:
|
||||
if state["step"] == 0:
|
||||
if "dynamic" not in self.name2qmap:
|
||||
|
@ -607,12 +616,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
p.device
|
||||
)
|
||||
|
||||
state["state1"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.uint8,
|
||||
device=p.device,
|
||||
)
|
||||
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
|
||||
state["qmap1"] = self.name2qmap["dynamic"]
|
||||
|
||||
if config["block_wise"]:
|
||||
|
|
6
bitsandbytes/research/__init__.py
Normal file
6
bitsandbytes/research/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
from . import nn
|
||||
from .autograd._functions import (
|
||||
switchback_bnb,
|
||||
matmul_fp8_global,
|
||||
matmul_fp8_mixed,
|
||||
)
|
0
bitsandbytes/research/autograd/__init__.py
Normal file
0
bitsandbytes/research/autograd/__init__.py
Normal file
411
bitsandbytes/research/autograd/_functions.py
Normal file
411
bitsandbytes/research/autograd/_functions.py
Normal file
|
@ -0,0 +1,411 @@
|
|||
import operator
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce # Required in Python 3
|
||||
|
||||
import torch
|
||||
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler
|
||||
|
||||
|
||||
# math.prod not compatible with python < 3.8
|
||||
def prod(iterable):
|
||||
return reduce(operator.mul, iterable, 1)
|
||||
|
||||
tensor = torch.Tensor
|
||||
|
||||
class MatMulFP8Mixed(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||
|
||||
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
if len(A.shape) == 3:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
else:
|
||||
At = A.transpose(1, 0).contiguous()
|
||||
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
# fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
class MatMulFP8Global(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize(A.float(), code=fw_code)
|
||||
fp8A = F.dequantize(cA, state).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype)
|
||||
|
||||
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
if len(A.shape) == 3:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
else:
|
||||
At = A.transpose(1, 0).contiguous()
|
||||
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
class SwitchBackBnb(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()):
|
||||
# default to pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
ctx.bias = bias
|
||||
if A.shape[-1] == B.shape[0]:
|
||||
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
# 5. Save state
|
||||
formatB = state.formatB
|
||||
input_shape = A.shape
|
||||
if state.outlier_pool is None:
|
||||
state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||
|
||||
# Cast A to fp16
|
||||
if A.dtype != torch.float16:
|
||||
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
|
||||
|
||||
# 1. Quantize A
|
||||
if len(A.shape) == 3:
|
||||
A = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
|
||||
A.to(torch.float16), threshold=state.threshold
|
||||
)
|
||||
|
||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||
if state.has_fp16_weights:
|
||||
idx = torch.unique(coo_tensorA.colidx).long()
|
||||
CA[:, idx] = 0
|
||||
CAt[:, idx] = 0
|
||||
subA = A[:, idx]
|
||||
state.subB = B[:, idx].t().contiguous()
|
||||
state.idx = idx
|
||||
else:
|
||||
if state.CxB is None:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
#print('A shape', A.shape)
|
||||
if not state.has_fp16_weights and state.CxB is None:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
subA = None
|
||||
|
||||
# 2. Quantize B
|
||||
if state.has_fp16_weights:
|
||||
#print('B shape', B.shape)
|
||||
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
||||
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
||||
if is_transposed:
|
||||
B = B.contiguous()
|
||||
|
||||
if (state.is_training and not has_grad) or state.CxB is None:
|
||||
state.reset_grads()
|
||||
(
|
||||
CB,
|
||||
state.CBt,
|
||||
state.SCB,
|
||||
state.SCBt,
|
||||
coo_tensorB,
|
||||
) = F.double_quant(B.to(torch.float16))
|
||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||
else:
|
||||
has_grad = False
|
||||
|
||||
if coo_tensorA is not None and not state.has_fp16_weights:
|
||||
# extract outliers
|
||||
|
||||
outlier_idx = torch.unique(coo_tensorA.colidx)
|
||||
state.idx = outlier_idx
|
||||
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
||||
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
||||
# # do not use pool for 2nd FFN layer
|
||||
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||
# else:
|
||||
# state.idx = outlier_idx
|
||||
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
|
||||
state.subB = (
|
||||
(outliers * state.SCB.view(-1, 1) / 127.0)
|
||||
.t()
|
||||
.contiguous()
|
||||
.to(A.dtype)
|
||||
)
|
||||
CA[:, state.idx.long()] = 0
|
||||
CAt[:, state.idx.long()] = 0
|
||||
subA = A[:, state.idx.long()]
|
||||
|
||||
shapeB = state.SB[0]
|
||||
|
||||
if len(input_shape) == 3:
|
||||
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
||||
else:
|
||||
output_shape = (input_shape[0], shapeB[0])
|
||||
|
||||
# 3. Matmul
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||
# we apply the fused bias here
|
||||
|
||||
if bias is None or bias.dtype == torch.float16:
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
|
||||
output = output.to(A.dtype)
|
||||
else: # apply bias separately
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
|
||||
output = output.to(A.dtype).add_(bias)
|
||||
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
if coo_tensorA is not None and subA is not None:
|
||||
output += torch.matmul(subA, state.subB)
|
||||
|
||||
# 5. Save state
|
||||
ctx.state = state
|
||||
|
||||
ctx.formatB = formatB
|
||||
ctx.grad_shape = input_shape
|
||||
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
ctx.tensors = (CAt, subA, A)
|
||||
ctx.tensor_states = (SCAt, state.idx)
|
||||
else:
|
||||
ctx.tensors = [None, None, None]
|
||||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
|
||||
clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||
return clone_func(output.view(output_shape))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias))
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
|
||||
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
|
||||
CAt, subA, A = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
state = ctx.state
|
||||
grad_A = grad_B = grad_bias = None
|
||||
|
||||
if req_gradBias:
|
||||
# compute grad_bias first before changing grad_output dtype
|
||||
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
|
||||
|
||||
# Cast grad_output to fp16
|
||||
if len(grad_output.shape) == 3:
|
||||
grad_output = grad_output.reshape(
|
||||
-1, grad_output.shape[-1]
|
||||
).contiguous()
|
||||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||
|
||||
if req_gradB:
|
||||
# print('back A shape', A.shape)
|
||||
# print('grad output t shape', grad_output.t().shape)
|
||||
grad_B = torch.matmul(grad_output.t(), A)
|
||||
|
||||
if req_gradA:
|
||||
if state.CBt is not None:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
if state.CxBt is None:
|
||||
state.CxBt, state.SBt = F.transform(
|
||||
state.CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
# print('back B shape', state.CxBt.shape)
|
||||
# print('back grad shape', C32grad.shape)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
||||
elif state.CB is not None:
|
||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
else:
|
||||
raise Exception('State must contain either CBt or CB matrix for backward')
|
||||
|
||||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
def get_block_sizes(input_matrix, weight_matrix):
|
||||
input_features = input_matrix.shape[-1]
|
||||
output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1])
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
bsz, bsz2 = 1024, 1024
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
bsz2 = k
|
||||
break
|
||||
|
||||
return bsz, bsz2
|
||||
|
||||
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
def switchback_bnb(
|
||||
A: tensor,
|
||||
B: tensor,
|
||||
out: tensor = None,
|
||||
state: MatmulLtState = None,
|
||||
threshold=0.0,
|
||||
bias=None
|
||||
):
|
||||
state = state or MatmulLtState()
|
||||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return SwitchBackBnb.apply(A, B, out, bias, state)
|
1
bitsandbytes/research/nn/__init__.py
Normal file
1
bitsandbytes/research/nn/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
from .modules import LinearFP8Mixed, LinearFP8Global
|
64
bitsandbytes/research/nn/modules.py
Normal file
64
bitsandbytes/research/nn/modules.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
from typing import Optional, TypeVar, Union, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, device, dtype, nn
|
||||
|
||||
import bitsandbytes as bnb
|
||||
from bitsandbytes.optim import GlobalOptimManager
|
||||
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||
|
||||
T = TypeVar("T", bound="torch.nn.Module")
|
||||
|
||||
|
||||
class LinearFP8Mixed(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
|
||||
class LinearFP8Global(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
0
bitsandbytes/triton/__init__.py
Normal file
0
bitsandbytes/triton/__init__.py
Normal file
64
bitsandbytes/triton/dequantize_rowwise.py
Normal file
64
bitsandbytes/triton/dequantize_rowwise.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# rowwise quantize
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _dequantize_rowwise(
|
||||
x_ptr,
|
||||
state_x,
|
||||
output_ptr,
|
||||
inv_127,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
max_val = tl.load(state_x + pid)
|
||||
output = max_val * x * inv_127
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
|
||||
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output
|
163
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
Normal file
163
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
import torch
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and global quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
w_factor = tl.load(state_w_ptr)
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
# conditionally add bias
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
|
||||
device = a.device
|
||||
divfactor = 1. / (127. * 127.)
|
||||
has_bias = 0 if bias is None else 1
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_mixed_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
164
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
Normal file
164
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
import torch
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None
|
||||
else:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# This is a matmul kernel based on triton.ops.matmul
|
||||
# It is modified to support rowwise quantized input and columnwise quantized weight
|
||||
# It's purpose is fused matmul then dequantize
|
||||
# It does support bias.
|
||||
|
||||
def init_to_zero(name):
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
|
||||
def get_configs_io_bound():
|
||||
configs = []
|
||||
for num_stages in [2, 3, 4, 5, 6]:
|
||||
for block_m in [16, 32]:
|
||||
for block_k in [32, 64]:
|
||||
for block_n in [32, 64, 128, 256]:
|
||||
num_warps = 2 if block_n <= 64 else 4
|
||||
configs.append(
|
||||
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
|
||||
num_stages=num_stages, num_warps=num_warps))
|
||||
# split_k
|
||||
for split_k in [2, 4, 8, 16]:
|
||||
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
|
||||
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
|
||||
return configs
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
# basic configs for compute-bound matmuls
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
# good for int8
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
|
||||
] + get_configs_io_bound(),
|
||||
key=['M', 'N', 'K'],
|
||||
prune_configs_by={
|
||||
'early_config_prune': early_config_prune,
|
||||
'perf_model': estimate_matmul_time,
|
||||
'top_k': 10
|
||||
},
|
||||
)
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
|
||||
stride_am, stride_ak,
|
||||
stride_bk, stride_bn,
|
||||
stride_cm, stride_cn,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
||||
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
|
||||
ACC_TYPE: tl.constexpr
|
||||
):
|
||||
# matrix multiplication
|
||||
pid = tl.program_id(0)
|
||||
pid_z = tl.program_id(1)
|
||||
grid_m = tl.cdiv(M, BLOCK_M)
|
||||
grid_n = tl.cdiv(N, BLOCK_N)
|
||||
# re-order program ID for better L2 performance
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // (group_size)
|
||||
# do matrix multiplication
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
|
||||
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
# pointers
|
||||
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
|
||||
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
|
||||
|
||||
# rematerialize rm and rn to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
w_factor = tl.load(state_w_ptr + rbn)[None, :]
|
||||
x_factor = tl.load(state_x_ptr + ram)[:, None]
|
||||
|
||||
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
|
||||
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
a = tl.load(A)
|
||||
b = tl.load(B)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
|
||||
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
A += BLOCK_K * SPLIT_K * stride_ak
|
||||
B += BLOCK_K * SPLIT_K * stride_bk
|
||||
|
||||
acc = (w_factor * (x_factor * (acc * divfactor)))
|
||||
acc = acc.to(C.dtype.element_ty)
|
||||
|
||||
if has_bias:
|
||||
bias = tl.load(bias + rn).to(C.dtype.element_ty)
|
||||
acc = acc + bias[None, :]
|
||||
|
||||
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(C, acc, mask=mask)
|
||||
else:
|
||||
tl.atomic_add(C, acc, mask=mask)
|
||||
|
||||
|
||||
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
|
||||
divfactor = 1. / (127. * 127.)
|
||||
|
||||
has_bias = 0 if bias is None else 1
|
||||
|
||||
device = a.device
|
||||
# handle non-contiguous inputs if necessary
|
||||
if a.stride(0) > 1 and a.stride(1) > 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(0) > 1 and b.stride(1) > 1:
|
||||
b = b.contiguous()
|
||||
# checks constraints
|
||||
assert a.shape[1] == b.shape[0], "incompatible dimensions"
|
||||
M, K = a.shape
|
||||
_, N = b.shape
|
||||
# allocates output
|
||||
c = torch.empty((M, N), device=device, dtype=torch.float16)
|
||||
# accumulator types
|
||||
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
|
||||
# launch int8_matmul_rowwise_dequantize kernel
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
|
||||
_int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
|
||||
a.stride(0), a.stride(1),
|
||||
b.stride(0), b.stride(1),
|
||||
c.stride(0), c.stride(1),
|
||||
GROUP_M=8, ACC_TYPE=ACC_TYPE)
|
||||
return c
|
74
bitsandbytes/triton/quantize_columnwise_and_transpose.py
Normal file
74
bitsandbytes/triton/quantize_columnwise_and_transpose.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# This kernel does fused columnwise quantization and transpose.
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_stages=16),
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=16, num_warps=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_columnwise_and_transpose(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
M : tl.constexpr, N : tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid
|
||||
p2_arange = tl.arange(0, P2)
|
||||
p2_arange_mask = p2_arange < M
|
||||
arange = p2_arange * N
|
||||
offsets = block_start + arange
|
||||
x = tl.load(x_ptr + offsets, mask=p2_arange_mask)
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
|
||||
new_start = pid * M
|
||||
new_offsets = new_start + p2_arange
|
||||
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
|
||||
def quantize_columnwise_and_transpose(x: torch.Tensor):
|
||||
M, N = x.shape
|
||||
output = torch.empty(N, M, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(M))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
|
||||
return output, output_maxs
|
||||
|
107
bitsandbytes/triton/quantize_global.py
Normal file
107
bitsandbytes/triton/quantize_global.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_global_transpose(input): return None
|
||||
def quantize_global(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# global quantize
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
|
||||
triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1),
|
||||
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global(
|
||||
x_ptr,
|
||||
absmax_inv_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
output = tl.libdevice.llrint(127. * (x * absmax_inv))
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
def quantize_global(x: torch.Tensor):
|
||||
absmax = x.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
output = torch.empty(*x.shape, device='cuda', dtype=torch.int8)
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
_quantize_global[grid](x, absmax_inv, output, n_elements)
|
||||
return output, absmax
|
||||
|
||||
|
||||
# global quantize and transpose
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
|
||||
|
||||
# ...
|
||||
],
|
||||
key=['M', 'N']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N,
|
||||
BLOCK_M : tl.constexpr,
|
||||
BLOCK_N : tl.constexpr,
|
||||
GROUP_M : tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
grid_m = (M + BLOCK_M - 1) // BLOCK_M
|
||||
grid_n = (N + BLOCK_N - 1) // BLOCK_N
|
||||
|
||||
width = GROUP_M * grid_n
|
||||
group_id = pid // width
|
||||
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
||||
pid_m = group_id * GROUP_M + (pid % group_size)
|
||||
pid_n = (pid % width) // group_size
|
||||
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
a = tl.load(A, mask=mask)
|
||||
absmax_inv = tl.load(absmax_inv_ptr)
|
||||
|
||||
# rematerialize to save registers
|
||||
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)
|
||||
mask = (rm < M)[:, None] & (rn < N)[None, :]
|
||||
|
||||
output = tl.libdevice.llrint(127. * (a * absmax_inv))
|
||||
|
||||
tl.store(B, output, mask=mask)
|
||||
|
||||
def quantize_global_transpose(input):
|
||||
absmax = input.abs().max().unsqueeze(0)
|
||||
absmax_inv = 1./ absmax
|
||||
M, N = input.shape
|
||||
out = torch.empty(N, M, device='cuda', dtype=torch.int8)
|
||||
|
||||
assert out.size(0) == N and out.size(1) == M
|
||||
assert input.stride(0) == 1 or input.stride(1) == 1
|
||||
assert out.stride(0) == 1 or out.stride(1) == 1
|
||||
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),)
|
||||
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
|
||||
return out, absmax
|
||||
|
68
bitsandbytes/triton/quantize_rowwise.py
Normal file
68
bitsandbytes/triton/quantize_rowwise.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
|
||||
if not is_triton_available():
|
||||
def quantize_rowwise(x: torch.Tensor): return None
|
||||
else:
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# rowwise quantize
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _quantize_rowwise(
|
||||
x_ptr,
|
||||
output_ptr,
|
||||
output_maxs,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
|
||||
abs_x = tl.abs(x)
|
||||
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
|
||||
output = tl.libdevice.llrint(127. * (x / max_val))
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
tl.store(output_maxs + pid, max_val)
|
||||
|
||||
def quantize_rowwise(x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
|
||||
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output, output_maxs
|
||||
|
4
bitsandbytes/triton/triton_utils.py
Normal file
4
bitsandbytes/triton/triton_utils.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
import importlib
|
||||
|
||||
def is_triton_available():
|
||||
return importlib.util.find_spec("triton") is not None
|
|
@ -1,7 +1,143 @@
|
|||
import shlex
|
||||
import subprocess
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
def outlier_hook(module, input):
|
||||
assert isinstance(module, torch.nn.Linear)
|
||||
tracer = OutlierTracer.get_instance()
|
||||
hvalue = tracer.get_hvalue(module.weight)
|
||||
if hvalue not in tracer.hvalue2outlier_idx:
|
||||
outlier_idx = find_outlier_dims(module.weight)
|
||||
tracer.outliers.append(outlier_idx)
|
||||
tracer.hvalues.append(hvalue)
|
||||
if len(tracer.outliers) > 1:
|
||||
# assign the current layer the outlier idx found from the weight
|
||||
# of the previous linear layer
|
||||
if tracer.outliers[-1].numel() > 0:
|
||||
assert tracer.outliers[-1].max() < module.weight.shape[1]
|
||||
tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1]
|
||||
|
||||
else:
|
||||
# first layer, we cannot use the weight for outlier detection
|
||||
# we follow a mixed approach:
|
||||
# (1) zscore test of std of hidden dimension
|
||||
# (2) magnitude > 6 test
|
||||
merged = input[0].view(-1, input[0].shape[-1])
|
||||
# (1) zscore test of std of hidden dimension
|
||||
outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3)
|
||||
# (2) magnitude > 6 test
|
||||
dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1)))
|
||||
outlier_idx2 = torch.where(dims > 0)[0]
|
||||
outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique()
|
||||
tracer.hvalue2outlier_idx[hvalue] = outlier_idx
|
||||
else:
|
||||
for hook in tracer.hooks:
|
||||
hook.remove()
|
||||
|
||||
|
||||
class OutlierTracer(object):
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("Call get_instance() instead")
|
||||
|
||||
def initialize(self, model):
|
||||
self.last_w = None
|
||||
self.current_outlier_dims = None
|
||||
self.hvalues = []
|
||||
self.outliers = []
|
||||
self.hvalue2outlier_idx = {}
|
||||
self.initialized = True
|
||||
self.hooks = []
|
||||
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
self.hooks.append(m.register_forward_pre_hook(outlier_hook))
|
||||
|
||||
def is_initialized(self):
|
||||
return getattr(self, 'initialized', False)
|
||||
|
||||
def get_hvalue(self, weight):
|
||||
return weight.data.storage().data_ptr()
|
||||
|
||||
def get_outliers(self, weight):
|
||||
if not self.is_initialized():
|
||||
print('Outlier tracer is not initialized...')
|
||||
return None
|
||||
hvalue = self.get_hvalue(weight)
|
||||
if hvalue in self.hvalue2outlier_idx:
|
||||
return self.hvalue2outlier_idx[hvalue]
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls.__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False):
|
||||
if rdm:
|
||||
return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
|
||||
|
||||
m = weight.mean(reduction_dim)
|
||||
mm = m.mean()
|
||||
mstd = m.std()
|
||||
zm = (m-mm)/mstd
|
||||
|
||||
std = weight.std(reduction_dim)
|
||||
stdm = std.mean()
|
||||
stdstd = std.std()
|
||||
|
||||
zstd = (std-stdm)/stdstd
|
||||
|
||||
if topk is not None:
|
||||
val, idx = torch.topk(std.abs(), k=topk, dim=0)
|
||||
else:
|
||||
idx = torch.where(zstd > zscore)[0]
|
||||
|
||||
return idx
|
||||
|
||||
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
|
||||
"""
|
||||
Replace linear modules with a new Linear module.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
linear_replacement (`torch.nn.Module`):
|
||||
The linear module that replaces the old one. Only expects standard arguments.
|
||||
If other arguments need to be passed, use a lambda.
|
||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||
List of modules names not to convert. Defaults to `lm_head`.
|
||||
copy_weights (`bool`):
|
||||
Copy the weights from the old linear module to the new one
|
||||
post_processing_fun_name (`str`):
|
||||
A function name of the replacement linear class that is called
|
||||
after processing.
|
||||
"""
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||
old_module = model._modules[name]
|
||||
model._modules[name] = linear_replacement(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
)
|
||||
if copy_weights:
|
||||
model._modules[name].weight = old_module.weight
|
||||
model._modules[name].bias = old_module.bias
|
||||
|
||||
if post_processing_function is not None:
|
||||
func = getattr(module, post_processing_function, None)
|
||||
if func is not None: func(module)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
def execute_and_return(command_string: str) -> Tuple[str, str]:
|
||||
def _decode(subprocess_err_out_tuple):
|
||||
|
@ -21,3 +157,43 @@ def execute_and_return(command_string: str) -> Tuple[str, str]:
|
|||
|
||||
std_out, std_err = execute_and_return_decoded_std_streams(command_string)
|
||||
return std_out, std_err
|
||||
|
||||
|
||||
|
||||
def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None):
|
||||
"""
|
||||
Replace linear modules with a new Linear module.
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
linear_replacement (`torch.nn.Module`):
|
||||
The linear module that replaces the old one. Only expects standard arguments.
|
||||
If other arguments need to be passed, use a lambda.
|
||||
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
||||
List of modules names not to convert. Defaults to `lm_head`.
|
||||
copy_weights (`bool`):
|
||||
Copy the weights from the old linear module to the new one
|
||||
post_processing_fun_name (`str`):
|
||||
A function name of the replacement linear class that is called
|
||||
after processing.
|
||||
"""
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function)
|
||||
|
||||
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
||||
old_module = model._modules[name]
|
||||
model._modules[name] = linear_replacement(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
)
|
||||
if copy_weights:
|
||||
model._modules[name].weight = old_module.weight
|
||||
model._modules[name].bias = old_module.bias
|
||||
|
||||
if post_processing_function is not None:
|
||||
func = getattr(module, post_processing_function, None)
|
||||
if func is not None: func(module)
|
||||
return model
|
||||
|
||||
|
|
|
@ -33,3 +33,8 @@ You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be abl
|
|||
|
||||
|
||||
If you have problems compiling the library with these instructions from source, please open an issue.
|
||||
|
||||
## Compilation with Kepler
|
||||
|
||||
Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler`
|
||||
|
||||
|
|
1120
csrc/kernels.cu
1120
csrc/kernels.cu
File diff suppressed because it is too large
Load Diff
|
@ -9,13 +9,15 @@
|
|||
#ifndef kernels
|
||||
#define kernels
|
||||
|
||||
//template <int QUANT_TYPE, typename INP_TYPE, typename COMP_TYPE, typename OUT_TYPE>__global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
|
||||
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
|
||||
|
||||
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
|
||||
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
|
||||
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE> __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||
|
@ -120,4 +122,9 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
|
|||
|
||||
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
|
||||
|
||||
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc);
|
||||
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n);
|
||||
|
||||
#endif
|
||||
|
|
137
csrc/ops.cu
137
csrc/ops.cu
|
@ -50,54 +50,53 @@ void dequantize(float *code, unsigned char *A, float *out, int n)
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
if(STOCHASTIC == 1)
|
||||
assert(blocksize == 4096);
|
||||
|
||||
if(blocksize == 4096)
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 2048)
|
||||
kQuantizeBlockwise<T, 2048, 4, 0><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 1024)
|
||||
kQuantizeBlockwise<T, 1024, 4, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 512)
|
||||
kQuantizeBlockwise<T, 512, 2, 0><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 256)
|
||||
kQuantizeBlockwise<T, 256, 2, 0><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 128)
|
||||
kQuantizeBlockwise<T, 128, 2, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
else if(blocksize == 64)
|
||||
kQuantizeBlockwise<T, 64, 1, 0><<<num_blocks, 64>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
{
|
||||
int num_blocks = n/blocksize;
|
||||
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
|
||||
if(blocksize == 4096)
|
||||
kDequantizeBlockwise<T, 4096, 1024, 4><<<num_blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 2048)
|
||||
kDequantizeBlockwise<T, 2048, 512, 4><<<num_blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 1024)
|
||||
kDequantizeBlockwise<T, 1024, 256, 4><<<num_blocks, 1024/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 512)
|
||||
kDequantizeBlockwise<T, 512, 256, 2><<<num_blocks, 512/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 256)
|
||||
kDequantizeBlockwise<T, 256, 128, 2><<<num_blocks, 256/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 128)
|
||||
kDequantizeBlockwise<T, 128, 64, 2><<<num_blocks, 128/2>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 64)
|
||||
kDequantizeBlockwise<T, 64, 64, 1><<<num_blocks, 64/1>>>(code, A, absmax, out, n);
|
||||
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
|
||||
|
||||
if(DATA_TYPE > 0)
|
||||
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n);
|
||||
else
|
||||
kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE><<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB)
|
||||
//{
|
||||
// int num_blocks = (colsB+32-1)/32;
|
||||
// kMatmul_inference_4bit<NF4, half, half, half><<<num_blocks, 256>>>(A, B, out, lda, ldb, rowsA, colsA, colsB);
|
||||
// CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
//}
|
||||
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
|
@ -683,10 +682,73 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
|
||||
{
|
||||
|
||||
int num_blocks = (m+31)/32;
|
||||
|
||||
//cout << num_blocks << endl;
|
||||
//cout << lda << endl;
|
||||
//cout << ldb << endl;
|
||||
//cout << ldc << endl;
|
||||
|
||||
//cout << m << endl;
|
||||
//cout << n << endl;
|
||||
//cout << k << endl;
|
||||
//if(bits == 32)
|
||||
//gemm_device<T, 32, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
//gemm_device<T, 32, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
if(bits == 16)
|
||||
//gemm_device<T, 16, 256><<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
gemm_device<T, 16, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
//gemm_device<T, 16, 128><<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
//gemm_device<T, 16, 96><<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
//gemm_device<T, 16, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
//gemm_device<T, 16, 64><<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc);
|
||||
}
|
||||
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{
|
||||
|
||||
int num_blocks = (m+31)/32;
|
||||
|
||||
//cout << num_blocks << endl;
|
||||
//cout << lda << endl;
|
||||
//cout << ldb << endl;
|
||||
//cout << ldc << endl;
|
||||
|
||||
//cout << m << endl;
|
||||
//cout << n << endl;
|
||||
//cout << k << endl;
|
||||
kgemm_4bit_inference<T, 160><<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
//kgemm_4bit_inference<T, 32><<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
|
||||
}
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
|
||||
{
|
||||
int threads = 512;
|
||||
int blocks = n/threads;
|
||||
blocks = n % threads == 0 ? blocks : blocks + 1;
|
||||
blocks = blocks > 65535 ? 65535 : blocks;
|
||||
kfunc<T, FUNC><<<blocks, 512>>>(A, B, value, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template void func<float, FILL>(float *A, float *B, float value, long n);
|
||||
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
|
||||
template void func<float, ARANGE>(float *A, float *B, float value, long n);
|
||||
template void func<float, _MUL>(float *A, float *B, float value, long n);
|
||||
|
||||
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
|
||||
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
|
||||
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
|
||||
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
||||
|
@ -710,12 +772,20 @@ template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows,
|
|||
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
|
||||
#define MAKE_optimizer32bit(name, gtype) \
|
||||
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
||||
|
@ -725,12 +795,14 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
|||
|
||||
MAKE_optimizer32bit(ADAM, half)
|
||||
MAKE_optimizer32bit(ADAM, float)
|
||||
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
|
||||
MAKE_optimizer32bit(MOMENTUM, half)
|
||||
MAKE_optimizer32bit(MOMENTUM, float)
|
||||
MAKE_optimizer32bit(RMSPROP, half)
|
||||
MAKE_optimizer32bit(RMSPROP, float)
|
||||
MAKE_optimizer32bit(LION, half)
|
||||
MAKE_optimizer32bit(LION, float)
|
||||
MAKE_optimizer32bit(LION, __nv_bfloat16)
|
||||
MAKE_optimizer32bit(ADAGRAD, half)
|
||||
MAKE_optimizer32bit(ADAGRAD, float)
|
||||
|
||||
|
@ -766,8 +838,11 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
|
|||
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
|
||||
|
||||
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
|
||||
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
|
||||
|
|
30
csrc/ops.cuh
30
csrc/ops.cuh
|
@ -20,6 +20,11 @@
|
|||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
#include <thrust/host_vector.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
||||
|
||||
|
||||
#define CUDA_CHECK_RETURN(value) { \
|
||||
cudaError_t _m_cudaStat = value; \
|
||||
if (_m_cudaStat != cudaSuccess) { \
|
||||
|
@ -82,6 +87,20 @@ typedef enum Transform_t
|
|||
COL_AMPERE = 4,
|
||||
} Transform_t;
|
||||
|
||||
typedef enum DataType_t
|
||||
{
|
||||
General8bit = 0,
|
||||
FP4 = 1,
|
||||
NF4 = 2,
|
||||
} DataType_t;
|
||||
|
||||
typedef enum Funcs_t
|
||||
{
|
||||
FILL = 0,
|
||||
ARANGE = 1,
|
||||
_MUL = 2,
|
||||
} Funcs_t;
|
||||
|
||||
class Context
|
||||
{
|
||||
public:
|
||||
|
@ -129,8 +148,8 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
|
|||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n);
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n);
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
|
||||
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||
|
@ -177,4 +196,11 @@ template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count,
|
|||
|
||||
template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
|
||||
|
||||
void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
|
||||
|
||||
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
|
||||
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
|
||||
|
||||
template <typename T, int FUNC> void func(T *A, T *B, T value, long n);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -20,8 +20,25 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat
|
|||
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
|
||||
|
||||
|
||||
//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
|
||||
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
|
||||
void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
|
||||
{ gemm_host<half>(M, N, K, A, B, out, lda, ldb, ldc, 16); }
|
||||
|
||||
void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference<half>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func<ctype, FUNC>(A, B, value, n); } \
|
||||
|
||||
MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
|
||||
MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
|
||||
MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
||||
MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
||||
|
||||
|
||||
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
|
||||
void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
void fname##32bit_grad_##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
|
||||
|
@ -29,17 +46,19 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
|||
|
||||
MAKE_FUNC32(momentum, MOMENTUM, float, 32)
|
||||
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
|
||||
MAKE_FUNC32(adam, ADAM, float, 32)
|
||||
MAKE_FUNC32(adam, ADAM, half, 16)
|
||||
MAKE_FUNC32(adam, ADAM, float, fp32)
|
||||
MAKE_FUNC32(adam, ADAM, half, fp16)
|
||||
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC32(lion, LION, float, 32)
|
||||
MAKE_FUNC32(lion, LION, half, 16)
|
||||
MAKE_FUNC32(lion, LION, float, fp32)
|
||||
MAKE_FUNC32(lion, LION, half, fp16)
|
||||
MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
|
||||
|
||||
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
|
||||
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
float *unorm, float max_unorm, float param_norm, \
|
||||
float beta1, float beta2, \
|
||||
float eps, int step, float lr, \
|
||||
|
@ -61,33 +80,42 @@ MAKE_FUNC8(lion, LION, float, 32)
|
|||
MAKE_FUNC8(lion, LION, half, 16)
|
||||
|
||||
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
|
||||
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
|
||||
|
||||
MAKE_BLOCKWISE8(adam, ADAM, half, 16)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, float, 32)
|
||||
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_BLOCKWISE8(lion, LION, half, 16)
|
||||
MAKE_BLOCKWISE8(lion, LION, float, 32)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
|
||||
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16)
|
||||
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_BLOCKWISE8(lion, LION, half, fp16)
|
||||
MAKE_BLOCKWISE8(lion, LION, float, fp32)
|
||||
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
|
||||
|
||||
|
||||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, 4096, n); }
|
||||
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, General8bit>(code, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, FP4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<half, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise<float, 0, NF4>(NULL, A, absmax, out, NULL, 0, blocksize, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, General8bit>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, General8bit>(code, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, FP4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, FP4>(NULL, A, absmax, out, blocksize, n); }
|
||||
void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half, NF4>(NULL, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float, NF4>(NULL, A, absmax, out, blocksize, n); }
|
||||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
|
@ -148,32 +176,41 @@ extern "C"
|
|||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
||||
|
||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
|
||||
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CFUNC32(adam, float, 32)
|
||||
MAKE_CFUNC32(adam, half, 16)
|
||||
MAKE_CFUNC32(adam, float, fp32)
|
||||
MAKE_CFUNC32(adam, half, fp16)
|
||||
MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
|
||||
MAKE_CFUNC32(momentum, float, 32)
|
||||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
MAKE_CFUNC32(rmsprop, half, 16)
|
||||
MAKE_CFUNC32(lion, float, 32)
|
||||
MAKE_CFUNC32(lion, half, 16)
|
||||
MAKE_CFUNC32(lion, float, fp32)
|
||||
MAKE_CFUNC32(lion, half, fp16)
|
||||
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
|
||||
MAKE_CFUNC32(adagrad, float, 32)
|
||||
MAKE_CFUNC32(adagrad, half, 16)
|
||||
|
||||
#define MAKE_CFUNC8(name, gtype, gbits) \
|
||||
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
float *unorm, float max_unorm, float param_norm, \
|
||||
float beta1, float beta2, \
|
||||
float eps, int step, float lr, \
|
||||
|
@ -181,7 +218,7 @@ extern "C"
|
|||
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||
float weight_decay, float gnorm_scale, int n) \
|
||||
{ \
|
||||
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||
} \
|
||||
|
||||
|
@ -195,22 +232,23 @@ extern "C"
|
|||
MAKE_CFUNC8(lion, half, 16)
|
||||
|
||||
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
|
||||
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, float, 32)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_CBLOCKWISE8(lion, LION, half, 16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, float, 32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
{ fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
|
||||
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
|
||||
|
||||
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
|
||||
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
|
||||
|
@ -298,6 +336,38 @@ extern "C"
|
|||
void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
|
||||
void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
|
||||
|
||||
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
|
||||
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||
|
||||
void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc)
|
||||
{ gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); }
|
||||
|
||||
void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize)
|
||||
{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); }
|
||||
|
||||
void *cget_managed_ptr(size_t bytes)
|
||||
{
|
||||
void *ptr;
|
||||
CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost));
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void cprefetch(void *ptr, size_t bytes, int device)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0));
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
#define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \
|
||||
void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \
|
||||
|
||||
CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL)
|
||||
CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL)
|
||||
CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE)
|
||||
CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL)
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
|
||||
|
|
11
deploy.sh
11
deploy.sh
|
@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then
|
|||
fi
|
||||
|
||||
|
||||
make clean
|
||||
export CUDA_HOME=$BASE_PATH/cuda-10.2
|
||||
make cuda10x_nomatmul CUDA_VERSION=102
|
||||
|
||||
if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then
|
||||
# Control will enter here if $DIRECTORY doesn't exist.
|
||||
echo "Compilation unsuccessul!" 1>&2
|
||||
exit 64
|
||||
fi
|
||||
|
||||
|
||||
make clean
|
||||
export CUDA_HOME=$BASE_PATH/cuda-11.0
|
||||
make cuda110_nomatmul CUDA_VERSION=110
|
||||
|
|
27
examples/int8_inference_huggingface.py
Normal file
27
examples/int8_inference_huggingface.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
MAX_NEW_TOKENS = 128
|
||||
model_name = 'decapoda-research/llama-7b-hf'
|
||||
|
||||
text = 'Hamburg is in which country?\n'
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_ids = tokenizer(text, return_tensors="pt").input_ids
|
||||
|
||||
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
|
||||
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
|
||||
|
||||
n_gpus = torch.cuda.device_count()
|
||||
max_memory = {i: max_memory for i in range(n_gpus)}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
device_map='auto',
|
||||
load_in_8bit=True,
|
||||
max_memory=max_memory
|
||||
)
|
||||
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
|
||||
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
|
||||
|
||||
|
||||
|
4
setup.py
4
setup.py
|
@ -18,10 +18,10 @@ def read(fname):
|
|||
|
||||
setup(
|
||||
name=f"bitsandbytes",
|
||||
version=f"0.38.0",
|
||||
version=f"0.39.1",
|
||||
author="Tim Dettmers",
|
||||
author_email="dettmers@cs.washington.edu",
|
||||
description="8-bit optimizers and matrix multiplication routines.",
|
||||
description="k-bit optimizers and matrix multiplication routines.",
|
||||
license="MIT",
|
||||
keywords="gpu optimizers optimization 8-bit quantization compression",
|
||||
url="https://github.com/TimDettmers/bitsandbytes",
|
||||
|
|
|
@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||
)
|
||||
if req_grad[1]:
|
||||
|
@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
assert (idx == 0).sum().item() < n * 0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx == 0).sum().item() < n * 0.02
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||
)
|
||||
|
||||
|
@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
n = out_bnb.numel()
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||
assert (idx == 0).sum().item() < n * 0.01
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
out_bnb, out_torch, atol=0.027, rtol=0.2
|
||||
)
|
||||
|
||||
|
@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||
)
|
||||
if req_grad[1]:
|
||||
|
@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
|||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||
)
|
||||
if req_grad[1]:
|
||||
|
@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
|
|||
dim2.append(0)
|
||||
|
||||
decomp = [0.0, 6.0]
|
||||
funcs = [(torch.matmul, bnb.matmul)]
|
||||
str_funcs = ["matmul"]
|
||||
funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)]
|
||||
str_funcs = ["matmullt", 'switchback_bnb']
|
||||
req_grad = [(False, False), (True, False), (True, True), (False, True)]
|
||||
req_grad = list(product([True, False], repeat=3))
|
||||
req_grad_str = []
|
||||
|
@ -407,7 +407,7 @@ def test_matmullt(
|
|||
bias.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradA1, gradA2, atol=0.015, rtol=0.1
|
||||
)
|
||||
if req_grad[1]:
|
||||
|
@ -423,9 +423,204 @@ def test_matmullt(
|
|||
assert (idx == 0).sum().item() <= n * 0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx == 0).sum().item() <= n * 0.02
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||
)
|
||||
|
||||
if req_grad[2]:
|
||||
torch.testing.assert_allclose(gradBias1, gradBias2)
|
||||
torch.testing.assert_close(gradBias1, gradBias2)
|
||||
|
||||
|
||||
n = 1
|
||||
k = 3
|
||||
dim1 = torch.randint(16, 64, size=(n,)).tolist()
|
||||
dim2 = torch.randint(32, 96, size=(n,)).tolist()
|
||||
dim3 = torch.randint(32, 96, size=(n,)).tolist()
|
||||
dim4 = torch.randint(32, 96, size=(n,)).tolist()
|
||||
|
||||
dim2.append(0)
|
||||
|
||||
funcs = [(torch.matmul, bnb.matmul_4bit)]
|
||||
str_funcs = ["matmul"]
|
||||
req_grad = list(product([True, False], repeat=3))
|
||||
req_grad_str = []
|
||||
for c in req_grad:
|
||||
strval = ''
|
||||
for v in c:
|
||||
if v == True: strval += 'T'
|
||||
else: strval += 'F'
|
||||
req_grad_str.append(strval)
|
||||
|
||||
transpose = [(False, True), (False, False)]
|
||||
str_transpose = ["NT", "NN"]
|
||||
dtype = [torch.float16, torch.float32]
|
||||
compress_statistics = [False, True]
|
||||
has_fp16_weights = [True, False]
|
||||
has_bias = [True, False]
|
||||
quant_type = ['fp4', 'nf4']
|
||||
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type))
|
||||
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type))
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values]
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names)
|
||||
def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type):
|
||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||
if has_bias == False:
|
||||
req_grad = list(req_grad)
|
||||
req_grad[2] = False
|
||||
|
||||
for i in range(k):
|
||||
# normal multiply
|
||||
if funcs[0] in [torch.mm, torch.matmul]:
|
||||
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
|
||||
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||
bias = None
|
||||
bias2 = None
|
||||
if has_bias:
|
||||
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
|
||||
bias2 = bias.clone()
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type)
|
||||
|
||||
if not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2)
|
||||
elif not transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B2, quant_state, bias=bias2)
|
||||
|
||||
if has_bias:
|
||||
out_torch += bias
|
||||
|
||||
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
|
||||
|
||||
n = out_bnb.numel()
|
||||
err = torch.abs(out_bnb - out_torch).float().mean().item()
|
||||
if n > 0:
|
||||
assert err < 0.115
|
||||
|
||||
#assert err < 0.20
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
||||
loss_bnb.backward()
|
||||
gradA1 = A.grad
|
||||
gradB1 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
if has_bias:
|
||||
gradBias1 = bias.grad
|
||||
bias.grad = None
|
||||
|
||||
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
|
||||
loss_torch.backward()
|
||||
gradA2 = A.grad
|
||||
gradB2 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
if has_bias:
|
||||
gradBias2 = bias.grad
|
||||
bias.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
|
||||
if req_grad[2]:
|
||||
torch.testing.assert_close(gradBias1, gradBias2)
|
||||
|
||||
|
||||
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
|
||||
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
|
||||
req_grad = list(product([True, False], repeat=3))
|
||||
req_grad_str = []
|
||||
for c in req_grad:
|
||||
strval = ''
|
||||
for v in c:
|
||||
if v == True: strval += 'T'
|
||||
else: strval += 'F'
|
||||
req_grad_str.append(strval)
|
||||
|
||||
transpose = [(False, True), (False, False)]
|
||||
str_transpose = ["NT", "NN"]
|
||||
dtype = [torch.float16, torch.float32]
|
||||
has_fp16_weights = [True, False]
|
||||
values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
|
||||
str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose))
|
||||
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values]
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
|
||||
def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||
req_grad = list(req_grad)
|
||||
req_grad[2] = False
|
||||
|
||||
for i in range(k):
|
||||
# normal multiply
|
||||
if funcs[0] in [torch.mm, torch.matmul]:
|
||||
A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype)
|
||||
B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||
target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype)
|
||||
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device)
|
||||
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device)
|
||||
|
||||
if not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B.t(), fw_code, bw_code)
|
||||
elif not transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B, fw_code, bw_code)
|
||||
|
||||
assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}"
|
||||
|
||||
n = out_bnb.numel()
|
||||
err = torch.abs(out_bnb - out_torch).float().mean().item()
|
||||
if n > 0:
|
||||
assert err < 0.115
|
||||
#assert err < 0.20
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
||||
loss_bnb.backward()
|
||||
gradA1 = A.grad
|
||||
gradB1 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean()
|
||||
loss_torch.backward()
|
||||
gradA2 = A.grad
|
||||
gradB2 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
if dim2 > 0:
|
||||
assert torch.abs(gradB1).sum() > 0.0
|
||||
assert torch.abs(gradB2).sum() > 0.0
|
||||
else:
|
||||
assert torch.abs(gradB1).sum() == 0.0
|
||||
assert torch.abs(gradB2).sum() == 0.0
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
|
||||
assert (idx == 0).sum().item() <= n * 0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx == 0).sum().item() <= n * 0.02
|
||||
grad_err = (gradB1-gradB2).abs().mean()
|
||||
assert grad_err.item() < 0.003
|
||||
torch.testing.assert_close(
|
||||
gradB1, gradB2, atol=0.18, rtol=0.3
|
||||
)
|
||||
|
||||
|
|
|
@ -18,12 +18,15 @@ torch.set_printoptions(
|
|||
k = 20
|
||||
|
||||
|
||||
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0):
|
||||
def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True):
|
||||
idx = torch.isclose(a, b, rtol, atol)
|
||||
sumval = (idx == 0).sum().item()
|
||||
if sumval > count:
|
||||
print(f"Too many values not close: assert {sumval} < {count}")
|
||||
torch.testing.assert_allclose(a, b, rtol, atol)
|
||||
if throw:
|
||||
print(f"Too many values not close: assert {sumval} < {count}")
|
||||
torch.testing.assert_close(a, b, rtol, atol)
|
||||
|
||||
return sumval
|
||||
|
||||
|
||||
class FFN(torch.nn.Module):
|
||||
|
@ -97,7 +100,7 @@ def test_estimate_quantiles(dtype):
|
|||
code = F.estimate_quantiles(A)
|
||||
|
||||
percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device)
|
||||
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
|
||||
torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2)
|
||||
|
||||
A = torch.randn(1024, 1024, device="cuda")
|
||||
A = A.to(dtype)
|
||||
|
@ -122,7 +125,7 @@ def test_quantile_quantization():
|
|||
C = F.quantize_no_absmax(A1, code)
|
||||
A2 = F.dequantize_no_absmax(C, code)
|
||||
diff = torch.abs(A1 - A2).mean().item()
|
||||
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
|
||||
torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0)
|
||||
assert diff < 0.001
|
||||
|
||||
|
||||
|
@ -146,63 +149,49 @@ def test_dynamic_quantization():
|
|||
C, S = F.quantize(A1)
|
||||
A2 = F.dequantize(C, S)
|
||||
diff = torch.abs(A1 - A2).mean().item()
|
||||
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
||||
assert diff < 0.004
|
||||
|
||||
|
||||
def test_dynamic_blockwise_quantization():
|
||||
|
||||
@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"])
|
||||
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
||||
def test_dynamic_blockwise_quantization(nested, blocksize):
|
||||
#print('')
|
||||
for blocksize in [4096, 2048, 1024, 512]:
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize)
|
||||
A2 = F.dequantize_blockwise(C, S, blocksize=blocksize)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
#print('rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
diffs = []
|
||||
reldiffs = []
|
||||
rand = torch.rand(1024).cuda()
|
||||
for i in range(100):
|
||||
A1 = torch.randn(1024, 1024, device="cuda")
|
||||
C1, S1 = F.quantize_blockwise(A1, rand=rand)
|
||||
C2, S2 = F.quantize_blockwise(A1)
|
||||
# a maximunm distance of quantized values of 1
|
||||
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
|
||||
fraction_smaller = (C1 < C2).float().sum() / C1.numel()
|
||||
fraction_larger = (C1 > C2).float().sum() / C1.numel()
|
||||
torch.testing.assert_allclose(
|
||||
fraction_larger, fraction_smaller, atol=0.01, rtol=0
|
||||
)
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
A1 = torch.rand(1024, 1024, device="cuda")
|
||||
C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested)
|
||||
A2 = F.dequantize_blockwise(C, S)
|
||||
diff = torch.abs(A1 - A2)
|
||||
reldiff = diff / torch.abs(A1 + 1e-8)
|
||||
diffs.append(diff.mean().item())
|
||||
reldiffs.append(reldiff.mean().item())
|
||||
#torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0)
|
||||
abserr = sum(diffs)/len(diffs)
|
||||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -231,9 +220,9 @@ def test_percentile_clipping(gtype):
|
|||
vals, idx = torch.sort(gnorm_vec1)
|
||||
clip1 = vals[percentile]
|
||||
|
||||
torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
|
||||
torch.testing.assert_allclose(clip1, clip2)
|
||||
torch.testing.assert_allclose(gnorm1, gnorm2)
|
||||
torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2))
|
||||
torch.testing.assert_close(clip1, clip2)
|
||||
torch.testing.assert_close(gnorm1, gnorm2)
|
||||
|
||||
|
||||
def quant(x):
|
||||
|
@ -315,7 +304,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
|||
dim2 = dim2 - (dim2 % 32)
|
||||
errors = []
|
||||
relerrors = []
|
||||
print("")
|
||||
#print("")
|
||||
for i in range(5):
|
||||
if batched:
|
||||
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
|
||||
|
@ -327,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
|||
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
|
||||
maxA, Ac = quant_methods[0](A, 1)
|
||||
maxB, Bc = quant_methods[1](B, 0)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05
|
||||
)
|
||||
if batched:
|
||||
|
@ -344,8 +333,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
|||
relerr = err / torch.abs(out2)
|
||||
errors.append(err.mean().item())
|
||||
relerrors.append(relerr.mean().item())
|
||||
print(mean(errors))
|
||||
print(mean(relerrors))
|
||||
#print(mean(errors))
|
||||
#print(mean(relerrors))
|
||||
|
||||
|
||||
def test_stable_embedding():
|
||||
|
@ -398,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
|
|||
out2 = torch.matmul(A.t().float(), B.t().float())
|
||||
out = F.igemm(A.t(), B.t())
|
||||
|
||||
torch.testing.assert_allclose(out.float(), out2)
|
||||
torch.testing.assert_close(out.float(), out2)
|
||||
|
||||
for i in range(k):
|
||||
shapeA = (batch_dim, seq_dim, hidden_dim)
|
||||
|
@ -416,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim):
|
|||
out2 = torch.matmul(A.float(), B.t().float())
|
||||
out = F.igemm(A, B.t())
|
||||
|
||||
torch.testing.assert_allclose(out.float(), out2)
|
||||
torch.testing.assert_close(out.float(), out2)
|
||||
|
||||
|
||||
n = 3
|
||||
|
@ -447,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim):
|
|||
)
|
||||
out = F.igemm(A, B, out=iout)
|
||||
|
||||
torch.testing.assert_allclose(out.float(), out2)
|
||||
torch.testing.assert_close(out.float(), out2)
|
||||
|
||||
|
||||
n = 2
|
||||
|
@ -572,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose):
|
|||
A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float()
|
||||
)
|
||||
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
|
||||
torch.testing.assert_allclose(out.float(), out2.float())
|
||||
torch.testing.assert_close(out.float(), out2.float())
|
||||
|
||||
|
||||
n = 1
|
||||
|
@ -630,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
|
|||
out, S = F.nvidia_transform(A, to_order=orderOut)
|
||||
|
||||
if orderOut == "row":
|
||||
torch.testing.assert_allclose(A.flatten(), out.flatten())
|
||||
torch.testing.assert_close(A.flatten(), out.flatten())
|
||||
elif orderOut == "col":
|
||||
torch.testing.assert_allclose(A.t().flatten(), out.flatten())
|
||||
torch.testing.assert_close(A.t().flatten(), out.flatten())
|
||||
elif orderOut == "col32":
|
||||
if dims == 2:
|
||||
n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32)))
|
||||
|
@ -665,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans
|
|||
|
||||
assert A.flatten()[i + j] == A[row, col]
|
||||
# assert A.flatten()[i+j] == out.flatten()[row2+col2]
|
||||
# torch.testing.assert_allclose(A.flatten()[i+j], A[row, col])
|
||||
# torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
|
||||
# torch.testing.assert_close(A.flatten()[i+j], A[row, col])
|
||||
# torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset])
|
||||
|
||||
if orderOut == "col32":
|
||||
out2, S = F.nvidia_transform(
|
||||
out, from_order=orderOut, to_order="row", state=S
|
||||
)
|
||||
torch.testing.assert_allclose(A, out2)
|
||||
torch.testing.assert_close(A, out2)
|
||||
|
||||
|
||||
n = 1
|
||||
|
@ -716,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
|
|||
B2, SB = F.transform(B, "col_turing")
|
||||
C2, SC = F.igemmlt(A2, B2, SA, SB)
|
||||
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
||||
torch.testing.assert_allclose(C1, C3.float())
|
||||
torch.testing.assert_close(C1, C3.float())
|
||||
|
||||
# transpose
|
||||
B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to(
|
||||
|
@ -727,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb):
|
|||
B2t, SBt = F.transform(B, "col_turing", transpose=True)
|
||||
C2, SC = F.igemmlt(A2, B2t, SA, SBt)
|
||||
C3, S = F.nvidia_transform(C2, "row", state=SC)
|
||||
torch.testing.assert_allclose(C1, C3.float())
|
||||
torch.testing.assert_close(C1, C3.float())
|
||||
|
||||
|
||||
dim1 = [32]
|
||||
|
@ -773,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
|
|||
# print(C1.flatten()[:10])
|
||||
# print(C2.flatten()[:10])
|
||||
|
||||
# torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
|
||||
# torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05)
|
||||
|
||||
# transpose
|
||||
# B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8)
|
||||
|
@ -782,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims):
|
|||
# B2t, SBt = F.transform2(B, 'col_turing', transpose=True)
|
||||
# C2, SC = F.igemmlt(A2, B2t, SA, SBt)
|
||||
# C3, S = F.transform(C2, 'row', state=SC)
|
||||
# torch.testing.assert_allclose(C1, C3.float())
|
||||
# torch.testing.assert_close(C1, C3.float())
|
||||
|
||||
|
||||
batch_size = 2
|
||||
|
@ -1001,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
|
|||
#assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}"
|
||||
|
||||
C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias)
|
||||
#torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1)
|
||||
#torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1)
|
||||
n = C5.numel()
|
||||
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n))
|
||||
|
||||
|
@ -1051,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims):
|
|||
)
|
||||
nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0)
|
||||
|
||||
torch.testing.assert_allclose(col_stats1_trunc, col_stats2)
|
||||
torch.testing.assert_allclose(row_stats1_trunc, row_stats2)
|
||||
torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2)
|
||||
torch.testing.assert_close(col_stats1_trunc, col_stats2)
|
||||
torch.testing.assert_close(row_stats1_trunc, row_stats2)
|
||||
torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2)
|
||||
|
||||
row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax(
|
||||
A, threshold=0.0
|
||||
)
|
||||
|
||||
torch.testing.assert_allclose(col_stats1, col_stats2)
|
||||
torch.testing.assert_allclose(row_stats1, row_stats2)
|
||||
torch.testing.assert_close(col_stats1, col_stats2)
|
||||
torch.testing.assert_close(row_stats1, row_stats2)
|
||||
assert nnz_block_ptr2 is None
|
||||
|
||||
|
||||
|
@ -1084,8 +1073,8 @@ def test_double_quant(dim1, dim2):
|
|||
CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
|
||||
|
||||
# max difference is 1 due to rounding differences
|
||||
torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0)
|
||||
torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0)
|
||||
torch.testing.assert_close(CA, out_row1, atol=1, rtol=0)
|
||||
torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0)
|
||||
|
||||
n = CAt.numel()
|
||||
num_not_close_rows = (
|
||||
|
@ -1108,8 +1097,8 @@ def test_double_quant(dim1, dim2):
|
|||
)
|
||||
assert False
|
||||
|
||||
torch.testing.assert_allclose(Srow.flatten(), statsA)
|
||||
torch.testing.assert_allclose(Scol.flatten(), statsAt)
|
||||
torch.testing.assert_close(Srow.flatten().float(), statsA)
|
||||
torch.testing.assert_close(Scol.flatten().float(), statsAt)
|
||||
|
||||
|
||||
n = 4
|
||||
|
@ -1134,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner):
|
|||
A1, maxA = F.vectorwise_quant(A, dim=1)
|
||||
B1, maxB = F.vectorwise_quant(B, dim=1)
|
||||
|
||||
torch.testing.assert_allclose(maxA.flatten(), stats1a)
|
||||
torch.testing.assert_allclose(maxB.flatten(), stats2a)
|
||||
torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1)
|
||||
torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1)
|
||||
torch.testing.assert_close(maxA.flatten().float(), stats1a)
|
||||
torch.testing.assert_close(maxB.flatten().float(), stats2a)
|
||||
torch.testing.assert_close(C1a, A1, rtol=0, atol=1)
|
||||
torch.testing.assert_close(C2a, B1, rtol=0, atol=1)
|
||||
|
||||
A2, SA = F.nvidia_transform(C1a, "col32")
|
||||
B2, SB = F.nvidia_transform(C2a, "col_turing")
|
||||
|
@ -1339,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose):
|
|||
# print(out1)
|
||||
# print(out2)
|
||||
|
||||
torch.testing.assert_allclose(out1, out2)
|
||||
torch.testing.assert_close(out1, out2)
|
||||
|
||||
|
||||
n = 2
|
||||
|
@ -1401,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2):
|
|||
A2[
|
||||
coo_tensor.rowidx.long(), coo_tensor.colidx.long()
|
||||
] = coo_tensor.values
|
||||
torch.testing.assert_allclose(A1, A2)
|
||||
torch.testing.assert_close(A1, A2)
|
||||
|
||||
A1 = A * (idx == 0)
|
||||
A2 = (CA.float() * statsA.unsqueeze(1) / 127).half()
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
A * (idx == 0), A2, rtol=0.05, atol=1.5e-2
|
||||
)
|
||||
|
||||
|
@ -1613,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func):
|
|||
|
||||
idx_col = torch.randint(0, A2.shape[-1], size=(15,))
|
||||
|
||||
# torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001)
|
||||
# torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001)
|
||||
|
||||
# Bt = torch.randn(dim2*4, dim2, device='cuda').half()
|
||||
# torch.cuda.synchronize()
|
||||
|
@ -1644,9 +1633,9 @@ def test_coo2csr():
|
|||
counts = csrA.rowptr[1:] - csrA.rowptr[:-1]
|
||||
assert counts.numel() == A.shape[0]
|
||||
|
||||
torch.testing.assert_allclose(counts, (A2 != 0).sum(1))
|
||||
torch.testing.assert_close(counts.long(), (A2 != 0).sum(1))
|
||||
idx = A2 != 0
|
||||
torch.testing.assert_allclose(A2[idx], csrA.values)
|
||||
torch.testing.assert_close(A2[idx], csrA.values)
|
||||
|
||||
|
||||
def test_coo2csc():
|
||||
|
@ -1664,10 +1653,10 @@ def test_coo2csc():
|
|||
counts = cscA.colptr[1:] - cscA.colptr[:-1]
|
||||
assert counts.numel() == A.shape[1]
|
||||
|
||||
torch.testing.assert_allclose(counts, (A2 != 0).sum(0))
|
||||
torch.testing.assert_close(counts.long(), (A2 != 0).sum(0))
|
||||
# torch uses row-major -> use transpose to transfer to col-major
|
||||
idx = A2.t() != 0
|
||||
torch.testing.assert_allclose(A2.t()[idx], cscA.values)
|
||||
torch.testing.assert_close(A2.t()[idx], cscA.values)
|
||||
|
||||
|
||||
n = 2
|
||||
|
@ -1717,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
max_count, max_idx = torch.sort(counts, descending=True)
|
||||
print(torch.median(max_count.float()))
|
||||
|
||||
torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001)
|
||||
torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001)
|
||||
|
||||
p = 200 / (2048 * 12288 * 4)
|
||||
n = out1.numel()
|
||||
|
@ -1787,38 +1776,43 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
|
|||
batch_size = 1
|
||||
seqdim = 1
|
||||
values = []
|
||||
values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
# values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
# values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
# values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
# values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
# values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
# values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 768, 4 * 768))
|
||||
#values.append((batch_size, seqdim, 1024, 4*1024))
|
||||
#values.append((batch_size, seqdim, 1536, 4*1536))
|
||||
#values.append((batch_size, seqdim, 2048, 4*2048))
|
||||
#values.append((batch_size, seqdim, 2560, 4*2560))
|
||||
values.append((batch_size, seqdim, 4096, 4*4096))
|
||||
values.append((batch_size, seqdim, 5120, 4*5120))
|
||||
values.append((batch_size, seqdim, 6656, 4*6656))
|
||||
values.append((batch_size, seqdim, 8192, 4*8192))
|
||||
#values.append((batch_size, seqdim, 5140, 4*5140))
|
||||
#values.append((batch_size, seqdim, 12288, 4*12288))
|
||||
names = [
|
||||
"batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names)
|
||||
def test_bench_matmul(batch, seq, model, hidden):
|
||||
iters = 128
|
||||
iters = 80
|
||||
formatB = F.get_special_format_str()
|
||||
|
||||
A = torch.randn(batch, seq, model, device="cuda").half()
|
||||
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
B_fp4, state = F.quantize_fp4(B)
|
||||
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
|
||||
|
||||
B_nf4, state_nf4= F.quantize_nf4(B)
|
||||
|
||||
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
|
||||
linear8bit.eval()
|
||||
|
||||
outliers = torch.randint(0, model, size=(5,)).cuda()
|
||||
A[:, :, outliers] = 8.0
|
||||
|
||||
linearMixedBit = (
|
||||
bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||
)
|
||||
linearMixedBit.eval()
|
||||
linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half())
|
||||
#linearMixedBit.eval()
|
||||
|
||||
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
|
||||
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
|
||||
|
||||
# warmup
|
||||
for i in range(iters):
|
||||
|
@ -1831,61 +1825,80 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
for i in range(iters):
|
||||
torch.matmul(A, B.t())
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul(A, B)
|
||||
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
|
||||
torch.cuda.synchronize()
|
||||
print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
bnb.matmul(A, B, threshold=6.0)
|
||||
bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
|
||||
torch.cuda.synchronize()
|
||||
print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
|
||||
C32A, SA = F.transform(CA, "col32")
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
|
||||
CxB, SB = F.transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
|
||||
torch.cuda.synchronize()
|
||||
print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
|
||||
|
||||
BA, statsB = F.vectorwise_quant(B, dim=1)
|
||||
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, statsA = F.vectorwise_quant(A2, dim=1)
|
||||
C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
||||
torch.cuda.synchronize()
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# bnb.matmul(A, B)
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# bnb.matmul(A, B, threshold=6.0)
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
#CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
|
||||
#C32A, SA = F.transform(CA, "col32")
|
||||
#CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
|
||||
#CxB, SB = F.transform(CB, to_order=formatB)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
#BA, statsB = F.vectorwise_quant(B, dim=1)
|
||||
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
# CA, statsA = F.vectorwise_quant(A2, dim=1)
|
||||
# C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
||||
CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
|
||||
C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
||||
torch.cuda.synchronize()
|
||||
#BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
|
||||
#CxB, SB = F.nvidia_transform(CB, to_order=formatB)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# A2 = A.view(-1, A.shape[-1]).contiguous()
|
||||
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
|
||||
# C32A, SA = F.nvidia_transform(CA, "col32")
|
||||
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
|
||||
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
|
||||
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
|
||||
#torch.cuda.synchronize()
|
||||
#print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linear8bit(A)
|
||||
|
@ -1894,9 +1907,7 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
for i in range(iters):
|
||||
linear8bit(A)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -1904,9 +1915,23 @@ def test_bench_matmul(batch, seq, model, hidden):
|
|||
for i in range(iters):
|
||||
linearMixedBit(A)
|
||||
torch.cuda.synchronize()
|
||||
print(
|
||||
f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
|
||||
)
|
||||
print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
#linear8bit_train(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linear8bit_train(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
#linear8bit_train_thresh(A)
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# linear8bit_train(A)
|
||||
#torch.cuda.synchronize()
|
||||
#print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
|
||||
|
||||
def test_zeropoint():
|
||||
def quant_zp(x):
|
||||
|
@ -2009,7 +2034,7 @@ def test_extract_outliers():
|
|||
assert outliers2.shape[0] == shapeA[0]
|
||||
assert outliers2.shape[1] == idx.numel()
|
||||
|
||||
torch.testing.assert_allclose(outliers1, outliers2)
|
||||
torch.testing.assert_close(outliers1, outliers2)
|
||||
|
||||
CA, SA = F.transform(A, "col_ampere")
|
||||
|
||||
|
@ -2018,7 +2043,7 @@ def test_extract_outliers():
|
|||
assert outliers2.shape[0] == shapeA[0]
|
||||
assert outliers2.shape[1] == idx.numel()
|
||||
|
||||
torch.testing.assert_allclose(outliers1, outliers2)
|
||||
torch.testing.assert_close(outliers1, outliers2)
|
||||
|
||||
|
||||
|
||||
|
@ -2050,7 +2075,6 @@ def test_fp8_quant():
|
|||
p_bits = 7-e_bits
|
||||
code = F.create_fp8_map(True, e_bits, p_bits).cuda()
|
||||
|
||||
print(e_bits, p_bits)
|
||||
abserr = []
|
||||
relerr = []
|
||||
for i in range(100):
|
||||
|
@ -2149,7 +2173,7 @@ def test_few_bit_quant():
|
|||
#assert err2.mean() <= err1
|
||||
|
||||
else:
|
||||
torch.testing.assert_allclose(q1, q2)
|
||||
torch.testing.assert_close(q1, q2)
|
||||
#print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs))
|
||||
#assert False
|
||||
|
||||
|
@ -2181,7 +2205,9 @@ def test_kbit_quantile_estimation():
|
|||
|
||||
def test_bench_dequantization():
|
||||
a = torch.rand(1024, 1024, device='cuda').half()
|
||||
qa, SA = F.quantize_blockwise(a)
|
||||
code =F.create_fp8_map(True, 3, 0, 4).cuda()
|
||||
qa, SA = F.quantize_blockwise(a, code=code)
|
||||
print(qa.max())
|
||||
|
||||
max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000
|
||||
#print(max_theoretical_mu)
|
||||
|
@ -2189,7 +2215,302 @@ def test_bench_dequantization():
|
|||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(100):
|
||||
F.dequantize_blockwise(qa, SA, blocksize=2048)
|
||||
qa, SA = F.quantize_blockwise(a)
|
||||
torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/1e6)
|
||||
|
||||
|
||||
|
||||
def test_fp4_quant():
|
||||
vals = list(product([0, 1], repeat=4))
|
||||
|
||||
code = {}
|
||||
for bits in vals:
|
||||
result = 0
|
||||
bias = 3
|
||||
sign, e1, e2, p1 = bits
|
||||
idx = sign*8 + e1*4 + e2*2 + p1*1
|
||||
sign = -1.0 if sign else 1.0
|
||||
exp = e1*2 + e2*1
|
||||
if exp == 0:
|
||||
# sub-normal
|
||||
if p1 == 0: result = 0
|
||||
else: result = sign*0.0625
|
||||
else:
|
||||
# normal
|
||||
exp = 2**(-exp + bias + 1)
|
||||
frac = 1.5 if p1 else 1.0
|
||||
result = sign*exp*frac
|
||||
code[idx] = result
|
||||
|
||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||
qa, SA = F.quantize_fp4(A1, blocksize=64)
|
||||
A2 = F.dequantize_fp4(qa, SA)
|
||||
|
||||
err = (A1 - A2).abs().float()
|
||||
relerr = (err/A1.abs().float()).mean()
|
||||
idx = err > 1.0
|
||||
err = err.mean()
|
||||
|
||||
|
||||
assert err.item() < 0.1
|
||||
assert relerr.item() < 0.28
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
||||
def test_4bit_compressed_stats(quant_type):
|
||||
for blocksize in [128, 64]:
|
||||
errs1 = []
|
||||
errs2 = []
|
||||
for i in range(10):
|
||||
A1 = torch.randn(1024, 1024, device='cuda').half()
|
||||
q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
|
||||
q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type)
|
||||
A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type)
|
||||
A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type)
|
||||
|
||||
|
||||
err = (A1 - A2).abs().float()
|
||||
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
||||
err = err.mean()
|
||||
|
||||
errs1.append(err.item())
|
||||
|
||||
|
||||
assert err.item() < 0.11
|
||||
assert relerr.item() < 0.28
|
||||
|
||||
err = (A1 - A3).abs().float()
|
||||
relerr = (err/(A1.abs().float()+1e-15)).mean()
|
||||
err = err.mean()
|
||||
|
||||
errs2.append(err.item())
|
||||
|
||||
assert err.item() < 0.11
|
||||
assert relerr.item() < 0.28
|
||||
|
||||
#print(sum(errs1)/len(errs1), blocksize, quant_type)
|
||||
#print(sum(errs2)/len(errs2), blocksize, quant_type)
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
|
||||
def test_bench_4bit_dequant(quant_type):
|
||||
blocksize = 256
|
||||
a = torch.rand(1024*12*4, 1024*12, device='cuda').half()
|
||||
qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type)
|
||||
|
||||
input_size = a.numel()/2
|
||||
output_size = a.numel()*2
|
||||
num_bytes = input_size+output_size
|
||||
GB = num_bytes/1e9
|
||||
max_theoretical_s = GB/768
|
||||
#print(max_theoretical_s*1e6)
|
||||
b = torch.randn(128, 1024*12, device='cuda').half()
|
||||
|
||||
iters = 5
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for i in range(iters):
|
||||
F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
|
||||
#b.copy_(a)
|
||||
torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/iters*1e6)
|
||||
|
||||
#torch.cuda.synchronize()
|
||||
#t0 = time.time()
|
||||
#for i in range(iters):
|
||||
# torch.matmul(b, a.t())
|
||||
#torch.cuda.synchronize()
|
||||
#print((time.time()-t0)/iters*1e6)
|
||||
|
||||
|
||||
|
||||
def test_normal_map_tree():
|
||||
code = F.create_normal_map()
|
||||
values =code[:8].tolist() + code[-8:].tolist()
|
||||
num_pivots = 1
|
||||
print(values)
|
||||
while num_pivots <16:
|
||||
idx = list(range(16//num_pivots//2, 16, 16//num_pivots))
|
||||
print(idx)
|
||||
num_pivots *= 2
|
||||
pivots = []
|
||||
for i in idx:
|
||||
pivots.append((values[i-1]+values[i])/2)
|
||||
print(pivots)
|
||||
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_cutlass3_gemm(dtype):
|
||||
debug = True
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
for dim in [4096]:
|
||||
#for dim in [128+1]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
for i in range(100):
|
||||
A = torch.randn(1, dim, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
#B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
|
||||
|
||||
C1 = torch.matmul(A, B.t())
|
||||
C2 = F.cutlass3_gemm(A, B.t())
|
||||
|
||||
# tensor cores are non-deterministic
|
||||
# so we need to analyze errors around the mean
|
||||
# to test our implementation
|
||||
err = torch.abs(C1-C2)
|
||||
mag = torch.abs(C1)+1e-8
|
||||
relerr = err/mag
|
||||
max_err = max(err.max(), max_err)
|
||||
max_relerr = max(relerr.max(), max_relerr)
|
||||
err = err.mean().item()
|
||||
relerr = relerr.mean().item()
|
||||
|
||||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
||||
#if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
# print('')
|
||||
# print(i, err, relerr)
|
||||
# print(A.flatten()[-6:])
|
||||
# print(B.flatten()[-6:])
|
||||
# out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
# print(out)
|
||||
# print(out[:-1].sum())
|
||||
# print('='*80)
|
||||
# print(C1.flatten()[-6:])
|
||||
# print(C2.flatten()[-6:])
|
||||
# #assert False, 'ERROR'
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug)
|
||||
#print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
|
||||
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16'])
|
||||
def test_gemm_4bit(dtype):
|
||||
#for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]:
|
||||
#for dim in [4096, 5120, 6656, 8192]:
|
||||
#for dim in [32]:
|
||||
for dim in [4096]:
|
||||
errs = []
|
||||
relerrs = []
|
||||
max_err = 0
|
||||
max_relerr = 0
|
||||
for i in range(1):
|
||||
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
|
||||
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
|
||||
#B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda')
|
||||
A = torch.randn(1, dim+0, dtype=dtype, device='cuda')
|
||||
B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim)
|
||||
|
||||
#print('')
|
||||
#print(A)
|
||||
#print(B.t())
|
||||
#A[:, :-1] = 0
|
||||
#B[:, :-1] = 0
|
||||
|
||||
qB, state = F.quantize_nf4(B)
|
||||
F.dequantize_nf4(qB, state)
|
||||
|
||||
C3 = torch.matmul(A, B.t())
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
C1 = bnb.matmul_4bit(A, qB.t(), state)
|
||||
C2 = F.cutlass3_gemm(A, qB.t(), state=state)
|
||||
|
||||
print(C1.shape, C2.shape)
|
||||
|
||||
# tensor cores are non-deterministic
|
||||
# so we need to analyze errors around the mean
|
||||
# to test our implementation
|
||||
err = torch.abs(C1-C2)
|
||||
mag = torch.abs(C1)+1e-8
|
||||
relerr = err/mag
|
||||
max_err = max(err.max(), max_err)
|
||||
max_relerr = max(relerr.max(), max_relerr)
|
||||
err = err.mean().item()
|
||||
relerr = relerr.mean().item()
|
||||
|
||||
errs.append(err)
|
||||
relerrs.append(relerr)
|
||||
|
||||
if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5:
|
||||
print('')
|
||||
print(i, err, relerr)
|
||||
print(A.flatten()[-6:])
|
||||
print(B.flatten()[-6:])
|
||||
out = A.flatten()[-6:]*B.flatten()[-6:]
|
||||
print(out)
|
||||
print(out[:-1].sum())
|
||||
print('='*80)
|
||||
print(C1.flatten()[-6:])
|
||||
print(C2.flatten()[-6:])
|
||||
#assert False, 'ERROR'
|
||||
|
||||
c = int(C1.numel()*0.0014*(dim/256))+1
|
||||
|
||||
c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False)
|
||||
#print(c/math.sqrt(dim))
|
||||
print('')
|
||||
print(dim, sum(errs)/len(errs)/math.sqrt(dim))
|
||||
print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim))
|
||||
print(dim, (max_err.item(), max_relerr.item()))
|
||||
|
||||
@pytest.mark.skip("Row scale has some bugs for ampere")
|
||||
def test_managed():
|
||||
n = 32*10
|
||||
A = F.get_paged(n, n, dtype=torch.float32)
|
||||
B = F.get_paged(n, n, dtype=torch.uint8)
|
||||
B2 = F.get_paged(n, n, dtype=torch.float32)
|
||||
assert A.is_paged
|
||||
assert B.is_paged
|
||||
assert A.page_deviceid==0
|
||||
assert B.page_deviceid==0
|
||||
F.fill(A, 17.0)
|
||||
F.fill(B, 17)
|
||||
F.fill(B2, 2)
|
||||
assert (A==17).sum().item() == n*n
|
||||
assert (B==17).sum().item() == n*n
|
||||
C = A*B.float()
|
||||
assert (C==289).sum().item() == n*n
|
||||
F._mul(A, B2)
|
||||
F._mul(A, B2)
|
||||
F._mul(A, B2)
|
||||
assert (A==17*(2**3)).sum().item() == n*n
|
||||
# F.prefetch_tensor(A)
|
||||
# F.prefetch_tensor(B)
|
||||
|
||||
|
||||
# F.fill(B2, 17.0)
|
||||
# F._mul(A, B2)
|
||||
|
||||
# F.prefetch_tensor(A, to_cpu=True)
|
||||
# F.prefetch_tensor(B, to_cpu=True)
|
||||
# F.prefetch_tensor(B2, to_cpu=True)
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# assert (A==17).sum().item() == n*n
|
||||
|
||||
# torch.testing.assert_close(A, torch.ones(A.shape)*289)
|
||||
|
|
|
@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
|
|||
sumval = (idx == 0).sum().item()
|
||||
if sumval > count:
|
||||
print(f"Too many values not close: assert {sumval} < {count}")
|
||||
torch.testing.assert_allclose(a, b, rtol, atol)
|
||||
torch.testing.assert_close(a, b, rtol, atol)
|
||||
|
||||
|
||||
class LinearFunction(torch.autograd.Function):
|
||||
|
@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold):
|
|||
|
||||
|
||||
def test_linear8bitlt_accumulated_gradient():
|
||||
l1 = torch.nn.Sequential(
|
||||
*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
|
||||
)
|
||||
l2 = torch.nn.Sequential(
|
||||
*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]
|
||||
)
|
||||
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
|
||||
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
|
||||
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
|
||||
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
|
||||
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
|
||||
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
|
||||
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)])
|
||||
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
|
||||
l1[0].weight.data.copy_(l2[0].weight.data)
|
||||
l1[1].weight.data.copy_(l2[1].weight.data)
|
||||
l1[0].bias.data.copy_(l2[0].bias.data)
|
||||
l1[1].bias.data.copy_(l2[1].bias.data)
|
||||
|
||||
opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001)
|
||||
opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001)
|
||||
|
||||
acc_steps = 10
|
||||
|
||||
|
@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient():
|
|||
# we do this copy because otherwise we have small divergences over time that add up
|
||||
l1[0].weight.data.copy_(l2[0].weight.data)
|
||||
l1[1].weight.data.copy_(l2[1].weight.data)
|
||||
l1[0].bias.data.copy_(l2[0].bias.data)
|
||||
l1[1].bias.data.copy_(l2[1].bias.data)
|
||||
else:
|
||||
torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
|
||||
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
|
||||
torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3)
|
||||
torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
threshold = [0.0, 2.0]
|
||||
values = threshold
|
||||
names = [f"threshold_{vals}" for vals in values]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||
@pytest.mark.parametrize("threshold", [0.0, 2.0])
|
||||
@pytest.mark.parametrize("memory_efficient_backward", [False])
|
||||
def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
||||
l1 = (
|
||||
bnb.nn.Linear8bitLt(
|
||||
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||
)
|
||||
.cuda()
|
||||
.half()
|
||||
)
|
||||
l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half())
|
||||
assert l1.weight.dtype == torch.int8
|
||||
|
||||
l1.eval()
|
||||
|
@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
|||
assert mlp.fc1.weight.dtype == torch.int8
|
||||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
|
||||
mlp = (
|
||||
MLP8bit(
|
||||
32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward
|
||||
)
|
||||
.half()
|
||||
.to("cuda")
|
||||
)
|
||||
mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda"))
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||
|
@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward):
|
|||
grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half()
|
||||
scale = grad_ref.abs().mean()
|
||||
|
||||
torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
|
||||
torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale)
|
||||
idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1)
|
||||
assert (idx == 0).sum().item() <= b1.numel() * 0.005
|
||||
|
||||
|
||||
def test_linear8bitlt_fp32_bias():
|
||||
@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4'])
|
||||
def test_linear_kbit_fp32_bias(module):
|
||||
# casts model to fp16 -> int8 automatically
|
||||
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda()
|
||||
assert l1.weight.dtype == torch.int8
|
||||
l1 = module(32, 64).cuda()
|
||||
assert l1.weight.dtype in [torch.int8, torch.uint8]
|
||||
assert l1.bias.dtype == torch.float32
|
||||
|
||||
for i in range(100):
|
||||
|
@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias():
|
|||
assert l1.bias.dtype == torch.float16
|
||||
|
||||
# casts model to fp16 -> int8 automatically
|
||||
l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda()
|
||||
assert l1.weight.dtype == torch.int8
|
||||
l1 = module(32, 64, bias=False).cuda()
|
||||
assert l1.weight.dtype in [torch.int8, torch.uint8]
|
||||
assert l1.bias is None
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||
o1 = l1(b1)
|
||||
assert l1.bias is None
|
||||
|
||||
modules = []
|
||||
modules.append(bnb.nn.Linear8bitLt)
|
||||
modules.append(bnb.nn.Linear4bit)
|
||||
modules.append(bnb.nn.LinearFP4)
|
||||
modules.append(bnb.nn.LinearNF4)
|
||||
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
|
||||
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
|
||||
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C']
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("module", modules, ids=names)
|
||||
def test_kbit_backprop(module):
|
||||
b = 17
|
||||
dim1 = 37
|
||||
dim2 = 83
|
||||
|
||||
ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)])
|
||||
ref[1].weight.requires_grad = False
|
||||
torch.nn.init.kaiming_normal_(ref[0].weight)
|
||||
torch.nn.init.kaiming_normal_(ref[1].weight)
|
||||
kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)])
|
||||
kbit[0].weight.detach().copy_(ref[0].weight)
|
||||
kbit[1].weight.detach().copy_(ref[1].weight)
|
||||
kbit[0].bias.detach().copy_(ref[0].bias)
|
||||
kbit[1].bias.detach().copy_(ref[1].bias)
|
||||
ref = ref.half().cuda()
|
||||
kbit = kbit.half().cuda()
|
||||
|
||||
errs1 = []
|
||||
errs2 = []
|
||||
relerrs1 = []
|
||||
relerrs2 = []
|
||||
for i in range(100):
|
||||
batch = torch.randn(b, dim1).half().cuda()
|
||||
out1 = ref(batch)
|
||||
out2 = kbit(batch)
|
||||
out1.mean().backward()
|
||||
out2.mean().backward()
|
||||
|
||||
grad1 = ref[0].weight.grad
|
||||
grad2 = kbit[0].weight.grad
|
||||
bgrad1 = ref[0].bias.grad
|
||||
bgrad2 = kbit[0].bias.grad
|
||||
|
||||
err1 = (out1-out2).abs().float()
|
||||
err2 = (grad1-grad2).abs().float()
|
||||
relerr1 = (err1/(out1.abs().float()+1e-9))
|
||||
relerr2 = (err2/(grad1.abs().float()+1e-9))
|
||||
errs1.append(err1.mean().item())
|
||||
errs2.append(err2.mean().item())
|
||||
relerrs1.append(relerr1.mean().item())
|
||||
relerrs2.append(relerr2.mean().item())
|
||||
|
||||
if isinstance(module, bnb.nn.Linear8bitLt):
|
||||
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
|
||||
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
|
||||
else:
|
||||
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
|
||||
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
|
||||
ref.zero_grad()
|
||||
kbit.zero_grad()
|
||||
|
||||
assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0
|
||||
assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0
|
||||
print('out', sum(errs1)/len(errs1))
|
||||
print('grad', sum(errs2)/len(errs2))
|
||||
print('rel out', sum(relerrs1)/len(relerrs1))
|
||||
print('rel grad', sum(relerrs2)/len(relerrs2))
|
||||
|
||||
def test_fp8linear():
|
||||
|
||||
b = 10
|
||||
h = 1024
|
||||
inp = torch.randn(b, h).cuda()
|
||||
fp32 = torch.nn.Linear(h, h*2).cuda()
|
||||
fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
|
||||
fp32b = torch.nn.Linear(h*2, h).cuda()
|
||||
fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
|
||||
|
||||
fp8.weight.data.copy_(fp32.weight.data)
|
||||
fp8.bias.data.copy_(fp32.bias.data)
|
||||
fp8b.weight.data.copy_(fp32b.weight.data)
|
||||
fp8b.bias.data.copy_(fp32b.bias.data)
|
||||
|
||||
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
|
||||
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
|
||||
|
||||
err = (a-b).abs().mean()
|
||||
|
||||
a.mean().backward()
|
||||
b.mean().backward()
|
||||
|
||||
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
|
||||
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
|
||||
|
||||
assert err < 0.05
|
||||
assert graderr < 0.00002
|
||||
assert bgraderr < 0.00002
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -19,11 +19,11 @@ import bitsandbytes.functional as F
|
|||
k = 20
|
||||
|
||||
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
|
||||
idx = torch.isclose(a, b, rtol, atol)
|
||||
idx = torch.isclose(a, b, rtol=rtol, atol=atol)
|
||||
error_count = (idx == 0).sum().item()
|
||||
if error_count > max_error_count:
|
||||
print(f"Too many values not close: assert {error_count} < {max_error_count}")
|
||||
torch.testing.assert_allclose(a, b, rtol, atol)
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def get_temp_dir():
|
||||
|
@ -35,11 +35,8 @@ def get_temp_dir():
|
|||
def rm_path(path):
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
str2optimizers = {}
|
||||
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
|
||||
str2optimizers["momentum_pytorch"] = (
|
||||
None,
|
||||
|
@ -47,28 +44,20 @@ str2optimizers["momentum_pytorch"] = (
|
|||
bnb.optim.Adam,
|
||||
)
|
||||
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
||||
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
|
||||
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
|
||||
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
|
||||
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
|
||||
str2optimizers["momentum"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["lars"] = (
|
||||
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
|
||||
)
|
||||
str2optimizers["rmsprop"] = (
|
||||
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["adam8bit"] = (
|
||||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
|
||||
)
|
||||
str2optimizers["lion8bit"] = (
|
||||
Lion,
|
||||
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
|
||||
)
|
||||
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
|
||||
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
|
||||
str2optimizers["momentum8bit"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -77,19 +66,12 @@ str2optimizers["rmsprop8bit"] = (
|
|||
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["lars8bit"] = (
|
||||
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
|
||||
)
|
||||
|
||||
str2optimizers["adam8bit_blockwise"] = (
|
||||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
|
||||
)
|
||||
str2optimizers["lion8bit_blockwise"] = (
|
||||
Lion,
|
||||
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
|
||||
)
|
||||
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
|
||||
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
|
||||
str2optimizers["momentum8bit_blockwise"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
||||
|
@ -101,53 +83,35 @@ str2optimizers["rmsprop8bit_blockwise"] = (
|
|||
|
||||
str2statenames = {}
|
||||
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["lion"] = [("exp_avg", "state1")]
|
||||
str2statenames["paged_lion"] = [("exp_avg", "state1")]
|
||||
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lars"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["rmsprop"] = [("square_avg", "state1")]
|
||||
str2statenames["adam8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||
]
|
||||
str2statenames["lion8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1")
|
||||
]
|
||||
str2statenames["lamb8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||
]
|
||||
str2statenames["adam8bit_blockwise"] = [
|
||||
("exp_avg", "state1", "qmap1", "absmax1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
||||
]
|
||||
str2statenames["lion8bit_blockwise"] = [
|
||||
("exp_avg", "state1", "qmap1", "absmax1")
|
||||
]
|
||||
str2statenames["momentum8bit"] = [
|
||||
("momentum_buffer", "state1", "qmap1", "max1")
|
||||
]
|
||||
str2statenames["momentum8bit_blockwise"] = [
|
||||
("momentum_buffer", "state1", "qmap1", "absmax1")
|
||||
]
|
||||
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
||||
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
|
||||
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
|
||||
str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
|
||||
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
|
||||
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
|
||||
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
||||
str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")]
|
||||
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
|
||||
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
|
||||
str2statenames["rmsprop8bit_blockwise"] = [
|
||||
("square_avg", "state1", "qmap1", "absmax1")
|
||||
]
|
||||
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
|
||||
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
|
||||
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
|
||||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
|
||||
gtype = [torch.float32, torch.float16, torch.bfloat16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
||||
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
|
||||
if dim1 == 1 and dim2 == 1:
|
||||
return
|
||||
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
|
||||
|
@ -159,6 +123,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
|
||||
if gtype == torch.float32:
|
||||
atol, rtol = 1e-6, 1e-5
|
||||
elif gtype == torch.bfloat16:
|
||||
atol, rtol = 1e-3, 1e-2
|
||||
else:
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
||||
|
@ -172,9 +138,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
|
||||
|
||||
for name1, name2 in str2statenames[optim_name]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
torch_optimizer.state[p1][name1],
|
||||
bnb_optimizer.state[p2][name2],
|
||||
bnb_optimizer.state[p2][name2].cuda(),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
@ -201,14 +167,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
atol=atol, rtol=rtol,
|
||||
max_error_count=10)
|
||||
|
||||
if gtype == torch.float16:
|
||||
if gtype != torch.float32:
|
||||
# the adam buffers should also be close because they are 32-bit
|
||||
# but the paramters can diverge because they are 16-bit
|
||||
# the difference grow larger and larger with each update
|
||||
# --> copy the state to keep weights close
|
||||
p1.data = p1.data.half().float()
|
||||
p1.data = p1.data.to(p2.dtype).float()
|
||||
p2.copy_(p1.data)
|
||||
torch.testing.assert_allclose(p1.half(), p2)
|
||||
torch.testing.assert_close(p1.to(p2.dtype), p2)
|
||||
if optim_name in ["lars", "lamb"]:
|
||||
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
|
||||
|
||||
|
@ -268,7 +234,7 @@ def test_global_config(dim1, dim2, gtype):
|
|||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
gtype = [torch.float32, torch.float16, torch.bfloat16]
|
||||
optimizer_names = [
|
||||
"adam8bit",
|
||||
"lion8bit",
|
||||
|
@ -276,7 +242,6 @@ optimizer_names = [
|
|||
"rmsprop8bit",
|
||||
"adam8bit_blockwise",
|
||||
"lion8bit_blockwise",
|
||||
"lars8bit",
|
||||
"momentum8bit_blockwise",
|
||||
"rmsprop8bit_blockwise",
|
||||
]
|
||||
|
@ -288,6 +253,7 @@ names = [
|
|||
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
||||
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
|
||||
if dim1 == 1 and dim2 == 1:
|
||||
return
|
||||
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
|
||||
|
@ -301,7 +267,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
if gtype == torch.float32:
|
||||
atol, rtol = 3e-3, 1e-3
|
||||
patol, prtol = 1e-5, 1e-3
|
||||
|
||||
elif gtype == torch.bfloat16:
|
||||
atol, rtol = 3e-3, 1e-3
|
||||
patol, prtol = 1e-4, 1e-2
|
||||
else:
|
||||
atol, rtol = 3e-3, 1e-3
|
||||
patol, prtol = 1e-5, 1e-3
|
||||
|
@ -309,7 +277,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
errors = []
|
||||
relerrors = []
|
||||
|
||||
for i in range(50):
|
||||
for i in range(100):
|
||||
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
|
||||
p1.grad = g.clone().float()
|
||||
p2.grad = g.clone()
|
||||
|
@ -343,13 +311,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
)
|
||||
== 0
|
||||
)
|
||||
assert num_not_close.sum().item() < 20
|
||||
#assert num_not_close.sum().item() < 20
|
||||
dequant_states.append(s1.clone())
|
||||
|
||||
err = torch.abs(p1 - p2)
|
||||
relerr = err / (torch.abs(p1)+1e-9)
|
||||
assert err.mean() < 0.0001
|
||||
assert relerr.mean() < 0.001
|
||||
if g.dtype == torch.bfloat16:
|
||||
assert err.mean() < 0.00015
|
||||
assert relerr.mean() < 0.0016
|
||||
else:
|
||||
assert err.mean() < 0.00012
|
||||
assert relerr.mean() < 0.0012
|
||||
|
||||
errors.append(err.mean().item())
|
||||
relerrors.append(relerr.mean().item())
|
||||
|
@ -369,12 +341,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
||||
rm_path(path)
|
||||
torch.testing.assert_allclose(
|
||||
raws1cpy, bnb_optimizer.state[p2][name2]
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
qmap1, bnb_optimizer.state[p2][qmap]
|
||||
)
|
||||
torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
|
||||
torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
|
||||
|
||||
if "blockwise" in optim_name:
|
||||
s1 = F.dequantize_blockwise(
|
||||
|
@ -389,17 +357,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
absmax=bnb_optimizer.state[p2][max_val],
|
||||
A=bnb_optimizer.state[p2][name2],
|
||||
)
|
||||
torch.testing.assert_allclose(s1cpy, s1)
|
||||
torch.testing.assert_close(s1cpy, s1)
|
||||
|
||||
num_not_close = (
|
||||
torch.isclose(
|
||||
torch_optimizer.state[p1][name1],
|
||||
s1,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
== 0
|
||||
)
|
||||
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
|
||||
assert num_not_close.sum().item() < 20
|
||||
# since Lion can have pretty noisy updates where things lie at the boundary
|
||||
# allow up to 5 errors for Lion
|
||||
|
@ -409,10 +369,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
# together so we can test against the Adam error
|
||||
p1.data = p1.data.to(gtype).float()
|
||||
p2.copy_(p1.data)
|
||||
torch.testing.assert_allclose(p1.to(gtype), p2)
|
||||
for (name1, name2, qmap, max_val), s in zip(
|
||||
str2statenames[optim_name], dequant_states
|
||||
):
|
||||
torch.testing.assert_close(p1.to(gtype), p2)
|
||||
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
|
||||
torch_optimizer.state[p1][name1].copy_(s.data)
|
||||
|
||||
# print(sum(errors)/len(errors))
|
||||
|
@ -473,28 +431,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
|
|||
|
||||
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
|
||||
if optim_bits == 32:
|
||||
torch.testing.assert_allclose(p1, p2)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(p1, p2)
|
||||
torch.testing.assert_close(
|
||||
adam1.state[p1]["state1"],
|
||||
adam2.state[p2]["state1"],
|
||||
atol=5e-5,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
adam1.state[p1]["state2"],
|
||||
adam2.state[p2]["state2"],
|
||||
atol=5e-5,
|
||||
rtol=1e-4,
|
||||
)
|
||||
elif optim_bits == 8:
|
||||
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
adam1.state[p1]["state1"],
|
||||
adam2.state[p2]["state1"],
|
||||
atol=2,
|
||||
rtol=1e-3,
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
adam1.state[p1]["state2"],
|
||||
adam2.state[p2]["state2"],
|
||||
atol=2,
|
||||
|
@ -526,7 +484,7 @@ gtype = [torch.float32, torch.float16]
|
|||
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
|
||||
# optimizer_names = ['lamb_apex', 'lamb8bit']
|
||||
# optimizer_names = ['lars_apex', 'lars8bit']
|
||||
optimizer_names = ["adam8bit_blockwise"]
|
||||
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
|
@ -557,3 +515,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
|
|||
params = (k - k // 5) * dim1 * dim2
|
||||
print(optim_name, gtype, s / params)
|
||||
# assert s < 3.9
|
||||
|
||||
dim1 = [2*1024]
|
||||
gtype = [torch.float16]
|
||||
#mode = ['torch', 'bnb']
|
||||
mode = ['bnb']
|
||||
optimizer_names = ['paged_adamw']
|
||||
#optimizer_names = ['paged_adamw8bit_blockwise']
|
||||
values = list(product(dim1,gtype, optimizer_names, mode))
|
||||
names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names)
|
||||
def test_stream_optimizer_bench(dim1, gtype, optim_name, mode):
|
||||
layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)]))
|
||||
layers1 = layers1.to(gtype)
|
||||
layers1 = layers1.cuda()
|
||||
|
||||
large_tensor = None
|
||||
if mode == 'torch':
|
||||
optim = str2optimizers[optim_name][0](layers1.parameters())
|
||||
else:
|
||||
optim = str2optimizers[optim_name][1](layers1.parameters())
|
||||
# 12 GB
|
||||
large_tensor = torch.empty((int(4.5e9),), device='cuda')
|
||||
|
||||
torch.cuda.synchronize()
|
||||
time.sleep(5)
|
||||
|
||||
num_batches = 5
|
||||
batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype)
|
||||
lbls = torch.randint(0, 10, size=(num_batches,128)).cuda()
|
||||
|
||||
for i in range(num_batches):
|
||||
print(i)
|
||||
b = batches[i]
|
||||
if i ==2:
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
|
||||
out1 = layers1(b)
|
||||
|
||||
loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean()
|
||||
loss1.backward()
|
||||
optim.step()
|
||||
torch.cuda.synchronize()
|
||||
print(mode, time.time() - t0)
|
||||
|
|
59
tests/test_triton.py
Normal file
59
tests/test_triton.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from bitsandbytes.triton.triton_utils import is_triton_available
|
||||
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
|
||||
@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8,
|
||||
reason="This test requires triton and a GPU with compute capability 8.0 or higher.")
|
||||
@pytest.mark.parametrize("vector_wise_quantization", [False, True])
|
||||
def test_switchback(vector_wise_quantization):
|
||||
for dim in [83]:
|
||||
for batch in [13]:
|
||||
|
||||
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
|
||||
switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half()
|
||||
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
|
||||
switchback.weight.data.copy_(standard.weight)
|
||||
switchback.bias.data.copy_(standard.bias)
|
||||
baseline.weight.data.copy_(standard.weight)
|
||||
baseline.bias.data.copy_(standard.bias)
|
||||
|
||||
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
|
||||
x2 = x1.clone().detach().requires_grad_(True)
|
||||
x3 = x1.clone().detach().requires_grad_(True)
|
||||
|
||||
out_standard = standard(x1)
|
||||
(2**10 * out_standard.abs().mean()).backward()
|
||||
|
||||
print(x2.dtype)
|
||||
out_sb = switchback(x2)
|
||||
(2**10 * out_sb.abs().mean()).backward()
|
||||
|
||||
out_baseline = baseline(x3)
|
||||
(2**10 * out_baseline.abs().mean()).backward()
|
||||
|
||||
err_sb = (out_standard - out_sb).abs().mean()
|
||||
err_baseline = (out_standard - out_baseline).abs().mean()
|
||||
print('OUT', err_sb, err_baseline)
|
||||
assert err_sb < 2 * err_baseline
|
||||
|
||||
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
|
||||
err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
|
||||
|
||||
print('GW2', err_sb, err_baseline)
|
||||
assert err_sb < 2 * err_baseline
|
||||
|
||||
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
|
||||
err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
|
||||
|
||||
print('GW1', err_sb, err_baseline)
|
||||
assert err_sb < 2 * err_baseline
|
||||
|
||||
err_sb = (x1.grad - x2.grad).abs().mean()
|
||||
err_baseline = (x1.grad - x3.grad).abs().mean()
|
||||
|
||||
print('GX1', err_sb, err_baseline)
|
||||
assert err_sb < 2 * err_baseline
|
||||
|
Loading…
Reference in New Issue
Block a user