diff --git a/CHANGELOG.md b/CHANGELOG.md index 5399c02..7c75b24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -221,3 +221,29 @@ Improvements: Deprecated: - Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0. - Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0 + + +### 0.38.1 + +Features: + - Added Int8 SwitchBack layers + - Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`) + + +### 0.39.0 + + +Features: + - 4-bit matrix multiplication for Float4 and NormalFloat4 data types. + - Added 4-bit quantization routines + - Doubled quantization routines for 4-bit quantization + - Paged optimizers for Adam and Lion. + - bfloat16 gradient / weight support for Adam and Lion with 8 or 32-bit states. + +Bug fixes: + - Fixed a bug where 8-bit models consumed twice the memory as expected after serialization + +Deprecated: + - Kepler binaries (GTX 700s and Tesla K40/K80) are not longer provided via pip and need to be compiled from source. Kepler support might be fully removed in the future. + + diff --git a/Makefile b/Makefile index 7bee7ef..19b5b91 100644 --- a/Makefile +++ b/Makefile @@ -2,6 +2,7 @@ MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH))) GPP:= /usr/bin/g++ +#GPP:= /sw/gcc/11.2.0/bin/g++ ifeq ($(CUDA_HOME),) CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) endif @@ -12,6 +13,7 @@ CUDA_VERSION:= endif + NVCC := $(CUDA_HOME)/bin/nvcc ########################################### @@ -23,8 +25,7 @@ FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include -INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include -LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib +LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell @@ -32,17 +33,11 @@ COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler # Later versions of CUDA support the new architectures -CC_CUDA10x += -gencode arch=compute_75,code=sm_75 - -CC_CUDA110 := -gencode arch=compute_75,code=sm_75 -CC_CUDA110 += -gencode arch=compute_80,code=sm_80 - CC_CUDA11x := -gencode arch=compute_75,code=sm_75 CC_CUDA11x += -gencode arch=compute_80,code=sm_80 CC_CUDA11x += -gencode arch=compute_86,code=sm_86 @@ -59,31 +54,32 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 -all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o +all: $(BUILD_DIR) env + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) -cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - -cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE_10x) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT - $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o - $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) - -cuda110_nomatmul: $(BUILD_DIR) env +cuda110_nomatmul_kepler: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) -cuda11x_nomatmul: $(BUILD_DIR) env +cuda11x_nomatmul_kepler: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) + +cuda110_nomatmul: $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) + +cuda11x_nomatmul: $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) + cuda12x_nomatmul: $(BUILD_DIR) env $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o diff --git a/benchmarking/switchback/README.md b/benchmarking/switchback/README.md new file mode 100644 index 0000000..bb33b5b --- /dev/null +++ b/benchmarking/switchback/README.md @@ -0,0 +1,4 @@ +Steps: + +1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling). +2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed. \ No newline at end of file diff --git a/benchmarking/switchback/info_a100_py2.jsonl b/benchmarking/switchback/info_a100_py2.jsonl new file mode 100644 index 0000000..53cda62 --- /dev/null +++ b/benchmarking/switchback/info_a100_py2.jsonl @@ -0,0 +1,60 @@ +{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28139352798461914, "standard_gw": 0.2811811864376068, "standard_gx": 0.30258670449256897, "rowwise_fwd": 0.1994594931602478, "rowwise_bwd": 0.16159191727638245, "global_fwd": 0.19502267241477966, "global_bwd": 0.16080215573310852, "x_quantize_rowwise": 0.03306940197944641, "g_quantize_rowwise": 0.08210167288780212, "w_quantize_rowwise": 0.03385916352272034, "w_quantize_colwise_transpose": 0.08635595440864563, "w_quantize_global": 0.09237229824066162, "w_quantize_global_transpose": 0.10007619857788086, "time_standard": 0.8651614189147949, "time_rowwise": 0.8776187896728516, "time_global": 0.944625586271286} +{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.262625515460968, "standard_gw": 0.2806223928928375, "standard_gx": 0.31118839979171753, "rowwise_fwd": 0.1828707754611969, "rowwise_bwd": 0.21236762404441833, "global_fwd": 0.16665831208229065, "global_bwd": 0.19929558038711548, "x_quantize_rowwise": 0.08227676153182983, "g_quantize_rowwise": 0.03310292959213257, "w_quantize_rowwise": 0.032648444175720215, "w_quantize_colwise_transpose": 0.09015202522277832, "w_quantize_global": 0.0988692045211792, "w_quantize_global_transpose": 0.10057538747787476, "time_standard": 0.8544363081455231, "time_rowwise": 0.9140409529209137, "time_global": 0.96140056848526} +{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5731917917728424, "standard_gw": 0.5709454417228699, "standard_gx": 0.5963630974292755, "rowwise_fwd": 0.37662312388420105, "rowwise_bwd": 0.281747430562973, "global_fwd": 0.36768242716789246, "global_bwd": 0.28043612837791443, "x_quantize_rowwise": 0.046547502279281616, "g_quantize_rowwise": 0.15532970428466797, "w_quantize_rowwise": 0.032436102628707886, "w_quantize_colwise_transpose": 0.08635222911834717, "w_quantize_global": 0.0947415828704834, "w_quantize_global_transpose": 0.10129809379577637, "time_standard": 1.7405003309249878, "time_rowwise": 1.5499815344810486, "time_global": 1.616980880498886} +{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5341619253158569, "standard_gw": 0.5690865218639374, "standard_gx": 0.599835067987442, "rowwise_fwd": 0.3233291208744049, "rowwise_bwd": 0.41359663009643555, "global_fwd": 0.2831108868122101, "global_bwd": 0.37280842661857605, "x_quantize_rowwise": 0.15563145279884338, "g_quantize_rowwise": 0.046741217374801636, "w_quantize_rowwise": 0.03306940197944641, "w_quantize_colwise_transpose": 0.09020790457725525, "w_quantize_global": 0.0925213098526001, "w_quantize_global_transpose": 0.09945780038833618, "time_standard": 1.7030835151672363, "time_rowwise": 1.6316622495651245, "time_global": 1.6193576157093048} +{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2199915945529938, "standard_gw": 1.1069811880588531, "standard_gx": 1.09761580824852, "rowwise_fwd": 0.738043338060379, "rowwise_bwd": 0.5549229681491852, "global_fwd": 0.7219798862934113, "global_bwd": 0.5512163043022156, "x_quantize_rowwise": 0.08748471736907959, "g_quantize_rowwise": 0.3023110330104828, "w_quantize_rowwise": 0.03182142972946167, "w_quantize_colwise_transpose": 0.08632615208625793, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.10032951831817627, "time_standard": 3.424588590860367, "time_rowwise": 2.9078908264636993, "time_global": 2.9647573828697205} +{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1040829122066498, "standard_gw": 1.1221766471862793, "standard_gx": 1.1548101902008057, "rowwise_fwd": 0.581938773393631, "rowwise_bwd": 0.7480122148990631, "global_fwd": 0.5537159740924835, "global_bwd": 0.7232688367366791, "x_quantize_rowwise": 0.30193477869033813, "g_quantize_rowwise": 0.08745118975639343, "w_quantize_rowwise": 0.03374740481376648, "w_quantize_colwise_transpose": 0.09068101644515991, "w_quantize_global": 0.09645149111747742, "w_quantize_global_transpose": 0.10189786553382874, "time_standard": 3.3810697495937347, "time_rowwise": 2.9659420251846313, "time_global": 2.9868967831134796} +{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.4533793330192566, "standard_gw": 2.1938569843769073, "standard_gx": 2.179361879825592, "rowwise_fwd": 1.4615543186664581, "rowwise_bwd": 1.0522231459617615, "global_fwd": 1.4288239181041718, "global_bwd": 1.0450035333633423, "x_quantize_rowwise": 0.1691766083240509, "g_quantize_rowwise": 0.5951300263404846, "w_quantize_rowwise": 0.03337860107421875, "w_quantize_colwise_transpose": 0.08653849363327026, "w_quantize_global": 0.0940859317779541, "w_quantize_global_transpose": 0.09976327419281006, "time_standard": 6.826598197221756, "time_rowwise": 5.5918581783771515, "time_global": 5.625840276479721} +{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.1698065102100372, "standard_gw": 2.1875128149986267, "standard_gx": 2.2887587547302246, "rowwise_fwd": 1.0762326419353485, "rowwise_bwd": 1.4638006687164307, "global_fwd": 1.0450668632984161, "global_bwd": 1.4308765530586243, "x_quantize_rowwise": 0.5953535437583923, "g_quantize_rowwise": 0.16899779438972473, "w_quantize_rowwise": 0.03240257501602173, "w_quantize_colwise_transpose": 0.09106099605560303, "w_quantize_global": 0.09546056389808655, "w_quantize_global_transpose": 0.09852275252342224, "time_standard": 6.6460780799388885, "time_rowwise": 5.615361034870148, "time_global": 5.621790885925293} +{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.858218133449554, "standard_gw": 4.3631307780742645, "standard_gx": 4.404045641422272, "rowwise_fwd": 2.9063820838928223, "rowwise_bwd": 2.094462513923645, "global_fwd": 2.8426870703697205, "global_bwd": 2.0792782306671143, "x_quantize_rowwise": 0.33241137862205505, "g_quantize_rowwise": 1.1817105114459991, "w_quantize_rowwise": 0.03374367952346802, "w_quantize_colwise_transpose": 0.08633732795715332, "w_quantize_global": 0.09231641888618469, "w_quantize_global_transpose": 0.100012868642807, "time_standard": 13.62539455294609, "time_rowwise": 10.998178273439407, "time_global": 10.991547256708145} +{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.246581345796585, "standard_gw": 4.42587211728096, "standard_gx": 4.581417888402939, "rowwise_fwd": 2.1114833652973175, "rowwise_bwd": 2.9050447046756744, "global_fwd": 2.0806826651096344, "global_bwd": 2.85966694355011, "x_quantize_rowwise": 1.1816024780273438, "g_quantize_rowwise": 0.33330172300338745, "w_quantize_rowwise": 0.033445656299591064, "w_quantize_colwise_transpose": 0.09065866470336914, "w_quantize_global": 0.09239837527275085, "w_quantize_global_transpose": 0.09984523057937622, "time_standard": 13.253871351480484, "time_rowwise": 11.081408709287643, "time_global": 11.073369532823563} +{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.4859529435634613, "standard_gw": 0.46338513493537903, "standard_gx": 0.42321905493736267, "rowwise_fwd": 0.2761557698249817, "rowwise_bwd": 0.20775198936462402, "global_fwd": 0.2713911235332489, "global_bwd": 0.20639970898628235, "x_quantize_rowwise": 0.033095479011535645, "g_quantize_rowwise": 0.11894106864929199, "w_quantize_rowwise": 0.03125518560409546, "w_quantize_colwise_transpose": 0.1424551010131836, "w_quantize_global": 0.07288157939910889, "w_quantize_global_transpose": 0.08071959018707275, "time_standard": 1.372557133436203, "time_rowwise": 1.2730397284030914, "time_global": 1.2468136847019196} +{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.3920421004295349, "standard_gw": 0.44424086809158325, "standard_gx": 0.4759356379508972, "rowwise_fwd": 0.23231282830238342, "rowwise_bwd": 0.28430670499801636, "global_fwd": 0.20883232355117798, "global_bwd": 0.2741999924182892, "x_quantize_rowwise": 0.12018159031867981, "g_quantize_rowwise": 0.03195926547050476, "w_quantize_rowwise": 0.026017427444458008, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.07734447717666626, "w_quantize_global_transpose": 0.0788569450378418, "time_standard": 1.3122186064720154, "time_rowwise": 1.2863576412200928, "time_global": 1.235615462064743} +{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0111741721630096, "standard_gw": 0.9267590939998627, "standard_gx": 0.8254274725914001, "rowwise_fwd": 0.5434826016426086, "rowwise_bwd": 0.4077926278114319, "global_fwd": 0.5318708717823029, "global_bwd": 0.40537863969802856, "x_quantize_rowwise": 0.059738755226135254, "g_quantize_rowwise": 0.2299174666404724, "w_quantize_rowwise": 0.02545863389968872, "w_quantize_colwise_transpose": 0.14269724488258362, "w_quantize_global": 0.07300823926925659, "w_quantize_global_transpose": 0.07878988981246948, "time_standard": 2.7633607387542725, "time_rowwise": 2.335846424102783, "time_global": 2.305462956428528} +{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8095316588878632, "standard_gw": 0.8607134222984314, "standard_gx": 0.9204968810081482, "rowwise_fwd": 0.4275888204574585, "rowwise_bwd": 0.5485899746417999, "global_fwd": 0.41000545024871826, "global_bwd": 0.5317628383636475, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.059254467487335205, "w_quantize_rowwise": 0.02466142177581787, "w_quantize_colwise_transpose": 0.14865398406982422, "w_quantize_global": 0.07582828402519226, "w_quantize_global_transpose": 0.08231401443481445, "time_standard": 2.5907419621944427, "time_rowwise": 2.2996440529823303, "time_global": 2.2500604391098022} +{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0658522844314575, "standard_gw": 1.718364655971527, "standard_gx": 1.6660578548908234, "rowwise_fwd": 1.066897064447403, "rowwise_bwd": 0.8070804178714752, "global_fwd": 1.0473169386386871, "global_bwd": 0.8021742105484009, "x_quantize_rowwise": 0.11274218559265137, "g_quantize_rowwise": 0.4518181085586548, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.14259666204452515, "w_quantize_global": 0.07484853267669678, "w_quantize_global_transpose": 0.07976219058036804, "time_standard": 5.450274795293808, "time_rowwise": 4.326000809669495, "time_global": 4.287026822566986} +{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.7549192309379578, "standard_gw": 1.6954988241195679, "standard_gx": 1.8179528415203094, "rowwise_fwd": 0.8649080991744995, "rowwise_bwd": 1.0746456682682037, "global_fwd": 0.8023083209991455, "global_bwd": 1.0471977293491364, "x_quantize_rowwise": 0.45225024223327637, "g_quantize_rowwise": 0.11286512017250061, "w_quantize_rowwise": 0.0252649188041687, "w_quantize_colwise_transpose": 0.14732033014297485, "w_quantize_global": 0.07537379860877991, "w_quantize_global_transpose": 0.0807642936706543, "time_standard": 6.268370896577835, "time_rowwise": 4.372753202915192, "time_global": 4.266258329153061} +{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.098430275917053, "standard_gw": 3.3501461148262024, "standard_gx": 5.560480058193207, "rowwise_fwd": 2.112947404384613, "rowwise_bwd": 1.605246216058731, "global_fwd": 2.0697638392448425, "global_bwd": 1.5953518450260162, "x_quantize_rowwise": 0.21921470761299133, "g_quantize_rowwise": 0.8956789970397949, "w_quantize_rowwise": 0.02710893750190735, "w_quantize_colwise_transpose": 0.14268234372138977, "w_quantize_global": 0.07259473204612732, "w_quantize_global_transpose": 0.07899105548858643, "time_standard": 13.009056448936462, "time_rowwise": 8.35302472114563, "time_global": 8.281741291284561} +{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.586959421634674, "standard_gw": 3.358360379934311, "standard_gx": 3.6434978246688843, "rowwise_fwd": 1.6269534826278687, "rowwise_bwd": 2.128206193447113, "global_fwd": 1.5950687229633331, "global_bwd": 2.0831897854804993, "x_quantize_rowwise": 0.8954145014286041, "g_quantize_rowwise": 0.21914392709732056, "w_quantize_rowwise": 0.026203691959381104, "w_quantize_colwise_transpose": 0.14658644795417786, "w_quantize_global": 0.07478520274162292, "w_quantize_global_transpose": 0.07964670658111572, "time_standard": 12.58881762623787, "time_rowwise": 8.400868624448776, "time_global": 8.305609226226807} +{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.229725062847137, "standard_gw": 6.791356950998306, "standard_gx": 6.806455552577972, "rowwise_fwd": 4.252471029758453, "rowwise_bwd": 3.2062679529190063, "global_fwd": 4.175614565610886, "global_bwd": 3.1837262213230133, "x_quantize_rowwise": 0.4321373999118805, "g_quantize_rowwise": 1.787092536687851, "w_quantize_rowwise": 0.0270158052444458, "w_quantize_colwise_transpose": 0.1424252986907959, "w_quantize_global": 0.07348507642745972, "w_quantize_global_transpose": 0.07829815149307251, "time_standard": 21.827537566423416, "time_rowwise": 16.63876697421074, "time_global": 16.52171090245247} +{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.279478669166565, "standard_gw": 6.7345499992370605, "standard_gx": 7.206875830888748, "rowwise_fwd": 3.209315240383148, "rowwise_bwd": 4.256397485733032, "global_fwd": 3.180190920829773, "global_bwd": 4.177983850240707, "x_quantize_rowwise": 1.7836056649684906, "g_quantize_rowwise": 0.4321075975894928, "w_quantize_rowwise": 0.03205239772796631, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09316205978393555, "w_quantize_global_transpose": 0.10086596012115479, "time_standard": 25.220904499292374, "time_rowwise": 16.5947787463665, "time_global": 16.502466052770615} +{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5776733160018921, "standard_gw": 0.5300231277942657, "standard_gx": 0.6005913019180298, "rowwise_fwd": 0.33330172300338745, "rowwise_bwd": 0.2957060933113098, "global_fwd": 0.32876431941986084, "global_bwd": 0.29108673334121704, "x_quantize_rowwise": 0.03466755151748657, "g_quantize_rowwise": 0.12264400720596313, "w_quantize_rowwise": 0.033874064683914185, "w_quantize_colwise_transpose": 0.1775398850440979, "w_quantize_global": 0.09503215551376343, "w_quantize_global_transpose": 0.10617449879646301, "time_standard": 1.7082877457141876, "time_rowwise": 1.5277564525604248, "time_global": 1.5083923935890198} +{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5164109170436859, "standard_gw": 0.5367249250411987, "standard_gx": 0.5876161158084869, "rowwise_fwd": 0.3132447600364685, "rowwise_bwd": 0.3396235406398773, "global_fwd": 0.2943649888038635, "global_bwd": 0.33209100365638733, "x_quantize_rowwise": 0.12357160449028015, "g_quantize_rowwise": 0.035997480154037476, "w_quantize_rowwise": 0.03213062882423401, "w_quantize_colwise_transpose": 0.17676874995231628, "w_quantize_global": 0.09861215949058533, "w_quantize_global_transpose": 0.0998862087726593, "time_standard": 1.6407519578933716, "time_rowwise": 1.5580616891384125, "time_global": 1.5212483704090118} +{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.2096501886844635, "standard_gw": 1.0663382709026337, "standard_gx": 1.0961703956127167, "rowwise_fwd": 0.6396733224391937, "rowwise_bwd": 0.5173943936824799, "global_fwd": 0.6296299397945404, "global_bwd": 0.5130060017108917, "x_quantize_rowwise": 0.06211921572685242, "g_quantize_rowwise": 0.2361498773097992, "w_quantize_rowwise": 0.03260001540184021, "w_quantize_colwise_transpose": 0.17679482698440552, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.09913742542266846, "time_standard": 3.372158855199814, "time_rowwise": 2.7310699224472046, "time_global": 2.6999935507774353} +{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.1065565049648285, "standard_gw": 1.0664314031600952, "standard_gx": 1.1266544461250305, "rowwise_fwd": 0.5352050065994263, "rowwise_bwd": 0.6464086472988129, "global_fwd": 0.513765960931778, "global_bwd": 0.6284862756729126, "x_quantize_rowwise": 0.23620948195457458, "g_quantize_rowwise": 0.062271952629089355, "w_quantize_rowwise": 0.031460076570510864, "w_quantize_colwise_transpose": 0.17675384879112244, "w_quantize_global": 0.09486451745033264, "w_quantize_global_transpose": 0.09898096323013306, "time_standard": 3.2996423542499542, "time_rowwise": 2.7547404170036316, "time_global": 2.7010105550289154} +{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.4367496371269226, "standard_gw": 2.0806193351745605, "standard_gx": 2.19624862074852, "rowwise_fwd": 1.2554042041301727, "rowwise_bwd": 1.0227933526039124, "global_fwd": 1.2322552502155304, "global_bwd": 1.0152235627174377, "x_quantize_rowwise": 0.11792033910751343, "g_quantize_rowwise": 0.4639364778995514, "w_quantize_rowwise": 0.03241002559661865, "w_quantize_colwise_transpose": 0.17657503485679626, "w_quantize_global": 0.09655207395553589, "w_quantize_global_transpose": 0.09958073496818542, "time_standard": 6.713617593050003, "time_rowwise": 5.149658769369125, "time_global": 5.106087774038315} +{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.1935217082500458, "standard_gw": 2.0055584609508514, "standard_gx": 2.1882541477680206, "rowwise_fwd": 1.0396353900432587, "rowwise_bwd": 1.2542344629764557, "global_fwd": 1.0161921381950378, "global_bwd": 1.233428716659546, "x_quantize_rowwise": 0.4642195999622345, "g_quantize_rowwise": 0.11782720685005188, "w_quantize_rowwise": 0.033117830753326416, "w_quantize_colwise_transpose": 0.17696991562843323, "w_quantize_global": 0.09416043758392334, "w_quantize_global_transpose": 0.10101497173309326, "time_standard": 6.387334316968918, "time_rowwise": 5.091562867164612, "time_global": 5.032401531934738} +{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.804681986570358, "standard_gw": 4.763372242450714, "standard_gx": 4.064023494720459, "rowwise_fwd": 2.484843134880066, "rowwise_bwd": 1.9691288471221924, "global_fwd": 2.441786229610443, "global_bwd": 1.9574686884880066, "x_quantize_rowwise": 0.2294592559337616, "g_quantize_rowwise": 0.9196549654006958, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.1768544316291809, "w_quantize_global": 0.09644776582717896, "w_quantize_global_transpose": 0.09847059845924377, "time_standard": 13.632077723741531, "time_rowwise": 10.574690997600555, "time_global": 10.506659746170044} +{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.0907710790634155, "standard_gw": 3.9793066680431366, "standard_gx": 4.302978515625, "rowwise_fwd": 1.992940902709961, "rowwise_bwd": 2.4996213614940643, "global_fwd": 1.9551962614059448, "global_bwd": 2.457551658153534, "x_quantize_rowwise": 0.9200014173984528, "g_quantize_rowwise": 0.2293996512889862, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.17882883548736572, "w_quantize_global": 0.09540095925331116, "w_quantize_global_transpose": 0.09880587458610535, "time_standard": 12.373056262731552, "time_rowwise": 9.831476956605911, "time_global": 9.73566249012947} +{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.655728936195374, "standard_gw": 8.261296898126602, "standard_gx": 8.064884692430496, "rowwise_fwd": 5.007706582546234, "rowwise_bwd": 3.8615092635154724, "global_fwd": 4.920527338981628, "global_bwd": 3.8330331444740295, "x_quantize_rowwise": 0.45276060700416565, "g_quantize_rowwise": 1.8306002020835876, "w_quantize_rowwise": 0.031366944313049316, "w_quantize_colwise_transpose": 0.1766495406627655, "w_quantize_global": 0.09412690997123718, "w_quantize_global_transpose": 0.09780004620552063, "time_standard": 25.981910526752472, "time_rowwise": 19.621890038251877, "time_global": 19.49014514684677} +{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 8.033104240894318, "standard_gw": 8.2889124751091, "standard_gx": 8.622754365205765, "rowwise_fwd": 3.8747042417526245, "rowwise_bwd": 5.003921687602997, "global_fwd": 3.8315393030643463, "global_bwd": 4.9162134528160095, "x_quantize_rowwise": 1.8304847180843353, "g_quantize_rowwise": 0.4522763192653656, "w_quantize_rowwise": 0.03413110971450806, "w_quantize_colwise_transpose": 0.1771189272403717, "w_quantize_global": 0.09519979357719421, "w_quantize_global_transpose": 0.09930506348609924, "time_standard": 24.944771081209183, "time_rowwise": 19.661549478769302, "time_global": 19.51393112540245} +{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7954612374305725, "standard_gw": 0.7456131279468536, "standard_gx": 0.8799619972705841, "rowwise_fwd": 0.43267011642456055, "rowwise_bwd": 0.34622475504875183, "global_fwd": 0.42615458369255066, "global_bwd": 0.344250351190567, "x_quantize_rowwise": 0.03748014569282532, "g_quantize_rowwise": 0.13304129242897034, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.2407953143119812, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10305643081665039, "time_standard": 2.4210363626480103, "time_rowwise": 1.96877121925354, "time_global": 1.8842294812202454} +{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7120333611965179, "standard_gw": 0.7622130215167999, "standard_gx": 0.8262209594249725, "rowwise_fwd": 0.3702230751514435, "rowwise_bwd": 0.4419572651386261, "global_fwd": 0.3479123115539551, "global_bwd": 0.4306286573410034, "x_quantize_rowwise": 0.13308599591255188, "g_quantize_rowwise": 0.037495046854019165, "w_quantize_rowwise": 0.03398209810256958, "w_quantize_colwise_transpose": 0.23782625794410706, "w_quantize_global": 0.09853765368461609, "w_quantize_global_transpose": 0.10247156023979187, "time_standard": 2.3004673421382904, "time_rowwise": 2.016782760620117, "time_global": 1.9123442471027374} +{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6292817890644073, "standard_gw": 1.5109702944755554, "standard_gx": 1.482747495174408, "rowwise_fwd": 0.8386112749576569, "rowwise_bwd": 0.6844550371170044, "global_fwd": 0.8220970630645752, "global_bwd": 0.6802082061767578, "x_quantize_rowwise": 0.06883963942527771, "g_quantize_rowwise": 0.25641173124313354, "w_quantize_rowwise": 0.033054500818252563, "w_quantize_colwise_transpose": 0.24027004837989807, "w_quantize_global": 0.0967271625995636, "w_quantize_global_transpose": 0.102948397397995, "time_standard": 4.622999578714371, "time_rowwise": 3.6326125264167786, "time_global": 3.5382024943828583} +{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4877021312713623, "standard_gw": 1.5015341341495514, "standard_gx": 1.529306173324585, "rowwise_fwd": 0.715944916009903, "rowwise_bwd": 0.8529908955097198, "global_fwd": 0.680088996887207, "global_bwd": 0.8224695920944214, "x_quantize_rowwise": 0.2568177878856659, "g_quantize_rowwise": 0.06864592432975769, "w_quantize_rowwise": 0.03343448042869568, "w_quantize_colwise_transpose": 0.23645907640457153, "w_quantize_global": 0.09399279952049255, "w_quantize_global_transpose": 0.10286271572113037, "time_standard": 4.518542438745499, "time_rowwise": 3.665827214717865, "time_global": 3.5264119505882263} +{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.261040896177292, "standard_gw": 2.8816498816013336, "standard_gx": 2.8357282280921936, "rowwise_fwd": 1.6594752669334412, "rowwise_bwd": 1.359265297651291, "global_fwd": 1.6287527978420258, "global_bwd": 1.3503879308700562, "x_quantize_rowwise": 0.13146549463272095, "g_quantize_rowwise": 0.5035959184169769, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.24086236953735352, "w_quantize_global": 0.0945068895816803, "w_quantize_global_transpose": 0.10332837700843811, "time_standard": 8.978419005870819, "time_rowwise": 6.8106986582279205, "time_global": 6.693687289953232} +{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.848360687494278, "standard_gw": 2.8955675661563873, "standard_gx": 3.0499882996082306, "rowwise_fwd": 1.3900883495807648, "rowwise_bwd": 1.6595833003520966, "global_fwd": 1.3514049351215363, "global_bwd": 1.629263162612915, "x_quantize_rowwise": 0.5036592483520508, "g_quantize_rowwise": 0.13118237257003784, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.23709610104560852, "w_quantize_global": 0.0951625406742096, "w_quantize_global_transpose": 0.10216236114501953, "time_standard": 8.793916553258896, "time_rowwise": 6.851561367511749, "time_global": 6.708402186632156} +{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.4978525042533875, "standard_gw": 6.462603807449341, "standard_gx": 5.5987648665905, "rowwise_fwd": 3.2996535301208496, "rowwise_bwd": 2.6320070028305054, "global_fwd": 3.2426007091999054, "global_bwd": 2.612769603729248, "x_quantize_rowwise": 0.2561397850513458, "g_quantize_rowwise": 0.9984448552131653, "w_quantize_rowwise": 0.033076852560043335, "w_quantize_colwise_transpose": 0.24232640862464905, "w_quantize_global": 0.09618699550628662, "w_quantize_global_transpose": 0.10257214307785034, "time_standard": 18.559221178293228, "time_rowwise": 13.9242522418499, "time_global": 13.771317899227142} +{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.5702440440654755, "standard_gw": 5.717620253562927, "standard_gx": 6.08203187584877, "rowwise_fwd": 2.649586647748947, "rowwise_bwd": 3.315173089504242, "global_fwd": 2.6132799685001373, "global_bwd": 3.257807344198227, "x_quantize_rowwise": 0.9980201721191406, "g_quantize_rowwise": 0.256560742855072, "w_quantize_rowwise": 0.03356859087944031, "w_quantize_colwise_transpose": 0.23729726672172546, "w_quantize_global": 0.09495764970779419, "w_quantize_global_transpose": 0.103779137134552, "time_standard": 17.369896173477173, "time_rowwise": 13.207826763391495, "time_global": 13.04202526807785} +{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 13.058379292488098, "standard_gw": 11.480242013931274, "standard_gx": 11.092845350503922, "rowwise_fwd": 6.637874990701675, "rowwise_bwd": 5.24790957570076, "global_fwd": 6.521012634038925, "global_bwd": 5.214303731918335, "x_quantize_rowwise": 0.5057565867900848, "g_quantize_rowwise": 1.989319920539856, "w_quantize_rowwise": 0.03439188003540039, "w_quantize_colwise_transpose": 0.24280324578285217, "w_quantize_global": 0.09520724415779114, "w_quantize_global_transpose": 0.10240450501441956, "time_standard": 35.631466656923294, "time_rowwise": 26.138298213481903, "time_global": 25.908246636390686} +{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.13397628068924, "standard_gw": 11.371888220310211, "standard_gx": 12.12756335735321, "rowwise_fwd": 5.2495077252388, "rowwise_bwd": 6.638709455728531, "global_fwd": 5.215313285589218, "global_bwd": 6.5222084522247314, "x_quantize_rowwise": 1.9870512187480927, "g_quantize_rowwise": 0.5058236420154572, "w_quantize_rowwise": 0.034634023904800415, "w_quantize_colwise_transpose": 0.23674964904785156, "w_quantize_global": 0.09457767009735107, "w_quantize_global_transpose": 0.10183081030845642, "time_standard": 34.63342785835266, "time_rowwise": 26.024363934993744, "time_global": 25.798693299293518} +{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2125298380851746, "standard_gw": 1.1111274361610413, "standard_gx": 1.0840706527233124, "rowwise_fwd": 0.6057210266590118, "rowwise_bwd": 0.51865354180336, "global_fwd": 0.5952082574367523, "global_bwd": 0.5167685449123383, "x_quantize_rowwise": 0.045686960220336914, "g_quantize_rowwise": 0.15827640891075134, "w_quantize_rowwise": 0.04361197352409363, "w_quantize_colwise_transpose": 0.34067779779434204, "w_quantize_global": 0.13644620776176453, "w_quantize_global_transpose": 0.14925003051757812, "time_standard": 3.407727926969528, "time_rowwise": 2.823755145072937, "time_global": 2.7127638459205627} +{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.0731369256973267, "standard_gw": 1.1365897953510284, "standard_gx": 1.1498592793941498, "rowwise_fwd": 0.5573518574237823, "rowwise_bwd": 0.615488737821579, "global_fwd": 0.5220361053943634, "global_bwd": 0.5939789116382599, "x_quantize_rowwise": 0.15765801072120667, "g_quantize_rowwise": 0.04369020462036133, "w_quantize_rowwise": 0.047359615564346313, "w_quantize_colwise_transpose": 0.5526281893253326, "w_quantize_global": 0.13606995344161987, "w_quantize_global_transpose": 0.15017390251159668, "time_standard": 3.359586000442505, "time_rowwise": 3.1107664108276367, "time_global": 2.7401968836784363} +{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4274885654449463, "standard_gw": 2.1799951791763306, "standard_gx": 2.1426528692245483, "rowwise_fwd": 1.195710152387619, "rowwise_bwd": 1.027170568704605, "global_fwd": 1.1747106909751892, "global_bwd": 1.0251589119434357, "x_quantize_rowwise": 0.08098781108856201, "g_quantize_rowwise": 0.3052949905395508, "w_quantize_rowwise": 0.043764710426330566, "w_quantize_colwise_transpose": 0.33987686038017273, "w_quantize_global": 0.13646483421325684, "w_quantize_global_transpose": 0.14739856123924255, "time_standard": 6.750136613845825, "time_rowwise": 5.172800272703171, "time_global": 5.050010979175568} +{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1661892533302307, "standard_gw": 2.0948275923728943, "standard_gx": 2.306375652551651, "rowwise_fwd": 1.0587647557258606, "rowwise_bwd": 1.1999905109405518, "global_fwd": 1.0296404361724854, "global_bwd": 1.1749230325222015, "x_quantize_rowwise": 0.3054030239582062, "g_quantize_rowwise": 0.08077546954154968, "w_quantize_rowwise": 0.047225505113601685, "w_quantize_colwise_transpose": 0.600133091211319, "w_quantize_global": 0.13613328337669373, "w_quantize_global_transpose": 0.1484006643295288, "time_standard": 6.567392498254776, "time_rowwise": 5.387119948863983, "time_global": 4.97010350227356} +{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.807606339454651, "standard_gw": 4.170913249254227, "standard_gx": 4.117622971534729, "rowwise_fwd": 2.370934933423996, "rowwise_bwd": 1.9481778144836426, "global_fwd": 2.3383721709251404, "global_bwd": 1.9443817436695099, "x_quantize_rowwise": 0.1547597348690033, "g_quantize_rowwise": 0.6000511348247528, "w_quantize_rowwise": 0.04361942410469055, "w_quantize_colwise_transpose": 0.3403201699256897, "w_quantize_global": 0.13600289821624756, "w_quantize_global_transpose": 0.1474134624004364, "time_standard": 13.096142560243607, "time_rowwise": 9.628776460886002, "time_global": 9.491894394159317} +{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.1619837284088135, "standard_gw": 4.181284457445145, "standard_gx": 4.635505378246307, "rowwise_fwd": 1.9684135913848877, "rowwise_bwd": 2.3750364780426025, "global_fwd": 1.9445866346359253, "global_bwd": 2.3551955819129944, "x_quantize_rowwise": 0.6004162132740021, "g_quantize_rowwise": 0.15468522906303406, "w_quantize_rowwise": 0.04730746150016785, "w_quantize_colwise_transpose": 0.5999617278575897, "w_quantize_global": 0.1364201307296753, "w_quantize_global_transpose": 0.14847144484519958, "time_standard": 12.978773564100266, "time_rowwise": 9.927105158567429, "time_global": 9.521059691905975} +{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.52371209859848, "standard_gw": 8.354485034942627, "standard_gx": 8.69860127568245, "rowwise_fwd": 4.717472940683365, "rowwise_bwd": 3.8843750953674316, "global_fwd": 4.645414650440216, "global_bwd": 3.8761012256145477, "x_quantize_rowwise": 0.3024861216545105, "g_quantize_rowwise": 1.1897757649421692, "w_quantize_rowwise": 0.04366785287857056, "w_quantize_colwise_transpose": 0.33988431096076965, "w_quantize_global": 0.1359507441520691, "w_quantize_global_transpose": 0.14724582433700562, "time_standard": 26.576798409223557, "time_rowwise": 18.832147121429443, "time_global": 18.651459366083145} +{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.307881653308868, "standard_gw": 8.214320987462997, "standard_gx": 9.21182706952095, "rowwise_fwd": 3.8919784128665924, "rowwise_bwd": 4.72346693277359, "global_fwd": 3.8761794567108154, "global_bwd": 4.673641175031662, "x_quantize_rowwise": 1.1893920600414276, "g_quantize_rowwise": 0.3024972975254059, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.6039328873157501, "w_quantize_global": 0.13624504208564758, "w_quantize_global_transpose": 0.14867261052131653, "time_standard": 25.734029710292816, "time_rowwise": 18.972668796777725, "time_global": 18.540948629379272} +{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.30372044444084, "standard_gw": 16.480475664138794, "standard_gx": 17.61433482170105, "rowwise_fwd": 9.49602946639061, "rowwise_bwd": 7.768530398607254, "global_fwd": 9.3533955514431, "global_bwd": 7.749464362859726, "x_quantize_rowwise": 0.5977451801300049, "g_quantize_rowwise": 2.3684948682785034, "w_quantize_rowwise": 0.04375725984573364, "w_quantize_colwise_transpose": 0.34042075276374817, "w_quantize_global": 0.13628974556922913, "w_quantize_global_transpose": 0.14671683311462402, "time_standard": 53.398530930280685, "time_rowwise": 37.09545359015465, "time_global": 36.83258220553398} +{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.041003495454788, "standard_gw": 17.770148813724518, "standard_gx": 17.70009845495224, "rowwise_fwd": 7.756810635328293, "rowwise_bwd": 9.502101689577103, "global_fwd": 7.7384114265441895, "global_bwd": 9.36170294880867, "x_quantize_rowwise": 2.3686252534389496, "g_quantize_rowwise": 0.5980581045150757, "w_quantize_rowwise": 0.04723668098449707, "w_quantize_colwise_transpose": 0.6035342812538147, "w_quantize_global": 0.13603642582893372, "w_quantize_global_transpose": 0.1485198736190796, "time_standard": 53.511250764131546, "time_rowwise": 38.64651545882225, "time_global": 38.121502846479416} +{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.598241299390793, "standard_gw": 4.294309765100479, "standard_gx": 4.261095076799393, "rowwise_fwd": 2.0976848900318146, "rowwise_bwd": 1.9718967378139496, "global_fwd": 2.0763762295246124, "global_bwd": 1.9703581929206848, "x_quantize_rowwise": 0.08216872811317444, "g_quantize_rowwise": 0.4405900835990906, "w_quantize_rowwise": 0.1553371548652649, "w_quantize_colwise_transpose": 1.6110725700855255, "w_quantize_global": 0.481240451335907, "w_quantize_global_transpose": 0.5061514675617218, "time_standard": 13.153646141290665, "time_rowwise": 10.653059929609299, "time_global": 9.85119491815567} +{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.35885414481163, "standard_gw": 4.29583340883255, "standard_gx": 4.5370906591415405, "rowwise_fwd": 2.0015686750411987, "rowwise_bwd": 2.097565680742264, "global_fwd": 1.969795674085617, "global_bwd": 2.075403928756714, "x_quantize_rowwise": 0.43984130024909973, "g_quantize_rowwise": 0.08216127753257751, "w_quantize_rowwise": 0.22544339299201965, "w_quantize_colwise_transpose": 2.4342015385627747, "w_quantize_global": 0.48087164759635925, "w_quantize_global_transpose": 0.5099289119243622, "time_standard": 13.19177821278572, "time_rowwise": 11.576615273952484, "time_global": 9.85383614897728} +{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.09888744354248, "standard_gw": 8.230950683355331, "standard_gx": 8.465446531772614, "rowwise_fwd": 4.182614386081696, "rowwise_bwd": 3.747660666704178, "global_fwd": 4.138719290494919, "global_bwd": 3.74777615070343, "x_quantize_rowwise": 0.15515834093093872, "g_quantize_rowwise": 0.8699297904968262, "w_quantize_rowwise": 0.15544891357421875, "w_quantize_colwise_transpose": 1.6132444143295288, "w_quantize_global": 0.48100948333740234, "w_quantize_global_transpose": 0.5051903426647186, "time_standard": 25.795284658670425, "time_rowwise": 18.955007195472717, "time_global": 18.128734081983566} +{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.378107100725174, "standard_gw": 8.923027664422989, "standard_gx": 9.049762040376663, "rowwise_fwd": 3.765825182199478, "rowwise_bwd": 4.183519631624222, "global_fwd": 3.744799643754959, "global_bwd": 4.1590481996536255, "x_quantize_rowwise": 0.8693933486938477, "g_quantize_rowwise": 0.1553073525428772, "w_quantize_rowwise": 0.2258792519569397, "w_quantize_colwise_transpose": 2.4386271834373474, "w_quantize_global": 0.4811100661754608, "w_quantize_global_transpose": 0.5102269351482391, "time_standard": 26.350896805524826, "time_rowwise": 20.5615796148777, "time_global": 18.842913210392} +{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.266115337610245, "standard_gw": 17.671160399913788, "standard_gx": 17.10302010178566, "rowwise_fwd": 8.347474038600922, "rowwise_bwd": 7.514089345932007, "global_fwd": 8.263226598501205, "global_bwd": 7.487393915653229, "x_quantize_rowwise": 0.3021806478500366, "g_quantize_rowwise": 1.7319358885288239, "w_quantize_rowwise": 0.15519559383392334, "w_quantize_colwise_transpose": 1.6133114695549011, "w_quantize_global": 0.48247724771499634, "w_quantize_global_transpose": 0.506427139043808, "time_standard": 53.04029583930969, "time_rowwise": 37.3353473842144, "time_global": 36.44480183720589} +{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 17.73649826645851, "standard_gw": 16.359902918338776, "standard_gx": 18.0993489921093, "rowwise_fwd": 7.493957877159119, "rowwise_bwd": 8.352488279342651, "global_fwd": 7.486194372177124, "global_bwd": 8.28903540968895, "x_quantize_rowwise": 1.7313472926616669, "g_quantize_rowwise": 0.30205026268959045, "w_quantize_rowwise": 0.2255477011203766, "w_quantize_colwise_transpose": 2.4363920092582703, "w_quantize_global": 0.4815347492694855, "w_quantize_global_transpose": 0.5103759467601776, "time_standard": 52.195750176906586, "time_rowwise": 36.90168634057045, "time_global": 35.16044095158577} +{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.309611052274704, "standard_gw": 32.85098075866699, "standard_gx": 34.34552624821663, "rowwise_fwd": 16.74525812268257, "rowwise_bwd": 15.026237815618515, "global_fwd": 16.574162989854813, "global_bwd": 14.977734535932541, "x_quantize_rowwise": 0.5954466760158539, "g_quantize_rowwise": 3.4569576382637024, "w_quantize_rowwise": 0.15521422028541565, "w_quantize_colwise_transpose": 1.6133897006511688, "w_quantize_global": 0.4822872579097748, "w_quantize_global_transpose": 0.5065612494945526, "time_standard": 103.50611805915833, "time_rowwise": 70.44348493218422, "time_global": 69.44413110613823} +{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.40017828345299, "standard_gw": 33.037226647138596, "standard_gx": 36.30436211824417, "rowwise_fwd": 15.043705701828003, "rowwise_bwd": 16.756191849708557, "global_fwd": 15.011314302682877, "global_bwd": 16.580048948526382, "x_quantize_rowwise": 3.4548528492450714, "g_quantize_rowwise": 0.5951337516307831, "w_quantize_rowwise": 0.22584572434425354, "w_quantize_colwise_transpose": 2.4329908192157745, "w_quantize_global": 0.4813261330127716, "w_quantize_global_transpose": 0.5101598799228668, "time_standard": 104.74176704883575, "time_rowwise": 71.54594734311104, "time_global": 69.67006251215935} +{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 73.40333238244057, "standard_gw": 73.76311346888542, "standard_gx": 70.41774317622185, "rowwise_fwd": 33.37597846984863, "rowwise_bwd": 30.345775187015533, "global_fwd": 33.00366923213005, "global_bwd": 30.218638479709625, "x_quantize_rowwise": 1.1825822293758392, "g_quantize_rowwise": 6.902601569890976, "w_quantize_rowwise": 0.15529245138168335, "w_quantize_colwise_transpose": 1.6109198331832886, "w_quantize_global": 0.48149004578590393, "w_quantize_global_transpose": 0.5066059529781342, "time_standard": 217.58418902754784, "time_rowwise": 147.33626320958138, "time_global": 146.05870097875595} +{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 71.5160183608532, "standard_gw": 73.76786693930626, "standard_gx": 72.98104092478752, "rowwise_fwd": 30.291248112916946, "rowwise_bwd": 33.36654230952263, "global_fwd": 30.181586742401123, "global_bwd": 33.082425594329834, "x_quantize_rowwise": 6.902430206537247, "g_quantize_rowwise": 1.1815279722213745, "w_quantize_rowwise": 0.2262219786643982, "w_quantize_colwise_transpose": 2.4421699345111847, "w_quantize_global": 0.4816502332687378, "w_quantize_global_transpose": 0.5105249583721161, "time_standard": 218.26492622494698, "time_rowwise": 148.17800745368004, "time_global": 146.1080126464367} diff --git a/benchmarking/switchback/make_plot_with_jsonl.py b/benchmarking/switchback/make_plot_with_jsonl.py new file mode 100644 index 0000000..8897564 --- /dev/null +++ b/benchmarking/switchback/make_plot_with_jsonl.py @@ -0,0 +1,138 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import os + +import matplotlib.gridspec as gridspec + +cmap=plt.get_cmap('cool') + +if __name__ == '__main__': + + fig = plt.figure(tight_layout=True, figsize=(12,3.5)) + gs = gridspec.GridSpec(1, 2) + + dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] + batch_size_for_plot1 = 32768 + batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17] + dims_to_xtick = [1024, 2048, 4096] + logscale_plot1 = True + + ax = fig.add_subplot(gs[0, 0]) + + # TODO: change this to what you want. + rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) + df = rdf[rdf.batch_size == batch_size_for_plot1] + + # first plot the time occupied by different operations + for k, marker, ls, color, name in [ + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), + + ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), + ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), + ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), + + ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), + ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), + + ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), + ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), + ('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'), + ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'), + ]: + xs = [] + ys = [] + for embed_dim in dims_to_consider: + # average over dim -> 4*dim and 4*dim -> dim + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + + + ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) + + + ax.set_xlabel('dim', fontsize=13) + ax.set_ylabel('time (ms)', fontsize=13) + + ax.grid() + + ax.set_xscale('log') + if logscale_plot1: + ax.set_yscale('log') + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks(dims_to_xtick) + ax.set_xticklabels(dims_to_xtick) + ax.set_xticks([], minor=True) + + leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) + leg.get_texts()[0].set_fontweight('bold') + leg.get_texts()[1].set_fontweight('bold') + plt.subplots_adjust(left=0.1) + ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) + + + ax = fig.add_subplot(gs[0, 1]) + + # now plot the % speedup for different batch sizes + for j, batch_size in enumerate(batch_sizes_for_plot2): + all_xs, all_ys = [], [] + for k, marker, ls, color, name in [ + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ]: + + xs, ys = [], [] + df = rdf[rdf.batch_size == batch_size] + for embed_dim in dims_to_consider: + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + all_xs.append(xs) + all_ys.append(ys) + + color = cmap(j * 0.25) + real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] + markers = ['^', 'v', 'P', 'o'] + ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) + + ax.legend() + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% speedup', fontsize=13) + + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks(dims_to_xtick) + ax.set_xticklabels(dims_to_xtick) + ax.set_xticks([], minor=True) + + ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) + + + + plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') + diff --git a/benchmarking/switchback/plot_with_info.pdf b/benchmarking/switchback/plot_with_info.pdf new file mode 100644 index 0000000..d186e91 Binary files /dev/null and b/benchmarking/switchback/plot_with_info.pdf differ diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py new file mode 100644 index 0000000..9ad9911 --- /dev/null +++ b/benchmarking/switchback/speed_benchmark.py @@ -0,0 +1,102 @@ +import json + +import time +import torch +import torch.nn as nn + +from bitsandbytes.triton.quantize_rowwise import quantize_rowwise +from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose +from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze + +# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. + +def get_time(k, fn, info_dict): + + for _ in range(repeat // 2): + fn() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + fn() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info_dict[k] = ms + +if __name__ == '__main__': + torch.manual_seed(0) + wm = 4 + for dim in [1024, 1280, 1408, 1664, 2048, 4096]: + # note "batch_size" is actually "batch_size * embed_dim", which is why it's large + for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: + + # switch switches dim_in and dim_out + for switch in [False, True]: + + # hparams + repeat = 64 + batch_size = batch_size + dim_out = dim * wm + dim_in = dim + if switch: + dim_out = dim + dim_in = wm * dim + + dim_in = round(dim_in) + dim_out = round(dim_out) + + # simulate forward pass + x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() + g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() + w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() + + x_int8 = x.clone().to(torch.int8) + g_int8 = g.clone().to(torch.int8) + w_int8 = w.clone().to(torch.int8) + wt_int8 = w.t().contiguous().clone().to(torch.int8) + state_x_rowwise = x.max(dim=1)[0] + state_g_rowwise = g.max(dim=1)[0] + state_w_columnwise = w.max(dim=0)[0] + state_w_rowwise = w.max(dim=1)[0] + state_w_global = w.max() + + info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} + + get_time('standard_fwd', lambda : x.matmul(w.t()), info) + get_time('standard_gw', lambda : g.t().matmul(x), info) + get_time('standard_gx', lambda : g.matmul(w), info) + get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) + get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) + get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) + get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) + get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) + get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) + get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) + get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) + get_time('w_quantize_global', lambda : quantize_global(w), info) + get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) + + time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] + time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] + time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] + + print('TOTAL STANDARD', time_standard) + print('TOTAL ROWWISE', time_rowwise) + print('TOTAL GLOBAL', time_global) + + print('speedup', -100*(time_global - time_standard)/time_standard) + + info['time_standard'] = time_standard + info['time_rowwise'] = time_rowwise + info['time_global'] = time_global + + info_json = json.dumps(info) + + # TODO: change this to what you want. + with open("speed_benchmark/info.jsonl", "a") as file: + file.write(info_json + "\n") diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 041df4b..f35a3b5 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,13 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, utils +from . import cuda_setup, utils, research from .autograd._functions import ( MatmulLtState, bmm_cublas, matmul, matmul_cublas, mm_cublas, + matmul_4bit ) from .cextension import COMPILED_WITH_CUDA from .nn import modules diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index cdac2ae..c2298c8 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -2,7 +2,7 @@ import operator import warnings from dataclasses import dataclass from functools import reduce # Required in Python 3 -from typing import Tuple, Optional +from typing import Tuple, Optional, List import torch @@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool: return True +def _get_tile_size(format): + assert format in ( + "col_turing", + "col_ampere", + ), f"please find this assert and manually enter tile size for {format}" + return (8, 32) if format == "col_turing" else (32, 32) + + +def get_tile_inds(format, device): + transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device) + with torch.no_grad(): + return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device) + @dataclass class MatmulLtState: _tile_indices: Optional[torch.Tensor] = None @@ -267,20 +280,10 @@ class MatmulLtState: self.SBt = None self.CBt = None - def get_tile_size(self): - assert self.formatB in ( - "col_turing", - "col_ampere", - ), f"please find this assert and manually enter tile size for {self.formatB}" - return (8, 32) if self.formatB == "col_turing" else (32, 32) - @property def tile_indices(self): if self._tile_indices is None: - device = self.CxB.device - transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device) - with torch.no_grad(): - self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device) + self._tile_indices = get_tile_inds(self.formatB, self.CxB.device) return self._tile_indices @@ -424,10 +427,10 @@ class MatMul8bitLt(torch.autograd.Function): ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype if any(ctx.needs_input_grad[:2]): - ctx.tensors = (CAt, subA) + ctx.tensors = (CAt, subA, A) ctx.tensor_states = (SCAt, state.idx) else: - ctx.tensors = [None, None] + ctx.tensors = [None, None, A] ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) @@ -440,7 +443,7 @@ class MatMul8bitLt(torch.autograd.Function): bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad - CAt, subA = ctx.tensors + CAt, subA, A = ctx.tensors SCAt, idx = ctx.tensor_states formatB = ctx.formatB state = ctx.state @@ -487,6 +490,64 @@ class MatMul8bitLt(torch.autograd.Function): return grad_A, grad_B, None, grad_bias, None +class MatMul4Bit(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=None): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + B_shape = state[1] + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + + # 1. Dequantize + # 2. MatmulnN + output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias) + + # 3. Save state + ctx.state = state + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (A, B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + + req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad + A, B = ctx.tensors + state = ctx.state + + grad_A, grad_B, grad_bias = None, None, None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # not supported by PyTorch. TODO: create work-around + #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t()) + + return grad_A, grad_B, None, grad_bias, None + + def matmul( A: tensor, B: tensor, @@ -499,3 +560,8 @@ def matmul( if threshold > 0.0: state.threshold = threshold return MatMul8bitLt.apply(A, B, out, bias, state) + + +def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): + assert quant_state is not None + return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 85bef00..131edc5 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -18,17 +18,24 @@ try: CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().print_log_stack() raise RuntimeError(''' - CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! - If you cannot find any issues and suspect a bug, please open an issue with detals about your environment: - https://github.com/TimDettmers/bitsandbytes/issues''') - lib.cadam32bit_g32 + CUDA Setup failed despite GPU being available. Please run the following command to get more information: + + python -m bitsandbytes + + Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them + to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes + and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') + lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p + lib.cget_managed_ptr.restype = ct.c_void_p COMPILED_WITH_CUDA = True -except AttributeError: +except AttributeError as ex: warn("The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers and GPU quantization are unavailable.") + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") COMPILED_WITH_CUDA = False + print(str(ex)) + # print the setup details after checking for errors so we do not print twice if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py index ac0fb02..4cedf62 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/cuda_setup/main.py @@ -44,6 +44,9 @@ class CUDASetup: raise RuntimeError("Call get_instance() instead") def generate_instructions(self): + if getattr(self, 'error', False): return + print(self.error) + self.error = True if self.cuda is None: self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') @@ -93,6 +96,7 @@ class CUDASetup: self.has_printed = False self.lib = None self.initialized = False + self.error = False def run_cuda_setup(self): self.initialized = True diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8d95789..afa346e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -9,6 +9,8 @@ import random import torch import itertools import math +from scipy.stats import norm +import numpy as np from functools import reduce # Required in Python 3 from typing import Tuple @@ -26,77 +28,95 @@ name2qmap = {} if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16) str2optimizer32bit["momentum"] = ( - lib.cmomentum32bit_g32, - lib.cmomentum32bit_g16, + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, ) str2optimizer32bit["rmsprop"] = ( - lib.crmsprop32bit_g32, - lib.crmsprop32bit_g16, - ) - str2optimizer32bit["lion"] = ( - lib.clion32bit_g32, - lib.clion32bit_g16, + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, ) + str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16) str2optimizer32bit["adagrad"] = ( - lib.cadagrad32bit_g32, - lib.cadagrad32bit_g16, + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, ) - str2optimizer32bit["lars"] = ( - lib.cmomentum32bit_g32, - lib.cmomentum32bit_g16, - ) - str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} str2optimizer8bit["adam"] = ( - lib.cadam_static_8bit_g32, - lib.cadam_static_8bit_g16, + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, ) str2optimizer8bit["momentum"] = ( - lib.cmomentum_static_8bit_g32, - lib.cmomentum_static_8bit_g16, + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, ) str2optimizer8bit["rmsprop"] = ( - lib.crmsprop_static_8bit_g32, - lib.crmsprop_static_8bit_g16, + lib.crmsprop_static_8bit_grad_32, + lib.crmsprop_static_8bit_grad_16, ) str2optimizer8bit["lion"] = ( - lib.clion_static_8bit_g32, - lib.clion_static_8bit_g16, + lib.clion_static_8bit_grad_32, + lib.clion_static_8bit_grad_16, ) str2optimizer8bit["lamb"] = ( - lib.cadam_static_8bit_g32, - lib.cadam_static_8bit_g16, + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, ) str2optimizer8bit["lars"] = ( - lib.cmomentum_static_8bit_g32, - lib.cmomentum_static_8bit_g16, + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, ) str2optimizer8bit_blockwise = {} str2optimizer8bit_blockwise["adam"] = ( - lib.cadam_8bit_blockwise_fp32, - lib.cadam_8bit_blockwise_fp16, + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["momentum"] = ( - lib.cmomentum_8bit_blockwise_fp32, - lib.cmomentum_8bit_blockwise_fp16, + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, ) str2optimizer8bit_blockwise["rmsprop"] = ( - lib.crmsprop_8bit_blockwise_fp32, - lib.crmsprop_8bit_blockwise_fp16, + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, ) str2optimizer8bit_blockwise["lion"] = ( - lib.clion_8bit_blockwise_fp32, - lib.clion_8bit_blockwise_fp16, + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + lib.clion_8bit_blockwise_grad_bf16, ) str2optimizer8bit_blockwise["adagrad"] = ( - lib.cadagrad_8bit_blockwise_fp32, - lib.cadagrad_8bit_blockwise_fp16, + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, ) +class GlobalPageManager: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.paged_tensors = [] + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def prefetch_all(self, to_cpu=False): + # assume the first added, will be hte + # ones that are used first, so swap them in last + # in the case they are evicted again + for t in self.paged_tensors[::-1]: + prefetch_tensor(t, to_cpu) + + class CUBLAS_Context: _instance = None @@ -106,11 +126,6 @@ class CUBLAS_Context: def initialize(self): self.context = {} - # prev_device = torch.cuda.current_device() - # for i in range(torch.cuda.device_count()): - # torch.cuda.set_device(torch.device('cuda', i)) - # self.context.append(ct.c_void_p(lib.get_context())) - # torch.cuda.set_device(prev_device) @classmethod def get_instance(cls): @@ -144,6 +159,61 @@ class Cusparse_Context: cls._instance.initialize() return cls._instance +dtype2bytes = {} +dtype2bytes[torch.float32] = 4 +dtype2bytes[torch.float16] = 2 +dtype2bytes[torch.bfloat16] = 2 +dtype2bytes[torch.uint8] = 1 +dtype2bytes[torch.int8] = 1 + +def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): + num_bytes = dtype2bytes[dtype]*prod(shape) + cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) + c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) + new_array = np.ctypeslib.as_array(c_ptr, shape=shape) + out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape) + out.is_paged = True + out.page_deviceid = device.index + return out + +def prefetch_tensor(A, to_cpu=False): + assert A.is_paged, 'Only paged tensors can be prefetched!' + if to_cpu: + deviceid = -1 + else: + deviceid = A.page_deviceid + + num_bytes = dtype2bytes[A.dtype]*A.numel() + lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + +def elementwise_func(func_name, A, B, value, prefetch=True): + func = None + if A.dtype == torch.float32: + func = getattr(lib, f'c{func_name}_fp32', None) + cvalue = ct.c_float(value) + elif A.dtype == torch.uint8: + func = getattr(lib, f'c{func_name}_uint8', None) + cvalue = ct.c_uint8(value) + + if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + + is_managed = getattr(A, 'is_managed', False) + if is_managed and prefetch: + prefetch_tensor(A) + if B is not None: prefetch_tensor(B) + + func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) + if A.is_paged or B.is_paged: + # paged function are fully asynchronous + # if we return from this function, we want to the tensor + # to be in the correct state, that is the final state after the + # operation occured. So we synchronize. + torch.cuda.synchronize() + +def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) +def arange(A, device=None): elementwise_func('arange', A, None, 0) +def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + def create_linear_map(signed=True, total_bits=8, add_zero=True): sign = (-1.0 if signed else 0.0) @@ -161,9 +231,27 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True): return values else: l = values.numel()//2 - #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) +def create_normal_map(offset=0.9677083, use_extra_value=True): + + if use_extra_value: + # one more positive value, this is an asymmetric type + v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() + v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() + v = v1 + v2 + v3 + else: + v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() + v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() + v = v1 + v2 + v3 + + values = torch.Tensor(v) + values = values.sort().values + values /= values.max() + assert values.numel() == 256 + return values def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): e = exponent_bits @@ -180,7 +268,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) values = [] lst = list(itertools.product([0, 1], repeat=precision_bits)) #for ev in evalues: - bias = 2**(exponent_bits-1)-1 + bias = 2**(exponent_bits-1) for evalue in range(2**(exponent_bits)): for bit_pattern in lst: value = (1 if evalue != 0 else 0) @@ -188,10 +276,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8) value += pval*(2**-(i+1)) if evalue == 0: # subnormals - value = value*2**-(bias-1) + value = value*2**-(bias) else: # normals - value = value*2**-(evalue-bias-2) + value = value*2**-(evalue-bias-1) values.append(value) if signed: values.append(-value) @@ -289,9 +377,17 @@ def get_special_format_str(): def is_on_gpu(tensors): on_gpu = True + gpu_ids = set() for t in tensors: if t is None: continue # NULL pointers are fine - on_gpu &= t.device.type == 'cuda' + is_paged = getattr(t, 'is_paged', False) + on_gpu &= (t.device.type == 'cuda' or is_paged) + if not is_paged: + gpu_ids.add(t.device.index) + if not on_gpu: + raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') + if len(gpu_ids) > 1: + raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') return on_gpu def get_ptr(A: Tensor) -> ct.c_void_p: @@ -469,7 +565,7 @@ def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, n return out -def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor: +def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: """ Quantize tensor A in blocks of size 4096 values. @@ -485,8 +581,6 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra The quantization map. absmax : torch.Tensor The absmax values. - rand : torch.Tensor - The tensor for stochastic rounding. out : torch.Tensor The output tensor (8-bit). @@ -518,33 +612,30 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra cblocksize = ct.c_int32(blocksize) prev_device = pre_call(A.device) code = code.to(A.device) - if rand is not None: - is_on_gpu([code, A, out, absmax, rand]) - assert blocksize==4096 - assert rand.numel() >= 1024 - rand_offset = random.randint(0, 1023) - if A.dtype == torch.float32: - lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) - elif A.dtype == torch.float16: - lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + is_on_gpu([code, A, out, absmax]) + if A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + elif A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: - is_on_gpu([code, A, out, absmax]) - if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) - elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) else: # cpu code = code.cpu() - assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) - return out, (absmax, code) + if nested: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + state = [qabsmax, code, blocksize, nested, offset, state2] + else: + state = [absmax, code, blocksize, nested, None, None] + + + + return out, state def dequantize_blockwise( @@ -554,6 +645,7 @@ def dequantize_blockwise( code: Tensor = None, out: Tensor = None, blocksize: int = 4096, + nested=False ) -> Tensor: """ Dequantizes blockwise quantized values. @@ -588,10 +680,15 @@ def dequantize_blockwise( if out is None: out = torch.zeros_like(A, dtype=torch.float32) + if quant_state is None: - quant_state = (absmax, code) + quant_state = (absmax, code, blocksize) + assert absmax is not None and out is not None else: - absmax, code = quant_state + absmax, code, blocksize, nested, offset, state2 = quant_state + if nested: + absmax = dequantize_blockwise(absmax, state2) + absmax += offset if A.device.type != 'cpu': @@ -599,7 +696,7 @@ def dequantize_blockwise( code = code.to(A.device) if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - is_on_gpu([A, out]) + is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) elif out.dtype == torch.float16: @@ -613,6 +710,164 @@ def dequantize_blockwise( return out +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') + +def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') + +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if A.device.type != 'cuda': + raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device) + + + if out is None: + out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + prev_device = pre_call(A.device) + is_on_gpu([A, out, absmax]) + + if A.dtype == torch.float32: + if quant_type == 'fp4': + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.float16: + if quant_type == 'fp4': + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + #code = create_custom_map().to(absmax.device) + #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) + qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + del absmax + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + else: + state = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + + return out, state + +def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') + +def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') + +def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : tuple(torch.Tensor, torch.Size, torch.dtype) + Tuple of absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + + if quant_state is None: + assert absmax is not None and out is not None + shape = out.shape + dtype = out.dtype + else: + absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state + + + if compressed_stats is not None: + offset, state2 = compressed_stats + absmax = dequantize_blockwise(absmax, state2) + absmax += offset + + if out is None: + out = torch.empty(shape, dtype=dtype, device=A.device) + + n = out.numel() + + + device = pre_call(A.device) + is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: + if quant_type == 'fp4': + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + elif out.dtype == torch.float16: + if quant_type == 'fp4': + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() + else: return out + def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: if code is None: @@ -765,55 +1020,36 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - if optimizer_name not in str2optimizer32bit: - raise NotImplementedError( - f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' - ) - prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, unorm_vec]) - if g.dtype == torch.float32 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][0]( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - elif g.dtype == torch.float16 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][1]( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optim_func = None + if g.dtype == torch.float32: + optim_func = str2optimizer32bit[optimizer_name][0] + elif g.dtype == torch.float16: + optim_func = str2optimizer32bit[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" - ) + raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + + is_on_gpu([g, p, state1, state2, unorm_vec]) + prev_device = pre_call(g.device) + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel())) post_call(prev_device) @@ -970,54 +1206,45 @@ def optimizer_update_8bit_blockwise( skip_zeros=False, ) -> None: + optim_func = None prev_device = pre_call(g.device) is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][0]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][1]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and + len(str2optimizer8bit_blockwise[optimizer_name])==3): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) post_call(prev_device) + is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + prev_device = pre_call(g.device) + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + post_call(prev_device) def percentile_clipping( grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 @@ -1171,6 +1398,123 @@ def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8 return sout +def cutlass3_gemm( + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, + state=None +): + #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + if state is None: + Bshape = B.shape + bout = Bshape[1] + else: + Bshape = state[1] + bout = Bshape[0] + if out is None: + out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + + sA = A.shape + sB = B.shape + if transposed_A and len(sA) == 2: + sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: + sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: + sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: + sB = (sB[0], sB[2], sB[0]) + # this is a mess: cuBLAS expect column major, but PyTorch is row major. + # So to perform the matrix multiplication, we have to treat A, B, and C matrices + # (transpose of row major is column major) + # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these + + # matrices in the input arguments for cuBLAS + # column major: A @ B = C: [m, k] @ [k, n] = [m, n] + # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] + # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] + if len(sB) == 2: + if B.stride()[0] == B.shape[1]: + transposed_B = False + elif B.stride()[1] == B.shape[0]: + transposed_B = True + if len(A.shape) == 2: + if A.stride()[0] == A.shape[1]: + transposed_A = False + elif A.stride()[1] == A.shape[0]: + transposed_A = True + else: + if A.stride()[1] == A.shape[2]: + transposed_A = False + elif A.stride()[2] == A.shape[1]: + transposed_A = True + + if len(sA) == 2: + n = sA[0] + ldb = A.stride()[1 if transposed_A else 0] + elif len(sA) == 3 and len(sB) == 2: + n = sA[0] * sA[1] + ldb = sA[2] + + m = sB[1] + k = sB[0] + lda = B.stride()[0] + ldc = sB[1] + elif len(sB) == 3: + # special case + assert len(sA) == 3 + if not (sA[0] == sB[0] and sA[1] == sB[1]): + raise ValueError( + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + ) + + transposed_A = True + transposed_B = False + + m = sB[2] + n = sA[2] + k = sB[0] * sB[1] + + lda = n + ldb = sA[2] + ldc = m + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + + # B^T @ A^T = C^T + # [km, nk -> mn] + #lda = ldb = ldc = 1 + #lda = 1 + if state is not None: + m = Bshape[0] + k = Bshape[1] + lda = Bshape[0] + ldc = Bshape[0] + ldb = (ldb+1)//2 + #print(m, n, k, lda, ldb, ldc) + is_on_gpu([B, A, out]) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + + if B.dtype == torch.uint8: + lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + elif A.dtype == torch.float32: + lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + elif A.dtype == torch.float16: + lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + else: + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + + return out + + + def igemm( A: Tensor, @@ -1845,8 +2189,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ccolsB = ct.c_int32(B.shape[1]) cldb = ct.c_int32(ldb) cldc = ct.c_int32(ldc) - # print(cooA.rowidx[:64]) - # print(cooA.colidx[:64].sort()[0]) is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) if B.dtype == torch.float16: @@ -2044,3 +2386,8 @@ def extract_outliers(A, SA, idx): post_call(prev_device) return out + +def pipeline_test(A, batch_size): + out = torch.zeros_like(A) + lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) + return out diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index edc595a..49d7b5c 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,4 +2,5 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .modules import Int8Params, Linear8bitLt, StableEmbedding +from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb +from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 8c4d688..b10d45a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -10,8 +10,9 @@ from torch import Tensor, device, dtype, nn import bitsandbytes as bnb import bitsandbytes.functional -from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout +from bitsandbytes.autograd._functions import undo_layout, get_tile_inds from bitsandbytes.optim import GlobalOptimManager +from bitsandbytes.utils import OutlierTracer, find_outlier_dims T = TypeVar("T", bound="torch.nn.Module") @@ -135,6 +136,101 @@ class Embedding(torch.nn.Embedding): return emb +class Params4bit(torch.nn.Parameter): + def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): + if data is None: + data = torch.empty(0) + + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.blocksize = blocksize + self.compress_statistics = compress_statistics + self.quant_type = quant_type + self.quant_state = quant_state + self.data = data + return self + + def cuda(self, device): + w = self.data.contiguous().half().cuda(device) + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + self.data = w_4bit + self.quant_state = quant_state + + return self + + @overload + def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: + ... + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): + return self.cuda(device) + else: + s = self.quant_state + if s is not None: + # make sure the quantization state is on the right device + s[0] = s[0].to(device) + if self.compress_statistics: + # TODO: refactor this. This is a nightmare + # for 4-bit: + # state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + # state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + #s[-2][0] = s[-2][0].to(device) # offset + #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax + + # for 8-bit + s[-2][0] = s[-2][0].to(device) # offset + s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics + s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook + new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, quant_state=self.quant_state, + blocksize=self.blocksize, compress_statistics=self.compress_statistics, + quant_type=self.quant_type) + + return new_param + +class Linear4bit(nn.Linear): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'): + super().__init__(input_features, output_features, bias) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) + self.compute_dtype = compute_dtype + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, 'quant_state', None) is None: + print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) + + out = out.to(inp_dtype) + + return out + +class LinearFP4(Linear4bit): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4') + +class LinearNF4(Linear4bit): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4') + + class Int8Params(torch.nn.Parameter): def __new__( @@ -210,6 +306,18 @@ class Int8Params(torch.nn.Parameter): return new_param +def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + weight = state_dict.get(f"{prefix}weight") + if weight is None: + # if the state dict has no weights for this layer (e.g., LoRA finetuning), do nothing + return + weight_format = state_dict.pop(f"{prefix}weight_format", "row") + + if weight_format != "row": + tile_indices = get_tile_inds(weight_format, weight.device) + state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices) + + class Linear8bitLt(nn.Linear): def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0, index=None): @@ -225,52 +333,55 @@ class Linear8bitLt(nn.Linear): self.state.use_pool = True self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + self._register_load_state_dict_pre_hook(maybe_rearrange_weight) def _save_to_state_dict(self, destination, prefix, keep_vars): - if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None: - # reorder weight layout back from ampere/turing to row - reorder_layout = True - weight_clone = self.weight.data.clone() - else: - reorder_layout = False + super()._save_to_state_dict(destination, prefix, keep_vars) - try: - if reorder_layout: - self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices) + # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data + scb_name = "SCB" - super()._save_to_state_dict(destination, prefix, keep_vars) + # case 1: .cuda was called, SCB is in self.weight + param_from_weight = getattr(self.weight, scb_name) + # case 2: self.init_8bit_state was called, SCB is in self.state + param_from_state = getattr(self.state, scb_name) + # case 3: SCB is in self.state, weight layout reordered after first forward() + layout_reordered = self.state.CxB is not None - # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data - weight_name = "SCB" + key_name = prefix + f"{scb_name}" + format_name = prefix + "weight_format" - # case 1: .cuda was called, SCB is in self.weight - param_from_weight = getattr(self.weight, weight_name) - # case 2: self.init_8bit_state was called, SCB is in self.state - param_from_state = getattr(self.state, weight_name) - - key_name = prefix + f"{weight_name}" + if not self.state.has_fp16_weights: if param_from_weight is not None: destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() - elif not self.state.has_fp16_weights and param_from_state is not None: + destination[format_name] = "row" + elif param_from_state is not None and not layout_reordered: destination[key_name] = param_from_state if keep_vars else param_from_state.detach() - finally: - if reorder_layout: - self.weight.data = weight_clone + destination[format_name] = "row" + elif param_from_state is not None: + destination[key_name] = param_from_state if keep_vars else param_from_state.detach() + destination[format_name] = self.state.formatB def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - for key in unexpected_keys: + unexpected_copy = list(unexpected_keys) + + for key in unexpected_copy: input_name = key[len(prefix):] if input_name == "SCB": if self.weight.SCB is None: - # buffers not yet initialized, can't call them directly without + # buffers not yet initialized, can't access them directly without quantizing first raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " "not supported. Please call module.cuda() before module.load_state_dict()") input_param = state_dict[key] self.weight.SCB.copy_(input_param) + + if self.state.SCB is not None: + self.state.SCB = self.weight.SCB + unexpected_keys.remove(key) def init_8bit_state(self): @@ -289,6 +400,7 @@ class Linear8bitLt(nn.Linear): self.bias.data = self.bias.data.to(x.dtype) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + if not self.state.has_fp16_weights: if self.state.CB is not None and self.state.CxB is not None: # we converted 8-bit row major to turing/ampere format in the first inference pass @@ -296,3 +408,71 @@ class Linear8bitLt(nn.Linear): del self.state.CB self.weight.data = self.state.CxB return out + + +class OutlierAwareLinear(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.outlier_dim = None + self.is_quantized = False + + def forward_with_outliers(self, x, outlier_idx): + raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') + + def quantize_weight(self, w, outlier_idx): + raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') + + def forward(self, x): + if self.outlier_dim is None: + tracer = OutlierTracer.get_instance() + if not tracer.is_initialized(): + print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') + outlier_idx = tracer.get_outliers(self.weight) + #print(outlier_idx, tracer.get_hvalue(self.weight)) + self.outlier_dim = outlier_idx + + if not self.is_quantized: + w = self.quantize_weight(self.weight, self.outlier_dim) + self.weight.data.copy_(w) + self.is_quantized = True + +class SwitchBackLinearBnb(nn.Linear): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__( + input_features, output_features, bias + ) + self.state = bnb.MatmulLtState() + self.index = index + + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights + ) + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x): + self.state.is_training = self.training + + if self.weight.CB is not None: + self.init_8bit_state() + + out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py new file mode 100644 index 0000000..6fbf583 --- /dev/null +++ b/bitsandbytes/nn/triton_based_modules.py @@ -0,0 +1,258 @@ +import torch +import torch.nn as nn +import time +from functools import partial + +from bitsandbytes.triton.triton_utils import is_triton_available + +from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise +from bitsandbytes.triton.quantize_rowwise import quantize_rowwise +from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose +from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze + + +class _switchback_global(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + # reshape input to [N * L, D] + X = X_3D.view(-1, X_3D.size(-1)) + + # rowwise quantize for X, global quantize for W + X_int8, state_X = quantize_rowwise(X) + W_int8, state_W = quantize_global(W) + + # save for backward. + ctx.save_for_backward = X, W + + # matmult, fused dequant and add bias + # call "mixed" because we are mixing rowwise quantized and global quantized + return int8_matmul_mixed_dequanitze( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + # reshape input to [N_out * L, D] + G = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X = grad_W = grad_bias = None + + X, W = ctx.save_for_backward + if ctx.needs_input_grad[0]: + # rowwise quantize for G, global quantize for W + # for W, we also fuse the transpose operation because only A @ B^T is supported + # so we transpose once then call .t() in the matmul + G_int8, state_G = quantize_rowwise(G) + W_int8, state_W = quantize_global_transpose(W) + grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + # backward pass uses standard weight grad + grad_W = torch.matmul(G.t(), X.to(G.dtype)) + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + + return grad_X, grad_W, grad_bias + +class _switchback_vectorrize(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + # reshape input to [N * L, D] + X = X_3D.view(-1, X_3D.size(-1)) + + ctx.save_for_backward = X, W + # rowwise quantize for X + # columnwise quantize for W (first rowwise, transpose later) + X_int8, state_X = quantize_rowwise(X) + W_int8, state_W = quantize_rowwise(W) + + # matmult, fused dequant and add bias + # call kernel which expects rowwise quantized X and W + return int8_matmul_rowwise_dequantize( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + X, W = ctx.save_for_backward + + G = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X = grad_W = grad_bias = None + + if ctx.needs_input_grad[0]: + # rowwise quantize for G, columnwise quantize for W and fused transpose + # we call .t() for weight later because only A @ B^T is supported + G_int8, state_G = quantize_rowwise(G) + W_int8, state_W = quantize_columnwise_and_transpose(W) + grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + # backward pass uses standard weight grad + grad_W = torch.matmul(G.t(), X.to(G.dtype)) + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + + return grad_X, grad_W, grad_bias + +class _switchback_global_mem_efficient(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + # reshape input to [N * L, D] + X = X_3D.view(-1, X_3D.size(-1)) + X_3D_sz = X_3D.size() + + # rowwise quantize for X, global quantize for W + X_int8, state_X = quantize_rowwise(X) + del X + W_int8, state_W = quantize_global(W) + + # save for backward. + ctx.save_for_backward = X_int8, state_X, W_int8, state_W + + # matmult, fused dequant and add bias + # call "mixed" because we are mixing rowwise quantized and global quantized + return int8_matmul_mixed_dequanitze( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D_sz[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + # reshape input to [N_out * L, D] + G = G_3D.reshape(-1, G_3D.size(-1)) + G_3D_sz = G_3D.size() + + grad_X = grad_W = grad_bias = None + + X_int8, state_X, W_int8, state_W = ctx.save_for_backward + if ctx.needs_input_grad[1]: + real_X = dequantize_rowwise(X_int8, state_X) + del X_int8 + grad_W = torch.matmul(G.t(), real_X.to(G.dtype)) + del real_X + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + if ctx.needs_input_grad[0]: + G_int8, state_G = quantize_rowwise(G) + del G + W_int8 = W_int8.t().contiguous() + grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( + *G_3D_sz[:-1], -1 + ) + + return grad_X, grad_W, grad_bias + +class SwitchBackLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + vector_wise_quantization: bool = False, + mem_efficient : bool = False, + ): + super().__init__(in_features, out_features, bias, device, dtype) + + if not is_triton_available: + raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. + Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') + + # By default, we use the global quantization. + self.vector_wise_quantization = vector_wise_quantization + if self.vector_wise_quantization: + self._fn = _switchback_vectorrize + if mem_efficient: + print('mem efficient is not supported for vector-wise quantization.') + exit(1) + else: + if mem_efficient: + self._fn = _switchback_global_mem_efficient + else: + self._fn = _switchback_global + + def prepare_for_eval(self): + # If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass. + # Note this is experimental and not tested thoroughly. + # Note this needs to be explicitly called with something like + # def cond_prepare(m): + # if hasattr(m, "prepare_for_eval"): + # m.prepare_for_eval() + # model.apply(cond_prepare) + print('=> preparing for eval.') + if self.vector_wise_quantization: + W_int8, state_W = quantize_rowwise(self.weight) + else: + W_int8, state_W = quantize_global(self.weight) + + self.register_buffer("W_int8", W_int8) + self.register_buffer("state_W", state_W) + + del self.weight + + def forward(self, x): + if self.training: + return self._fn.apply(x, self.weight, self.bias) + else: + # If it hasn't been "prepared for eval", run the standard forward pass. + if not hasattr(self, "W_int8"): + return self._fn.apply(x, self.weight, self.bias) + + # Otherwise, use pre-computed weights. + X = x.view(-1, x.size(-1)) + X_int8, state_X = quantize_rowwise(X) + + if self.vector_wise_quantization: + return int8_matmul_rowwise_dequantize( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) + else: + return int8_matmul_mixed_dequanitze( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) + +SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) +SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) +SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) + +# This is just the standard linear function. +class StandardLinearFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None): + X = input.view(-1, input.size(-1)) + + ctx.save_for_backward(X, weight, bias) + output = input.matmul(weight.t()) + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output.view(*input.size()[:-1], -1) + + @staticmethod + def backward(ctx, grad_output_3D): + input, weight, bias = ctx.saved_tensors + + grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1)) + + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1) + if ctx.needs_input_grad[1]: + grad_weight = grad_output.t().matmul(input.to(grad_output.dtype)) + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + +class StandardLinear(nn.Linear): + + def forward(self, x): + return StandardLinearFunction.apply(x, self.weight, self.bias) diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 53533ee..83a57bd 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -6,11 +6,11 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit -from .adam import Adam, Adam8bit, Adam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit +from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit +from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit -from .lion import Lion, Lion8bit, Lion32bit +from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 396aeb8..86981eb 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -14,92 +14,34 @@ from bitsandbytes.optim.optimizer import Optimizer2State class Adam(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) +class PagedAdam(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 022e64c..21077f1 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -5,89 +5,35 @@ from bitsandbytes.optim.optimizer import Optimizer2State -class AdamW(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) +class AdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + + +class PagedAdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py index 2551b68..2bde1a4 100644 --- a/bitsandbytes/optim/lion.py +++ b/bitsandbytes/optim/lion.py @@ -4,84 +4,27 @@ # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer1State - class Lion(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion8bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Lion32bit(Optimizer1State): - def __init__( - self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "lion", - params, - lr, - betas, - 0., - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + + +class PagedLion(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedLion8bit(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedLion32bit(Optimizer1State): + def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 1adf5d4..fb83edd 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -92,10 +92,12 @@ class GlobalOptimManager: class Optimizer8bit(torch.optim.Optimizer): - def __init__(self, params, defaults, optim_bits=32): + def __init__(self, params, defaults, optim_bits=32, is_paged=False): super().__init__(params, defaults) self.initialized = False self.name2qmap = {} + self.is_paged = is_paged + self.page_mng = F.GlobalPageManager.get_instance() self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = { @@ -207,7 +209,9 @@ class Optimizer8bit(torch.optim.Optimizer): values = self.state[p] for k, v in values.items(): if isinstance(v, torch.Tensor): - self.state[p][k] = v.to(p.device) + is_paged = getattr(v, 'is_paged', False) + if not is_paged: + self.state[p][k] = v.to(p.device) def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: @@ -252,6 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True + #if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -260,7 +265,14 @@ class Optimizer8bit(torch.optim.Optimizer): if len(state) == 0: self.init_state(group, p, gindex, pindex) + self.prefetch_state(p) self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + return loss @@ -289,6 +301,26 @@ class Optimizer8bit(torch.optim.Optimizer): "The update_step method needs to be overridden" ) + def get_state_buffer(self, p, dtype=torch.float32): + if not self.is_paged or p.numel() < 1e5: + return torch.zeros_like(p, dtype=dtype, device=p.device) + else: + # > 1 MB + buff = F.get_paged(*p.shape, dtype=dtype, device=p.device) + F.fill(buff, 0) + self.page_mng.paged_tensors.append(buff) + return buff + + def prefetch_state(self, p): + if self.is_paged: + state = self.state[p] + s1 = state['state1'] + is_paged = getattr(s1, 'is_paged', False) + if is_paged: + F.prefetch_tensor(state['state1']) + if 'state2' in state: + F.prefetch_tensor(state['state2']) + class Optimizer2State(Optimizer8bit): def __init__( @@ -306,6 +338,7 @@ class Optimizer2State(Optimizer8bit): block_wise=True, max_unorm=0.0, skip_zeros=False, + is_paged=False ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -325,7 +358,7 @@ class Optimizer2State(Optimizer8bit): f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits, is_paged) if args is None: args = {} @@ -365,18 +398,8 @@ class Optimizer2State(Optimizer8bit): if dtype == torch.float32 or ( dtype == torch.uint8 and p.numel() < 4096 ): - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) - state["state2"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) + state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: @@ -388,20 +411,10 @@ class Optimizer2State(Optimizer8bit): p.device ) - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] - state["state2"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state2"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap2"] = self.name2qmap["udynamic"] if config["block_wise"]: @@ -538,6 +551,7 @@ class Optimizer1State(Optimizer8bit): block_wise=True, max_unorm=0.0, skip_zeros=False, + is_paged=False ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -553,7 +567,7 @@ class Optimizer1State(Optimizer8bit): f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits, is_paged) if args is None: args = {} @@ -593,12 +607,7 @@ class Optimizer1State(Optimizer8bit): if dtype == torch.float32 or ( dtype == torch.uint8 and p.numel() < 4096 ): - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: @@ -607,12 +616,7 @@ class Optimizer1State(Optimizer8bit): p.device ) - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] if config["block_wise"]: diff --git a/bitsandbytes/research/__init__.py b/bitsandbytes/research/__init__.py new file mode 100644 index 0000000..47b720d --- /dev/null +++ b/bitsandbytes/research/__init__.py @@ -0,0 +1,6 @@ +from . import nn +from .autograd._functions import ( + switchback_bnb, + matmul_fp8_global, + matmul_fp8_mixed, +) diff --git a/bitsandbytes/research/autograd/__init__.py b/bitsandbytes/research/autograd/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py new file mode 100644 index 0000000..0dff351 --- /dev/null +++ b/bitsandbytes/research/autograd/_functions.py @@ -0,0 +1,411 @@ +import operator +import warnings +from dataclasses import dataclass +from functools import reduce # Required in Python 3 + +import torch + +import bitsandbytes.functional as F + +from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler + + +# math.prod not compatible with python < 3.8 +def prod(iterable): + return reduce(operator.mul, iterable, 1) + +tensor = torch.Tensor + +class MatMulFP8Mixed(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + + B_shape = B.shape + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + # 1. Dequantize + # 2. MatmulnN + cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) + fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) + + cB, state = F.quantize(B.float(), code=fw_code) + fp8B = F.dequantize(cB, state).to(B.dtype) + + output = torch.matmul(fp8A, fp8B) + + # output is half + + # 3. Save state + ctx.fw_code = fw_code + ctx.bw_code = bw_code + ctx.bsz = bsz + ctx.bsz2 = bsz2 + ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype + + if any(ctx.needs_input_grad[:2]): + # NOTE: we send back A, and re-quant. + ctx.tensors = (A, fp8B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None + + req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad + A, B = ctx.tensors + + grad_A, grad_B = None, None + + # TODO: Fix blocksize to be output_dim + cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2) + fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype) + + # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) + # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) + + # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') + # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose + # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) + + # not supported by PyTorch. TODO: create work-around + if req_gradA: + grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) + + if req_gradB: + if len(A.shape) == 3: + At = A.transpose(2, 1).contiguous() + else: + At = A.transpose(1, 0).contiguous() + # cA, state = F.quantize(At.float(), code=ctx.fw_code) + # fp8At = F.dequantize(cA, state).to(A.dtype) + grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype) + + return grad_A, grad_B, None, None, None, None, None + + +class MatMulFP8Global(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + + B_shape = B.shape + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + # 1. Dequantize + # 2. MatmulnN + cA, state = F.quantize(A.float(), code=fw_code) + fp8A = F.dequantize(cA, state).to(A.dtype) + + cB, state = F.quantize(B.float(), code=fw_code) + fp8B = F.dequantize(cB, state).to(B.dtype) + + output = torch.matmul(fp8A, fp8B) + + # output is half + + # 3. Save state + ctx.fw_code = fw_code + ctx.bw_code = bw_code + ctx.bsz = bsz + ctx.bsz2 = bsz2 + ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype + + if any(ctx.needs_input_grad[:2]): + # NOTE: we send back A, and re-quant. + ctx.tensors = (A, fp8B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None + + req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad + A, B = ctx.tensors + + grad_A, grad_B = None, None + + # TODO: Fix blocksize to be output_dim + cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code) + fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype) + + # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) + # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) + + # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') + # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose + # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) + + # not supported by PyTorch. TODO: create work-around + if req_gradA: + grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) + + if req_gradB: + if len(A.shape) == 3: + At = A.transpose(2, 1).contiguous() + else: + At = A.transpose(1, 0).contiguous() + cA, state = F.quantize(At.float(), code=ctx.fw_code) + fp8At = F.dequantize(cA, state).to(A.dtype) + grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype) + + return grad_A, grad_B, None, None, None, None, None + + +class SwitchBackBnb(torch.autograd.Function): + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) + + # 1. Quantize A + # 2. Quantize B + # 3. Matmul + # 4. Mixed-precision decomposition matmul + # 5. Save state + formatB = state.formatB + input_shape = A.shape + if state.outlier_pool is None: + state.outlier_pool = GlobalOutlierPooler.get_instance() + + # Cast A to fp16 + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + + # 1. Quantize A + if len(A.shape) == 3: + A = A.view(-1, A.shape[-1]).contiguous() + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A.to(torch.float16), threshold=state.threshold + ) + + if state.threshold > 0.0 and coo_tensorA is not None: + if state.has_fp16_weights: + idx = torch.unique(coo_tensorA.colidx).long() + CA[:, idx] = 0 + CAt[:, idx] = 0 + subA = A[:, idx] + state.subB = B[:, idx].t().contiguous() + state.idx = idx + else: + if state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + else: + #print('A shape', A.shape) + if not state.has_fp16_weights and state.CxB is None: + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + subA = None + + # 2. Quantize B + if state.has_fp16_weights: + #print('B shape', B.shape) + has_grad = True if (getattr(B, "grad", None) is not None) else False + is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) + if is_transposed: + B = B.contiguous() + + if (state.is_training and not has_grad) or state.CxB is None: + state.reset_grads() + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B.to(torch.float16)) + state.CxB, state.SB = F.transform(CB, to_order=formatB) + else: + has_grad = False + + if coo_tensorA is not None and not state.has_fp16_weights: + # extract outliers + + outlier_idx = torch.unique(coo_tensorA.colidx) + state.idx = outlier_idx + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + state.subB = ( + (outliers * state.SCB.view(-1, 1) / 127.0) + .t() + .contiguous() + .to(A.dtype) + ) + CA[:, state.idx.long()] = 0 + CAt[:, state.idx.long()] = 0 + subA = A[:, state.idx.long()] + + shapeB = state.SB[0] + + if len(input_shape) == 3: + output_shape = (input_shape[0], input_shape[1], shapeB[0]) + else: + output_shape = (input_shape[0], shapeB[0]) + + # 3. Matmul + C32A, SA = F.transform(CA, "col32") + out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + # we apply the fused bias here + + if bias is None or bias.dtype == torch.float16: + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype).add_(bias) + + # 4. Mixed-precision decomposition matmul + if coo_tensorA is not None and subA is not None: + output += torch.matmul(subA, state.subB) + + # 5. Save state + ctx.state = state + + ctx.formatB = formatB + ctx.grad_shape = input_shape + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (CAt, subA, A) + ctx.tensor_states = (SCAt, state.idx) + else: + ctx.tensors = [None, None, None] + ctx.tensor_states = (None, None) + ctx.save_for_backward(None, None) + + + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + return clone_func(output.view(output_shape)) + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad + CAt, subA, A = ctx.tensors + SCAt, idx = ctx.tensor_states + formatB = ctx.formatB + state = ctx.state + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # Cast grad_output to fp16 + if len(grad_output.shape) == 3: + grad_output = grad_output.reshape( + -1, grad_output.shape[-1] + ).contiguous() + + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + + if req_gradB: + # print('back A shape', A.shape) + # print('grad output t shape', grad_output.t().shape) + grad_B = torch.matmul(grad_output.t(), A) + + if req_gradA: + if state.CBt is not None: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + # print('back B shape', state.CxBt.shape) + # print('back grad shape', C32grad.shape) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + + elif state.CB is not None: + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + else: + raise Exception('State must contain either CBt or CB matrix for backward') + + return grad_A, grad_B, None, grad_bias, None + +def get_block_sizes(input_matrix, weight_matrix): + input_features = input_matrix.shape[-1] + output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + bsz, bsz2 = 1024, 1024 + for i, k in enumerate(array): + if input_features > array[i + 1]: + bsz = k + break + for i, k in enumerate(array): + if output_features > array[i + 1]: + bsz2 = k + break + + return bsz, bsz2 + +def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + +def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + +def switchback_bnb( + A: tensor, + B: tensor, + out: tensor = None, + state: MatmulLtState = None, + threshold=0.0, + bias=None +): + state = state or MatmulLtState() + if threshold > 0.0: + state.threshold = threshold + return SwitchBackBnb.apply(A, B, out, bias, state) diff --git a/bitsandbytes/research/nn/__init__.py b/bitsandbytes/research/nn/__init__.py new file mode 100644 index 0000000..8faec10 --- /dev/null +++ b/bitsandbytes/research/nn/__init__.py @@ -0,0 +1 @@ +from .modules import LinearFP8Mixed, LinearFP8Global diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py new file mode 100644 index 0000000..2a46b40 --- /dev/null +++ b/bitsandbytes/research/nn/modules.py @@ -0,0 +1,64 @@ +from typing import Optional, TypeVar, Union, overload + +import torch +import torch.nn.functional as F +from torch import Tensor, device, dtype, nn + +import bitsandbytes as bnb +from bitsandbytes.optim import GlobalOptimManager +from bitsandbytes.utils import OutlierTracer, find_outlier_dims + +T = TypeVar("T", bound="torch.nn.Module") + + +class LinearFP8Mixed(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.bw_code = None + self.fw_code = None + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + for i, k in enumerate(array): + if input_features > array[i + 1]: + self.bsz = k + break + for i, k in enumerate(array): + if output_features > array[i + 1]: + self.bsz2 = k + break + + def forward(self, x: torch.Tensor): + if self.fw_code is None: + self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) + self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) + + out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + if self.bias is not None: + out += self.bias + + return out + +class LinearFP8Global(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.bw_code = None + self.fw_code = None + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + for i, k in enumerate(array): + if input_features > array[i + 1]: + self.bsz = k + break + for i, k in enumerate(array): + if output_features > array[i + 1]: + self.bsz2 = k + break + + def forward(self, x: torch.Tensor): + if self.fw_code is None: + self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) + self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) + + out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + if self.bias is not None: + out += self.bias + + return out diff --git a/bitsandbytes/triton/__init__.py b/bitsandbytes/triton/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py new file mode 100644 index 0000000..e092680 --- /dev/null +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -0,0 +1,64 @@ +import math +import torch +import time +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # rowwise quantize + + # TODO: autotune this better. + @triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] + ) + @triton.jit + def _dequantize_rowwise( + x_ptr, + state_x, + output_ptr, + inv_127, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + max_val = tl.load(state_x + pid) + output = max_val * x * inv_127 + tl.store(output_ptr + offsets, output, mask=row_mask) + + + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py b/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py new file mode 100644 index 0000000..60a56e6 --- /dev/null +++ b/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py @@ -0,0 +1,163 @@ +import torch +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + + # This is a matmul kernel based on triton.ops.matmul + # It is modified to support rowwise quantized input and global quantized weight + # It's purpose is fused matmul then dequantize + # It does support bias. + + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + + @triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, + ) + @triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, + }) + @triton.jit + def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr) + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + # conditionally add bias + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + + def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): + device = a.device + divfactor = 1. / (127. * 127.) + has_bias = 0 if bias is None else 1 + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch int8_matmul_mixed_dequantize kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py new file mode 100644 index 0000000..33f4d13 --- /dev/null +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -0,0 +1,164 @@ +import torch + +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None +else: + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # This is a matmul kernel based on triton.ops.matmul + # It is modified to support rowwise quantized input and columnwise quantized weight + # It's purpose is fused matmul then dequantize + # It does support bias. + + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + + def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + + @triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, + ) + @triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, + }) + @triton.jit + def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr + rbn)[None, :] + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): + divfactor = 1. / (127. * 127.) + + has_bias = 0 if bias is None else 1 + + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch int8_matmul_rowwise_dequantize kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py new file mode 100644 index 0000000..54220d9 --- /dev/null +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -0,0 +1,74 @@ +import math +import torch +import time +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def quantize_columnwise_and_transpose(x: torch.Tensor): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # This kernel does fused columnwise quantization and transpose. + + # TODO: autotune this better. + @triton.autotune( + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] + ) + @triton.jit + def _quantize_columnwise_and_transpose( + x_ptr, + output_ptr, + output_maxs, + n_elements, + M : tl.constexpr, N : tl.constexpr, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid + p2_arange = tl.arange(0, P2) + p2_arange_mask = p2_arange < M + arange = p2_arange * N + offsets = block_start + arange + x = tl.load(x_ptr + offsets, mask=p2_arange_mask) + abs_x = tl.abs(x) + max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + + new_start = pid * M + new_offsets = new_start + p2_arange + tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) + tl.store(output_maxs + pid, max_val) + + def quantize_columnwise_and_transpose(x: torch.Tensor): + M, N = x.shape + output = torch.empty(N, M, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(M)))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) + return output, output_maxs + diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py new file mode 100644 index 0000000..845db6e --- /dev/null +++ b/bitsandbytes/triton/quantize_global.py @@ -0,0 +1,107 @@ +import math +import torch +import time +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def quantize_global_transpose(input): return None + def quantize_global(x: torch.Tensor): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # global quantize + @triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), + triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), + + ], + key=['n_elements'] + ) + @triton.jit + def _quantize_global( + x_ptr, + absmax_inv_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + output = tl.libdevice.llrint(127. * (x * absmax_inv)) + tl.store(output_ptr + offsets, output, mask=mask) + + def quantize_global(x: torch.Tensor): + absmax = x.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_global[grid](x, absmax_inv, output, n_elements) + return output, absmax + + + # global quantize and transpose + @triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + + # ... + ], + key=['M', 'N'] + ) + @triton.jit + def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, + BLOCK_M : tl.constexpr, + BLOCK_N : tl.constexpr, + GROUP_M : tl.constexpr): + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // group_size + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) + mask = (rm < M)[:, None] & (rn < N)[None, :] + a = tl.load(A, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + + # rematerialize to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + + output = tl.libdevice.llrint(127. * (a * absmax_inv)) + + tl.store(B, output, mask=mask) + + def quantize_global_transpose(input): + absmax = input.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + M, N = input.shape + out = torch.empty(N, M, device='cuda', dtype=torch.int8) + + assert out.size(0) == N and out.size(1) == M + assert input.stride(0) == 1 or input.stride(1) == 1 + assert out.stride(0) == 1 or out.stride(1) == 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) + return out, absmax + diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py new file mode 100644 index 0000000..26d2183 --- /dev/null +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -0,0 +1,68 @@ +import math +import torch +import time + +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def quantize_rowwise(x: torch.Tensor): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # rowwise quantize + + # TODO: autotune this better. + @triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] + ) + @triton.jit + def _quantize_rowwise( + x_ptr, + output_ptr, + output_maxs, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + + def quantize_rowwise(x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs + diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py new file mode 100644 index 0000000..c74c239 --- /dev/null +++ b/bitsandbytes/triton/triton_utils.py @@ -0,0 +1,4 @@ +import importlib + +def is_triton_available(): + return importlib.util.find_spec("triton") is not None diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 1cd90e3..6729f7c 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,143 @@ import shlex import subprocess +import torch from typing import Tuple +def outlier_hook(module, input): + assert isinstance(module, torch.nn.Linear) + tracer = OutlierTracer.get_instance() + hvalue = tracer.get_hvalue(module.weight) + if hvalue not in tracer.hvalue2outlier_idx: + outlier_idx = find_outlier_dims(module.weight) + tracer.outliers.append(outlier_idx) + tracer.hvalues.append(hvalue) + if len(tracer.outliers) > 1: + # assign the current layer the outlier idx found from the weight + # of the previous linear layer + if tracer.outliers[-1].numel() > 0: + assert tracer.outliers[-1].max() < module.weight.shape[1] + tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] + + else: + # first layer, we cannot use the weight for outlier detection + # we follow a mixed approach: + # (1) zscore test of std of hidden dimension + # (2) magnitude > 6 test + merged = input[0].view(-1, input[0].shape[-1]) + # (1) zscore test of std of hidden dimension + outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) + # (2) magnitude > 6 test + dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) + outlier_idx2 = torch.where(dims > 0)[0] + outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() + tracer.hvalue2outlier_idx[hvalue] = outlier_idx + else: + for hook in tracer.hooks: + hook.remove() + + +class OutlierTracer(object): + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self, model): + self.last_w = None + self.current_outlier_dims = None + self.hvalues = [] + self.outliers = [] + self.hvalue2outlier_idx = {} + self.initialized = True + self.hooks = [] + + for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + self.hooks.append(m.register_forward_pre_hook(outlier_hook)) + + def is_initialized(self): + return getattr(self, 'initialized', False) + + def get_hvalue(self, weight): + return weight.data.storage().data_ptr() + + def get_outliers(self, weight): + if not self.is_initialized(): + print('Outlier tracer is not initialized...') + return None + hvalue = self.get_hvalue(weight) + if hvalue in self.hvalue2outlier_idx: + return self.hvalue2outlier_idx[hvalue] + else: + return None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + return cls._instance + +def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): + if rdm: + return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() + + m = weight.mean(reduction_dim) + mm = m.mean() + mstd = m.std() + zm = (m-mm)/mstd + + std = weight.std(reduction_dim) + stdm = std.mean() + stdstd = std.std() + + zstd = (std-stdm)/stdstd + + if topk is not None: + val, idx = torch.topk(std.abs(), k=topk, dim=0) + else: + idx = torch.where(zstd > zscore)[0] + + return idx + +def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): + """ + Replace linear modules with a new Linear module. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + copy_weights (`bool`): + Copy the weights from the old linear module to the new one + post_processing_fun_name (`str`): + A function name of the replacement linear class that is called + after processing. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: func(module) + return model + + def execute_and_return(command_string: str) -> Tuple[str, str]: def _decode(subprocess_err_out_tuple): @@ -21,3 +157,43 @@ def execute_and_return(command_string: str) -> Tuple[str, str]: std_out, std_err = execute_and_return_decoded_std_streams(command_string) return std_out, std_err + + + +def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + copy_weights (`bool`): + Copy the weights from the old linear module to the new one + post_processing_fun_name (`str`): + A function name of the replacement linear class that is called + after processing. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: func(module) + return model + diff --git a/compile_from_source.md b/compile_from_source.md index 9d4f89d..f5de4db 100644 --- a/compile_from_source.md +++ b/compile_from_source.md @@ -33,3 +33,8 @@ You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be abl If you have problems compiling the library with these instructions from source, please open an issue. + +## Compilation with Kepler + +Since 0.39.1 bitsandbytes installed via pip no longer provides Kepler binaries and these need to be compiled from source. Follow the steps above and instead of `cuda11x_nomatmul` etc use `cuda11x_nomatmul_kepler` + diff --git a/csrc/kernels.cu b/csrc/kernels.cu index e1ec00d..ab12c37 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -12,12 +12,17 @@ #include #include #include +#include +#include +#include + #define HLF_MAX 65504 #define TH 1024 #define NUM 4 #define NUM_BLOCK 4096 + // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda __device__ float atomicMax(float* address, float val) { int* address_as_i = reinterpret_cast(address); @@ -43,11 +48,289 @@ __device__ float atomicMin(float* address, float val) { return __int_as_float(old); } +__device__ float dDequantizeFP4(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f*absmax; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction*absmax; + } +} + +__device__ float d2DequantizeFP4(unsigned char val) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0110) == 0) + { + // subnormal + if((val & 0b0001) == 0) + return 0.0f; + else + return sign*0.0625f; + } + else + { + // normal + float exponent = ((val & 0b0100) == 4 ? 2.0f : 8.0f) + ((val & 0b0010) == 2 ? 0.0f : 2.0f); + float fraction = (val & 0b0001) == 1 ? 1.5f : 1.0f; + + return sign*exponent*fraction; + } +} + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 111 + return 0.25000000f*absmax*sign; // 1111 + else + return 0.16666667f*absmax*sign; // 1110 + else + if((val & 0b0001) == 1) // 110 + return 0.50000000f*absmax*sign; // 1101 + else + return 0.33333333f*absmax*sign; // 1100 + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 1.00000000f*absmax*sign; // 1011 + else + return 0.66666667f*absmax*sign; // 1010 + else + if((val & 0b0001) == 1) // 100 + return 5.208333333e-03f*absmax*sign; // 1001 + else + return 0.00000000f*absmax*sign; // 1000 +} + +__device__ unsigned char dQuantizeFP4(float x) +{ + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if(x > 0.29166667f) + if( x > 0.583333f) + if( x > 0.8333333f) + return 0b0011+sign; + else + return 0b0010+sign; + else + if(x > 0.4166667f) + return 0b101+sign; + else + return 0b100+sign; + else + if(x > 0.0859375f) + if(x > 0.20833333f) + return 0b0111+sign; + else + return 0b0110+sign; + else + if(x > 0.00260417f) + return 0b0001+sign; + else + return 0b0000+sign; +} + +__device__ half dhDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ float dDequantizeNF4(unsigned char val) +{ + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if((val & 0b1000) == 8) + if((val & 0b0100) == 4) // 1 + if((val & 0b0010) == 2) // 11 + if((val & 0b0001) == 1) // 111 + return 1.0f; + else + return 0.7229568362236023f; + else + if((val & 0b0001) == 1) // 110 + return 0.5626170039176941f; + else + return 0.44070982933044434f; + else + if((val & 0b0010) == 2) //10 + if((val & 0b0001) == 1) // 101 + return 0.33791524171829224f; + else + return 0.24611230194568634f; + else + if((val & 0b0001) == 1) // 100 + return 0.16093020141124725f; + else + return 0.07958029955625534f; + + else + if((val & 0b0100) == 4) // 0 + if((val & 0b0010) == 2) //01 + if((val & 0b0001) == 1) // 011 + return 0.0f; + else + return -0.09105003625154495f; + else + if((val & 0b0001) == 1) // 010 + return -0.18477343022823334f; + else + return -0.28444138169288635f; + else + if((val & 0b0010) == 2) //00 + if((val & 0b0001) == 1) // 001 + return -0.39491748809814453f; + else + return -0.5250730514526367f; + else + if((val & 0b0001) == 1) // 000 + return -0.6961928009986877f; + else + return -1.0f; + +} + +__device__ unsigned char dQuantizeNF4(float x) +{ + + // the values for this tree was generated by test_normal_map_tree + // in the file tests/test_functional.py + if(x > 0.03979014977812767f) + if(x > 0.3893125355243683f) // 1 + if(x > 0.6427869200706482f) // 11 + if(x > 0.8614784181118011f) // 111 + return 0b1111; + else + return 0b1110; + else + if(x > 0.5016634166240692f) // 110 + return 0b1101; + else + return 0b1100; + else + if(x > 0.2035212516784668f) // 10 + if(x > 0.2920137718319893f) // 101 + return 0b1011; + else + return 0b1010; + else + if(x > 0.1202552504837513f) // 100 + return 0b1001; + else + return 0b1000; + else + if(x > -0.33967943489551544f) // 0 + if(x > -0.13791173323988914f) // 01 + if(x > -0.045525018125772476f) // 011 + return 0b0111; + else + return 0b0110; + else + if(x > -0.23460740596055984f) // 010 + return 0b0101; + else + return 0b0100; + else + if(x > -0.6106329262256622f) // 00 + if(x > -0.4599952697753906f) // 001 + return 0b0011; + else + return 0b0010; + else + if(x > -0.8480964004993439f) // 000 + return 0b0001; + else + return 0b0000; +} // sign function for lion // taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA -template -__device__ int sgn(T val) { +template __device__ int sgn(T val) +{ return (T(0) < val) - (val < T(0)); } @@ -435,7 +718,7 @@ __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned c } } -template +template //__launch_bounds__(TH, 4) __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) { @@ -445,13 +728,13 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float T vals[NUM_PER_TH]; float rand_vals[NUM_PER_TH]; - unsigned char qvals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; //float local_abs_max = -FLT_MAX; float local_abs_max = 0.0f; int local_rand_idx = 0; typedef cub::BlockLoad LoadT; - typedef cub::BlockStore StoreChar; + typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; typedef cub::BlockReduce BlockReduce; typedef cub::BlockLoad LoadFloat; @@ -462,8 +745,9 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float __shared__ float smem_code[256]; __shared__ float smem_absmax_value[1]; - for(int i = threadIdx.x; i < 256; i+=blockDim.x) - smem_code[i] = code[i]; + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) { @@ -503,62 +787,111 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); } - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) + unsigned char packed_4bit = 0; + switch(DATA_TYPE) { - if(!STOCHASTIC) - qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); - else - qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + qvals[j] = packed_4bit; + } + break; } __syncthreads(); - StoreChar(storec).Store(&(out[i]), qvals, valid_items); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); } } -template -__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n) +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) { - const int n_full = gridDim.x * BLOCK_SIZE; - int valid_items = 0; - const int base_idx = (blockIdx.x * BLOCK_SIZE); + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); - T vals[NUM_PER_TH]; + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; unsigned char qvals[NUM_PER_TH]; float local_abs_max = -FLT_MAX; typedef cub::BlockLoad LoadChar; - typedef cub::BlockStore StoreT; + typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; __shared__ typename LoadChar::TempStorage loadchar; __shared__ typename StoreT::TempStorage storet; - //__shared__ float smem_code[256]; - //float local_code[16]; - //if(threadIdx.x < 256) - //smem_code[threadIdx.x] = code[threadIdx.x]; - - for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) { - valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; - local_abs_max = absmax[i/BLOCK_SIZE]; + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); - __syncthreads(); - LoadChar(loadchar).Load(&(A[i]), qvals, valid_items, 128); + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); - // load code through read-only cache via __ldg - #pragma unroll NUM_PER_TH - for(int j = 0; j < NUM_PER_TH; j++) - vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } - __syncthreads(); - StoreT(storet).Store(&(out[i]), vals, valid_items); + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); } } - __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) { const unsigned int numThreads = blockDim.x * gridDim.x; @@ -1460,6 +1793,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char unsigned char c1s[N_PER_TH]; unsigned char c2s[N_PER_TH]; T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; typedef cub::BlockLoad LoadT; typedef cub::BlockLoad LoadChar; @@ -1521,16 +1855,24 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - g_val = float(g_vals[j]); - g_val *= gnorm_scale; - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); - - s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; - s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); @@ -1561,22 +1903,23 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char } __syncthreads(); - LoadT(temp_storage.loadh).Load(&(p[i]), g_vals, valid_items, (T)0.0f); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); // reduce: 2.67/1.69 -> 2.67/1.70 # pragma unroll N_PER_TH for(unsigned int j = 0; j < N_PER_TH; j++) { - if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { - g_vals[j] = (T)(((float)g_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); if(weight_decay > 0.0f) - g_vals[j] = ((float)g_vals[j])*(1.0f-(lr*weight_decay)); + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } } // store: 0.85/1.44 -> 2.48/1.57 __syncthreads(); - StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); // quantizaztion: 2.67/1.70 -> 3.4/3.3 # pragma unroll N_PER_TH @@ -2496,7 +2839,7 @@ template @@ -2575,7 +2918,6 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; if(idx >= colsB){ break; } - //printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx); if((idx+num_items < colsB)) { if(BITS == 8) @@ -2595,15 +2937,13 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o #pragma unroll num_items for(int k = 0; k < num_items; k++) { - //if((float)local_valsB[k] != 0.0) - // printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB); if(BITS == 8 && dequant_stats != NULL) // we do texture cache reads (__ldg) on dequant_stats which should be super fast { float valB = local_valsB[k]; float valA = local_valA[i]; if(valB != 0.0 && valA != 0.0) - local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA; + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; } else local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; @@ -2709,10 +3049,587 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * } } + +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with cub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggreecate files of C into shared memroy block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} + +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = (T)zero_value; + } + } +} + +#define WARPS 5 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + +#if __CUDA_ARCH__ >= 750 + using namespace nvcuda; + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +#endif +} + +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef cub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef cub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef cub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); @@ -2753,6 +3670,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) @@ -2766,6 +3684,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half) MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(LION, __nv_bfloat16) MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, float) @@ -2775,12 +3694,15 @@ template __global__ void kPreconditionOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ @@ -2851,39 +3773,60 @@ MAKE_optimizerStatic8bit2State(ADAM, float) template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kQuantizeBlockwise(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); - -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int n); +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ @@ -2896,6 +3839,8 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise( \ @@ -2913,5 +3858,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index a8aa3fc..30faf4a 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -9,13 +9,15 @@ #ifndef kernels #define kernels +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); __global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); __global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); -template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); -template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int n); +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); template __global__ void kPreconditionOptimizer32bit2State(T* g, T* p, @@ -120,4 +122,9 @@ template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kfunc(T *A, T *B, T value, long n); + #endif diff --git a/csrc/ops.cu b/csrc/ops.cu index 94d5f2e..9c042fa 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -50,54 +50,53 @@ void dequantize(float *code, unsigned char *A, float *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - if(STOCHASTIC == 1) - assert(blocksize == 4096); if(blocksize == 4096) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 2048) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 1024) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 512) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 256) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 128) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); else if(blocksize == 64) - kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) { int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; - if(blocksize == 4096) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 2048) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 1024) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 512) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 256) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 128) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); - else if(blocksize == 64) - kDequantizeBlockwise<<>>(code, A, absmax, out, n); + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + else + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(cudaPeekAtLastError()); +//} + + template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay, @@ -683,10 +682,73 @@ template void extractOutliers(char * A, int *idx, char *out, int id CUDA_CHECK_RETURN(cudaPeekAtLastError()); } + + + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + kfunc<<>>(A, B, value, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + //============================================================== // TEMPLATE DEFINITIONS //============================================================== +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); @@ -710,12 +772,20 @@ template void transformRowToFormat(char * A, char *out, int rows, template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, \ @@ -725,12 +795,14 @@ template void optimizer32bit(gtype* g, gtype* p, \ MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(LION, __nv_bfloat16) MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, float) @@ -766,8 +838,11 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 9f06435..5b9a32b 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -20,6 +20,11 @@ #include #include +#include +#include + + + #define CUDA_CHECK_RETURN(value) { \ cudaError_t _m_cudaStat = value; \ if (_m_cudaStat != cudaSuccess) { \ @@ -82,6 +87,20 @@ typedef enum Transform_t COL_AMPERE = 4, } Transform_t; +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + class Context { public: @@ -129,8 +148,8 @@ template void estimateQuantiles(T *A, float *code, float offset, in void quantize(float *code, float *A, unsigned char *out, int n); void dequantize(float *code, unsigned char *A, float *out, int n); -template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, @@ -177,4 +196,11 @@ template void spmm_coo_very_sparse_naive(int *max_count, template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); + +template void func(T *A, T *B, T value, long n); + #endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 4caa7e8..23a0364 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -20,8 +20,25 @@ void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimat void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } +//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) +//{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } +void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } + +void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + +#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ +void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ + +MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + + #define MAKE_FUNC32(fname, oname, gtype, gbits) \ -void fname##32bit_g##gbits(gtype *g, gtype *p, \ +void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ @@ -29,17 +46,19 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) -MAKE_FUNC32(adam, ADAM, float, 32) -MAKE_FUNC32(adam, ADAM, half, 16) +MAKE_FUNC32(adam, ADAM, float, fp32) +MAKE_FUNC32(adam, ADAM, half, fp16) +MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) -MAKE_FUNC32(lion, LION, float, 32) -MAKE_FUNC32(lion, LION, half, 16) +MAKE_FUNC32(lion, LION, float, fp32) +MAKE_FUNC32(lion, LION, half, fp16) +MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, half, 16) #define MAKE_FUNC8(fname, oname, gtype, gbits) \ -void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ +void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -61,33 +80,42 @@ MAKE_FUNC8(lion, LION, float, 32) MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ -void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ +void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ { optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ -MAKE_BLOCKWISE8(adam, ADAM, half, 16) -MAKE_BLOCKWISE8(adam, ADAM, float, 32) -MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) -MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) -MAKE_BLOCKWISE8(lion, LION, half, 16) -MAKE_BLOCKWISE8(lion, LION, float, 32) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) +MAKE_BLOCKWISE8(adam, ADAM, half, fp16) +MAKE_BLOCKWISE8(adam, ADAM, float, fp32) +MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(lion, LION, half, fp16) +MAKE_BLOCKWISE8(lion, LION, float, fp32) +MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } -void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } -void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } -void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise(code, A, absmax, out, rand, rand_offset, 4096, n); } +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ @@ -148,32 +176,41 @@ extern "C" void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); } - void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); } void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + #define MAKE_CFUNC32(name, gtype, gbits) \ - void c##name##32bit_g##gbits(gtype *g, gtype *p, \ + void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ - { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ - MAKE_CFUNC32(adam, float, 32) - MAKE_CFUNC32(adam, half, 16) + MAKE_CFUNC32(adam, float, fp32) + MAKE_CFUNC32(adam, half, fp16) + MAKE_CFUNC32(adam, __nv_bfloat16, bf16) MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, half, 16) - MAKE_CFUNC32(lion, float, 32) - MAKE_CFUNC32(lion, half, 16) + MAKE_CFUNC32(lion, float, fp32) + MAKE_CFUNC32(lion, half, fp16) + MAKE_CFUNC32(lion, __nv_bfloat16, bf16) MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, half, 16) #define MAKE_CFUNC8(name, gtype, gbits) \ - void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ float *unorm, float max_unorm, float param_norm, \ float beta1, float beta2, \ float eps, int step, float lr, \ @@ -181,7 +218,7 @@ extern "C" float* max1, float* max2, float* new_max1, float* new_max2, \ float weight_decay, float gnorm_scale, int n) \ { \ - name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ } \ @@ -195,22 +232,23 @@ extern "C" MAKE_CFUNC8(lion, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ - void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ + void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ - - MAKE_CBLOCKWISE8(adam, ADAM, half, 16) - MAKE_CBLOCKWISE8(adam, ADAM, float, 32) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) - MAKE_CBLOCKWISE8(lion, LION, half, 16) - MAKE_CBLOCKWISE8(lion, LION, float, 32) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) + { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) + MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) + MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(lion, LION, half, fp16) + MAKE_CBLOCKWISE8(lion, LION, float, fp32) + MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16) void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } @@ -298,6 +336,38 @@ extern "C" void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) + //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } + + void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) + { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } + + void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + + void *cget_managed_ptr(size_t bytes) + { + void *ptr; + CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + return ptr; + } + + void cprefetch(void *ptr, size_t bytes, int device) + { + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ + + CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) + CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) + CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) + CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + #endif void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } diff --git a/deploy.sh b/deploy.sh index 24d6cbf..a2257a2 100644 --- a/deploy.sh +++ b/deploy.sh @@ -139,17 +139,6 @@ if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then fi -make clean -export CUDA_HOME=$BASE_PATH/cuda-10.2 -make cuda10x_nomatmul CUDA_VERSION=102 - -if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then - # Control will enter here if $DIRECTORY doesn't exist. - echo "Compilation unsuccessul!" 1>&2 - exit 64 -fi - - make clean export CUDA_HOME=$BASE_PATH/cuda-11.0 make cuda110_nomatmul CUDA_VERSION=110 diff --git a/examples/int8_inference_huggingface.py b/examples/int8_inference_huggingface.py new file mode 100644 index 0000000..dc80a44 --- /dev/null +++ b/examples/int8_inference_huggingface.py @@ -0,0 +1,27 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +MAX_NEW_TOKENS = 128 +model_name = 'decapoda-research/llama-7b-hf' + +text = 'Hamburg is in which country?\n' +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_ids = tokenizer(text, return_tensors="pt").input_ids + +free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) +max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' + +n_gpus = torch.cuda.device_count() +max_memory = {i: max_memory for i in range(n_gpus)} + +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map='auto', + load_in_8bit=True, + max_memory=max_memory +) +generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) +print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) + + + diff --git a/setup.py b/setup.py index b023c0b..51e747c 100644 --- a/setup.py +++ b/setup.py @@ -18,10 +18,10 @@ def read(fname): setup( name=f"bitsandbytes", - version=f"0.38.0", + version=f"0.39.1", author="Tim Dettmers", author_email="dettmers@cs.washington.edu", - description="8-bit optimizers and matrix multiplication routines.", + description="k-bit optimizers and matrix multiplication routines.", license="MIT", keywords="gpu optimizers optimization 8-bit quantization compression", url="https://github.com/TimDettmers/bitsandbytes", diff --git a/tests/test_autograd.py b/tests/test_autograd.py index c67126d..803fde1 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -97,7 +97,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1 ) if req_grad[1]: @@ -106,7 +106,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() < n * 0.02 - torch.testing.assert_allclose( + torch.testing.assert_close( gradB1, gradB2, atol=0.18, rtol=0.3 ) @@ -135,7 +135,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) assert (idx == 0).sum().item() < n * 0.01 - torch.testing.assert_allclose( + torch.testing.assert_close( out_bnb, out_torch, atol=0.027, rtol=0.2 ) @@ -159,7 +159,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1 ) if req_grad[1]: @@ -218,7 +218,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1 ) if req_grad[1]: @@ -239,8 +239,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist() dim2.append(0) decomp = [0.0, 6.0] -funcs = [(torch.matmul, bnb.matmul)] -str_funcs = ["matmul"] +funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)] +str_funcs = ["matmullt", 'switchback_bnb'] req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = list(product([True, False], repeat=3)) req_grad_str = [] @@ -407,7 +407,7 @@ def test_matmullt( bias.grad = None if req_grad[0]: - torch.testing.assert_allclose( + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1 ) if req_grad[1]: @@ -423,9 +423,204 @@ def test_matmullt( assert (idx == 0).sum().item() <= n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) assert (idx == 0).sum().item() <= n * 0.02 - torch.testing.assert_allclose( + torch.testing.assert_close( gradB1, gradB2, atol=0.18, rtol=0.3 ) if req_grad[2]: - torch.testing.assert_allclose(gradBias1, gradBias2) + torch.testing.assert_close(gradBias1, gradBias2) + + +n = 1 +k = 3 +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() + +dim2.append(0) + +funcs = [(torch.matmul, bnb.matmul_4bit)] +str_funcs = ["matmul"] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + +transpose = [(False, True), (False, False)] +str_transpose = ["NT", "NN"] +dtype = [torch.float16, torch.float32] +compress_statistics = [False, True] +has_fp16_weights = [True, False] +has_bias = [True, False] +quant_type = ['fp4', 'nf4'] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names) +def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + if has_bias == False: + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(k): + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() + torch.nn.init.xavier_uniform_(B) + + B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B2, quant_state, bias=bias2) + + if has_bias: + out_torch += bias + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.115 + + #assert err < 0.20 + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias1 = bias.grad + bias.grad = None + + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias2 = bias.grad + bias.grad = None + + if req_grad[0]: + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + + if req_grad[2]: + torch.testing.assert_close(gradBias1, gradBias2) + + +funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)] +str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global'] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + +transpose = [(False, True), (False, False)] +str_transpose = ["NT", "NN"] +dtype = [torch.float16, torch.float32] +has_fp16_weights = [True, False] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) +def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(k): + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + + torch.nn.init.xavier_uniform_(B) + + fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) + bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device) + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t(), fw_code, bw_code) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B, fw_code, bw_code) + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.115 + #assert err < 0.20 + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + + if req_grad[1]: + n = gradB1.numel() + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + + assert (idx == 0).sum().item() <= n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() <= n * 0.02 + grad_err = (gradB1-gradB2).abs().mean() + assert grad_err.item() < 0.003 + torch.testing.assert_close( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) + diff --git a/tests/test_functional.py b/tests/test_functional.py index 69c200a..cc58324 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -18,12 +18,15 @@ torch.set_printoptions( k = 20 -def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0): +def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): idx = torch.isclose(a, b, rtol, atol) sumval = (idx == 0).sum().item() if sumval > count: - print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_allclose(a, b, rtol, atol) + if throw: + print(f"Too many values not close: assert {sumval} < {count}") + torch.testing.assert_close(a, b, rtol, atol) + + return sumval class FFN(torch.nn.Module): @@ -97,7 +100,7 @@ def test_estimate_quantiles(dtype): code = F.estimate_quantiles(A) percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) - torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2) + torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2) A = torch.randn(1024, 1024, device="cuda") A = A.to(dtype) @@ -122,7 +125,7 @@ def test_quantile_quantization(): C = F.quantize_no_absmax(A1, code) A2 = F.dequantize_no_absmax(C, code) diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0) + torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0) assert diff < 0.001 @@ -146,63 +149,49 @@ def test_dynamic_quantization(): C, S = F.quantize(A1) A2 = F.dequantize(C, S) diff = torch.abs(A1 - A2).mean().item() - torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) + torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) assert diff < 0.004 -def test_dynamic_blockwise_quantization(): + +@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) +@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) +def test_dynamic_blockwise_quantization(nested, blocksize): #print('') - for blocksize in [4096, 2048, 1024, 512]: - diffs = [] - reldiffs = [] - for i in range(100): - A1 = torch.randn(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.011 - assert relerr < 0.018 - #print('randn', blocksize, sum(diffs)/len(diffs)) - #print('randn', blocksize, sum(reldiffs)/len(reldiffs)) - - diffs = [] - for i in range(100): - A1 = torch.rand(1024, 1024, device="cuda") - C, S = F.quantize_blockwise(A1, blocksize=blocksize) - A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) - diff = torch.abs(A1 - A2) - reldiff = diff / torch.abs(A1 + 1e-8) - diffs.append(diff.mean().item()) - reldiffs.append(reldiff.mean().item()) - #torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0) - abserr = sum(diffs)/len(diffs) - relerr = sum(reldiffs)/len(reldiffs) - assert abserr < 0.0035 - assert relerr < 0.015 - #print('rand', blocksize, sum(diffs)/len(diffs)) - #print('rand', blocksize, sum(reldiffs)/len(reldiffs)) - - -def test_dynamic_blockwise_stochastic_quantization(): diffs = [] reldiffs = [] - rand = torch.rand(1024).cuda() for i in range(100): A1 = torch.randn(1024, 1024, device="cuda") - C1, S1 = F.quantize_blockwise(A1, rand=rand) - C2, S2 = F.quantize_blockwise(A1) - # a maximunm distance of quantized values of 1 - torch.testing.assert_allclose(C1, C2, atol=1, rtol=0) - fraction_smaller = (C1 < C2).float().sum() / C1.numel() - fraction_larger = (C1 > C2).float().sum() / C1.numel() - torch.testing.assert_allclose( - fraction_larger, fraction_smaller, atol=0.01, rtol=0 - ) + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.011 + assert relerr < 0.018 + #print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.0035 + assert relerr < 0.015 + #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + @pytest.mark.parametrize( @@ -231,9 +220,9 @@ def test_percentile_clipping(gtype): vals, idx = torch.sort(gnorm_vec1) clip1 = vals[percentile] - torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2)) - torch.testing.assert_allclose(clip1, clip2) - torch.testing.assert_allclose(gnorm1, gnorm2) + torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) + torch.testing.assert_close(clip1, clip2) + torch.testing.assert_close(gnorm1, gnorm2) def quant(x): @@ -315,7 +304,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - print("") + #print("") for i in range(5): if batched: A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") @@ -327,7 +316,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") maxA, Ac = quant_methods[0](A, 1) maxB, Bc = quant_methods[1](B, 0) - torch.testing.assert_allclose( + torch.testing.assert_close( quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 ) if batched: @@ -344,8 +333,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) - print(mean(errors)) - print(mean(relerrors)) + #print(mean(errors)) + #print(mean(relerrors)) def test_stable_embedding(): @@ -398,7 +387,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): out2 = torch.matmul(A.t().float(), B.t().float()) out = F.igemm(A.t(), B.t()) - torch.testing.assert_allclose(out.float(), out2) + torch.testing.assert_close(out.float(), out2) for i in range(k): shapeA = (batch_dim, seq_dim, hidden_dim) @@ -416,7 +405,7 @@ def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): out2 = torch.matmul(A.float(), B.t().float()) out = F.igemm(A, B.t()) - torch.testing.assert_allclose(out.float(), out2) + torch.testing.assert_close(out.float(), out2) n = 3 @@ -447,7 +436,7 @@ def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): ) out = F.igemm(A, B, out=iout) - torch.testing.assert_allclose(out.float(), out2) + torch.testing.assert_close(out.float(), out2) n = 2 @@ -572,7 +561,7 @@ def test_ibmm(dim1, dim2, dim3, dim4, transpose): A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() ) out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) - torch.testing.assert_allclose(out.float(), out2.float()) + torch.testing.assert_close(out.float(), out2.float()) n = 1 @@ -630,9 +619,9 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans out, S = F.nvidia_transform(A, to_order=orderOut) if orderOut == "row": - torch.testing.assert_allclose(A.flatten(), out.flatten()) + torch.testing.assert_close(A.flatten(), out.flatten()) elif orderOut == "col": - torch.testing.assert_allclose(A.t().flatten(), out.flatten()) + torch.testing.assert_close(A.t().flatten(), out.flatten()) elif orderOut == "col32": if dims == 2: n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) @@ -665,14 +654,14 @@ def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, trans assert A.flatten()[i + j] == A[row, col] # assert A.flatten()[i+j] == out.flatten()[row2+col2] - # torch.testing.assert_allclose(A.flatten()[i+j], A[row, col]) - # torch.testing.assert_allclose(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) + # torch.testing.assert_close(A.flatten()[i+j], A[row, col]) + # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) if orderOut == "col32": out2, S = F.nvidia_transform( out, from_order=orderOut, to_order="row", state=S ) - torch.testing.assert_allclose(A, out2) + torch.testing.assert_close(A, out2) n = 1 @@ -716,7 +705,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): B2, SB = F.transform(B, "col_turing") C2, SC = F.igemmlt(A2, B2, SA, SB) C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_allclose(C1, C3.float()) + torch.testing.assert_close(C1, C3.float()) # transpose B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( @@ -727,7 +716,7 @@ def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): B2t, SBt = F.transform(B, "col_turing", transpose=True) C2, SC = F.igemmlt(A2, B2t, SA, SBt) C3, S = F.nvidia_transform(C2, "row", state=SC) - torch.testing.assert_allclose(C1, C3.float()) + torch.testing.assert_close(C1, C3.float()) dim1 = [32] @@ -773,7 +762,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # print(C1.flatten()[:10]) # print(C2.flatten()[:10]) - # torch.testing.assert_allclose(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) # transpose # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) @@ -782,7 +771,7 @@ def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): # B2t, SBt = F.transform2(B, 'col_turing', transpose=True) # C2, SC = F.igemmlt(A2, B2t, SA, SBt) # C3, S = F.transform(C2, 'row', state=SC) - # torch.testing.assert_allclose(C1, C3.float()) + # torch.testing.assert_close(C1, C3.float()) batch_size = 2 @@ -1001,7 +990,7 @@ def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) - #torch.testing.assert_allclose(C5, C4, atol=0.015, rtol=0.1) + #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) n = C5.numel() assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) @@ -1051,16 +1040,16 @@ def test_colrow_absmax(dim1, dim2, dims): ) nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) - torch.testing.assert_allclose(col_stats1_trunc, col_stats2) - torch.testing.assert_allclose(row_stats1_trunc, row_stats2) - torch.testing.assert_allclose(nnz_block_ptr1, nnz_block_ptr2) + torch.testing.assert_close(col_stats1_trunc, col_stats2) + torch.testing.assert_close(row_stats1_trunc, row_stats2) + torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( A, threshold=0.0 ) - torch.testing.assert_allclose(col_stats1, col_stats2) - torch.testing.assert_allclose(row_stats1, row_stats2) + torch.testing.assert_close(col_stats1, col_stats2) + torch.testing.assert_close(row_stats1, row_stats2) assert nnz_block_ptr2 is None @@ -1084,8 +1073,8 @@ def test_double_quant(dim1, dim2): CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) # max difference is 1 due to rounding differences - torch.testing.assert_allclose(CA, out_row1, atol=1, rtol=0) - torch.testing.assert_allclose(CAt, out_col1, atol=1, rtol=0) + torch.testing.assert_close(CA, out_row1, atol=1, rtol=0) + torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) n = CAt.numel() num_not_close_rows = ( @@ -1108,8 +1097,8 @@ def test_double_quant(dim1, dim2): ) assert False - torch.testing.assert_allclose(Srow.flatten(), statsA) - torch.testing.assert_allclose(Scol.flatten(), statsAt) + torch.testing.assert_close(Srow.flatten().float(), statsA) + torch.testing.assert_close(Scol.flatten().float(), statsAt) n = 4 @@ -1134,10 +1123,10 @@ def test_integrated_igemmlt(dim1, dim4, inner): A1, maxA = F.vectorwise_quant(A, dim=1) B1, maxB = F.vectorwise_quant(B, dim=1) - torch.testing.assert_allclose(maxA.flatten(), stats1a) - torch.testing.assert_allclose(maxB.flatten(), stats2a) - torch.testing.assert_allclose(C1a, A1, rtol=0, atol=1) - torch.testing.assert_allclose(C2a, B1, rtol=0, atol=1) + torch.testing.assert_close(maxA.flatten().float(), stats1a) + torch.testing.assert_close(maxB.flatten().float(), stats2a) + torch.testing.assert_close(C1a, A1, rtol=0, atol=1) + torch.testing.assert_close(C2a, B1, rtol=0, atol=1) A2, SA = F.nvidia_transform(C1a, "col32") B2, SB = F.nvidia_transform(C2a, "col_turing") @@ -1339,7 +1328,7 @@ def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): # print(out1) # print(out2) - torch.testing.assert_allclose(out1, out2) + torch.testing.assert_close(out1, out2) n = 2 @@ -1401,11 +1390,11 @@ def test_coo_double_quant(dim1, dim2): A2[ coo_tensor.rowidx.long(), coo_tensor.colidx.long() ] = coo_tensor.values - torch.testing.assert_allclose(A1, A2) + torch.testing.assert_close(A1, A2) A1 = A * (idx == 0) A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() - torch.testing.assert_allclose( + torch.testing.assert_close( A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 ) @@ -1613,7 +1602,7 @@ def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): idx_col = torch.randint(0, A2.shape[-1], size=(15,)) - # torch.testing.assert_allclose(out1, out2.half(), rtol=0.05, atol=0.001) + # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001) # Bt = torch.randn(dim2*4, dim2, device='cuda').half() # torch.cuda.synchronize() @@ -1644,9 +1633,9 @@ def test_coo2csr(): counts = csrA.rowptr[1:] - csrA.rowptr[:-1] assert counts.numel() == A.shape[0] - torch.testing.assert_allclose(counts, (A2 != 0).sum(1)) + torch.testing.assert_close(counts.long(), (A2 != 0).sum(1)) idx = A2 != 0 - torch.testing.assert_allclose(A2[idx], csrA.values) + torch.testing.assert_close(A2[idx], csrA.values) def test_coo2csc(): @@ -1664,10 +1653,10 @@ def test_coo2csc(): counts = cscA.colptr[1:] - cscA.colptr[:-1] assert counts.numel() == A.shape[1] - torch.testing.assert_allclose(counts, (A2 != 0).sum(0)) + torch.testing.assert_close(counts.long(), (A2 != 0).sum(0)) # torch uses row-major -> use transpose to transfer to col-major idx = A2.t() != 0 - torch.testing.assert_allclose(A2.t()[idx], cscA.values) + torch.testing.assert_close(A2.t()[idx], cscA.values) n = 2 @@ -1717,7 +1706,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): max_count, max_idx = torch.sort(counts, descending=True) print(torch.median(max_count.float())) - torch.testing.assert_allclose(out2, out3, rtol=0.05, atol=0.001) + torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001) p = 200 / (2048 * 12288 * 4) n = out1.numel() @@ -1787,38 +1776,43 @@ def test_spmm_coo_dequant(dim1, dim2, dtype): batch_size = 1 seqdim = 1 values = [] -values.append((batch_size, seqdim, 768, 4 * 768)) -# values.append((batch_size, seqdim, 1024, 4*1024)) -# values.append((batch_size, seqdim, 1536, 4*1536)) -# values.append((batch_size, seqdim, 2048, 4*2048)) -# values.append((batch_size, seqdim, 2560, 4*2560)) -# values.append((batch_size, seqdim, 4096, 4*4096)) -# values.append((batch_size, seqdim, 5140, 4*5140)) +#values.append((batch_size, seqdim, 768, 4 * 768)) +#values.append((batch_size, seqdim, 1024, 4*1024)) +#values.append((batch_size, seqdim, 1536, 4*1536)) +#values.append((batch_size, seqdim, 2048, 4*2048)) +#values.append((batch_size, seqdim, 2560, 4*2560)) +values.append((batch_size, seqdim, 4096, 4*4096)) +values.append((batch_size, seqdim, 5120, 4*5120)) +values.append((batch_size, seqdim, 6656, 4*6656)) +values.append((batch_size, seqdim, 8192, 4*8192)) +#values.append((batch_size, seqdim, 5140, 4*5140)) #values.append((batch_size, seqdim, 12288, 4*12288)) -names = [ - "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values -] - - +names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] @pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) def test_bench_matmul(batch, seq, model, hidden): - iters = 128 + iters = 80 formatB = F.get_special_format_str() A = torch.randn(batch, seq, model, device="cuda").half() B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") torch.nn.init.xavier_uniform_(B) - linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + B_fp4, state = F.quantize_fp4(B) + B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) + + B_nf4, state_nf4= F.quantize_nf4(B) + + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half() linear8bit.eval() outliers = torch.randint(0, model, size=(5,)).cuda() A[:, :, outliers] = 8.0 - linearMixedBit = ( - bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() - ) - linearMixedBit.eval() + linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()) + #linearMixedBit.eval() + + linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() # warmup for i in range(iters): @@ -1831,61 +1825,80 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): torch.matmul(A, B.t()) torch.cuda.synchronize() - print( - f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul(A, B) + bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) torch.cuda.synchronize() - print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) torch.cuda.synchronize() t0 = time.time() for i in range(iters): - bnb.matmul(A, B, threshold=6.0) + bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) torch.cuda.synchronize() - print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) - C32A, SA = F.transform(CA, "col32") - CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) - CxB, SB = F.transform(CB, to_order=formatB) torch.cuda.synchronize() t0 = time.time() for i in range(iters): - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) torch.cuda.synchronize() - print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) - BA, statsB = F.vectorwise_quant(B, dim=1) - CxB, SB = F.nvidia_transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1) - C32A, SA = F.nvidia_transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) - torch.cuda.synchronize() + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B, threshold=6.0) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + #C32A, SA = F.transform(CA, "col32") + #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + #CxB, SB = F.transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + #torch.cuda.synchronize() + #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #BA, statsB = F.vectorwise_quant(B, dim=1) + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1) + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + #torch.cuda.synchronize() #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") - BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") - CxB, SB = F.nvidia_transform(CB, to_order=formatB) - torch.cuda.synchronize() - t0 = time.time() - for i in range(iters): - A2 = A.view(-1, A.shape[-1]).contiguous() - CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") - C32A, SA = F.nvidia_transform(CA, "col32") - out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) - Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) - out = Cout * statsB * statsA * (1.0 / (127 * 127)) - torch.cuda.synchronize() + #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # out = Cout * statsB * statsA * (1.0 / (127 * 127)) + #torch.cuda.synchronize() #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") linear8bit(A) @@ -1894,9 +1907,7 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): linear8bit(A) torch.cuda.synchronize() - print( - f"bnb linear8bitlt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") linearMixedBit(A) torch.cuda.synchronize() @@ -1904,9 +1915,23 @@ def test_bench_matmul(batch, seq, model, hidden): for i in range(iters): linearMixedBit(A) torch.cuda.synchronize() - print( - f"bnb linear8bitlt with threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" - ) + print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #linear8bit_train(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #linear8bit_train_thresh(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") def test_zeropoint(): def quant_zp(x): @@ -2009,7 +2034,7 @@ def test_extract_outliers(): assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[1] == idx.numel() - torch.testing.assert_allclose(outliers1, outliers2) + torch.testing.assert_close(outliers1, outliers2) CA, SA = F.transform(A, "col_ampere") @@ -2018,7 +2043,7 @@ def test_extract_outliers(): assert outliers2.shape[0] == shapeA[0] assert outliers2.shape[1] == idx.numel() - torch.testing.assert_allclose(outliers1, outliers2) + torch.testing.assert_close(outliers1, outliers2) @@ -2050,7 +2075,6 @@ def test_fp8_quant(): p_bits = 7-e_bits code = F.create_fp8_map(True, e_bits, p_bits).cuda() - print(e_bits, p_bits) abserr = [] relerr = [] for i in range(100): @@ -2149,7 +2173,7 @@ def test_few_bit_quant(): #assert err2.mean() <= err1 else: - torch.testing.assert_allclose(q1, q2) + torch.testing.assert_close(q1, q2) #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) #assert False @@ -2181,7 +2205,9 @@ def test_kbit_quantile_estimation(): def test_bench_dequantization(): a = torch.rand(1024, 1024, device='cuda').half() - qa, SA = F.quantize_blockwise(a) + code =F.create_fp8_map(True, 3, 0, 4).cuda() + qa, SA = F.quantize_blockwise(a, code=code) + print(qa.max()) max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 #print(max_theoretical_mu) @@ -2189,7 +2215,302 @@ def test_bench_dequantization(): torch.cuda.synchronize() t0 = time.time() for i in range(100): - F.dequantize_blockwise(qa, SA, blocksize=2048) + qa, SA = F.quantize_blockwise(a) torch.cuda.synchronize() #print((time.time()-t0)/1e6) + + +def test_fp4_quant(): + vals = list(product([0, 1], repeat=4)) + + code = {} + for bits in vals: + result = 0 + bias = 3 + sign, e1, e2, p1 = bits + idx = sign*8 + e1*4 + e2*2 + p1*1 + sign = -1.0 if sign else 1.0 + exp = e1*2 + e2*1 + if exp == 0: + # sub-normal + if p1 == 0: result = 0 + else: result = sign*0.0625 + else: + # normal + exp = 2**(-exp + bias + 1) + frac = 1.5 if p1 else 1.0 + result = sign*exp*frac + code[idx] = result + + A1 = torch.randn(1024, 1024, device='cuda').half() + qa, SA = F.quantize_fp4(A1, blocksize=64) + A2 = F.dequantize_fp4(qa, SA) + + err = (A1 - A2).abs().float() + relerr = (err/A1.abs().float()).mean() + idx = err > 1.0 + err = err.mean() + + + assert err.item() < 0.1 + assert relerr.item() < 0.28 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +def test_4bit_compressed_stats(quant_type): + for blocksize in [128, 64]: + errs1 = [] + errs2 = [] + for i in range(10): + A1 = torch.randn(1024, 1024, device='cuda').half() + q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) + A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) + + + err = (A1 - A2).abs().float() + relerr = (err/(A1.abs().float()+1e-15)).mean() + err = err.mean() + + errs1.append(err.item()) + + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + err = (A1 - A3).abs().float() + relerr = (err/(A1.abs().float()+1e-15)).mean() + err = err.mean() + + errs2.append(err.item()) + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + #print(sum(errs1)/len(errs1), blocksize, quant_type) + #print(sum(errs2)/len(errs2), blocksize, quant_type) + + + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +def test_bench_4bit_dequant(quant_type): + blocksize = 256 + a = torch.rand(1024*12*4, 1024*12, device='cuda').half() + qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) + + input_size = a.numel()/2 + output_size = a.numel()*2 + num_bytes = input_size+output_size + GB = num_bytes/1e9 + max_theoretical_s = GB/768 + #print(max_theoretical_s*1e6) + b = torch.randn(128, 1024*12, device='cuda').half() + + iters = 5 + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + #b.copy_(a) + torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6) + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # torch.matmul(b, a.t()) + #torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6) + + + +def test_normal_map_tree(): + code = F.create_normal_map() + values =code[:8].tolist() + code[-8:].tolist() + num_pivots = 1 + print(values) + while num_pivots <16: + idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) + print(idx) + num_pivots *= 2 + pivots = [] + for i in idx: + pivots.append((values[i-1]+values[i])/2) + print(pivots) + + +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_cutlass3_gemm(dtype): + debug = True + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + for dim in [4096]: + #for dim in [128+1]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(100): + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + #B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + + #print('') + #print(A) + #print(B.t()) + #A[:, :-1] = 0 + #B[:, :-1] = 0 + + + C1 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, B.t()) + + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + + errs.append(err) + relerrs.append(relerr) + + #if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + # print('') + # print(i, err, relerr) + # print(A.flatten()[-6:]) + # print(B.flatten()[-6:]) + # out = A.flatten()[-6:]*B.flatten()[-6:] + # print(out) + # print(out[:-1].sum()) + # print('='*80) + # print(C1.flatten()[-6:]) + # print(C2.flatten()[-6:]) + # #assert False, 'ERROR' + + c = int(C1.numel()*0.0014*(dim/256))+1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug) + #print(c/math.sqrt(dim)) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) + +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_gemm_4bit(dtype): + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + #for dim in [32]: + for dim in [4096]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + + #print('') + #print(A) + #print(B.t()) + #A[:, :-1] = 0 + #B[:, :-1] = 0 + + qB, state = F.quantize_nf4(B) + F.dequantize_nf4(qB, state) + + C3 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + C1 = bnb.matmul_4bit(A, qB.t(), state) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + + print(C1.shape, C2.shape) + + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + + errs.append(err) + relerrs.append(relerr) + + if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + print('') + print(i, err, relerr) + print(A.flatten()[-6:]) + print(B.flatten()[-6:]) + out = A.flatten()[-6:]*B.flatten()[-6:] + print(out) + print(out[:-1].sum()) + print('='*80) + print(C1.flatten()[-6:]) + print(C2.flatten()[-6:]) + #assert False, 'ERROR' + + c = int(C1.numel()*0.0014*(dim/256))+1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + #print(c/math.sqrt(dim)) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) + +@pytest.mark.skip("Row scale has some bugs for ampere") +def test_managed(): + n = 32*10 + A = F.get_paged(n, n, dtype=torch.float32) + B = F.get_paged(n, n, dtype=torch.uint8) + B2 = F.get_paged(n, n, dtype=torch.float32) + assert A.is_paged + assert B.is_paged + assert A.page_deviceid==0 + assert B.page_deviceid==0 + F.fill(A, 17.0) + F.fill(B, 17) + F.fill(B2, 2) + assert (A==17).sum().item() == n*n + assert (B==17).sum().item() == n*n + C = A*B.float() + assert (C==289).sum().item() == n*n + F._mul(A, B2) + F._mul(A, B2) + F._mul(A, B2) + assert (A==17*(2**3)).sum().item() == n*n + # F.prefetch_tensor(A) + # F.prefetch_tensor(B) + + + # F.fill(B2, 17.0) + # F._mul(A, B2) + + # F.prefetch_tensor(A, to_cpu=True) + # F.prefetch_tensor(B, to_cpu=True) + # F.prefetch_tensor(B2, to_cpu=True) + # torch.cuda.synchronize() + + # assert (A==17).sum().item() == n*n + + # torch.testing.assert_close(A, torch.ones(A.shape)*289) diff --git a/tests/test_modules.py b/tests/test_modules.py index d78f0c9..d0a9051 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -44,7 +44,7 @@ def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): sumval = (idx == 0).sum().item() if sumval > count: print(f"Too many values not close: assert {sumval} < {count}") - torch.testing.assert_allclose(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol, atol) class LinearFunction(torch.autograd.Function): @@ -330,18 +330,15 @@ def test_linear8bitlt_inference(threshold): def test_linear8bitlt_accumulated_gradient(): - l1 = torch.nn.Sequential( - *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)] - ) - l2 = torch.nn.Sequential( - *[torch.nn.Linear(32, 32).cuda().half() for i in range(2)] - ) - l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) - l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) - l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) - l2[1].bias = torch.nn.Parameter(l1[1].bias.clone()) - opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001) - opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001) + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) + + opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001) + opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001) acc_steps = 10 @@ -371,26 +368,17 @@ def test_linear8bitlt_accumulated_gradient(): # we do this copy because otherwise we have small divergences over time that add up l1[0].weight.data.copy_(l2[0].weight.data) l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) else: - torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad) - torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad) + torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3) -threshold = [0.0, 2.0] -values = threshold -names = [f"threshold_{vals}" for vals in values] - - -@pytest.mark.parametrize("threshold", values, ids=names) +@pytest.mark.parametrize("threshold", [0.0, 2.0]) @pytest.mark.parametrize("memory_efficient_backward", [False]) def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): - l1 = ( - bnb.nn.Linear8bitLt( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) - .cuda() - .half() - ) + l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) assert l1.weight.dtype == torch.int8 l1.eval() @@ -446,13 +434,7 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 - mlp = ( - MLP8bit( - 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward - ) - .half() - .to("cuda") - ) + mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() @@ -499,15 +481,16 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() scale = grad_ref.abs().mean() - torch.testing.assert_allclose(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) + torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) assert (idx == 0).sum().item() <= b1.numel() * 0.005 -def test_linear8bitlt_fp32_bias(): +@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +def test_linear_kbit_fp32_bias(module): # casts model to fp16 -> int8 automatically - l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False).cuda() - assert l1.weight.dtype == torch.int8 + l1 = module(32, 64).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias.dtype == torch.float32 for i in range(100): @@ -517,11 +500,116 @@ def test_linear8bitlt_fp32_bias(): assert l1.bias.dtype == torch.float16 # casts model to fp16 -> int8 automatically - l1 = bnb.nn.Linear8bitLt(32, 64, has_fp16_weights=False, bias=False).cuda() - assert l1.weight.dtype == torch.int8 + l1 = module(32, 64, bias=False).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] assert l1.bias is None for i in range(100): b1 = torch.randn(16, 8, 32, device="cuda").half() o1 = l1(b1) assert l1.bias is None + +modules = [] +modules.append(bnb.nn.Linear8bitLt) +modules.append(bnb.nn.Linear4bit) +modules.append(bnb.nn.LinearFP4) +modules.append(bnb.nn.LinearNF4) +modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) +modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) +names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C'] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("module", modules, ids=names) +def test_kbit_backprop(module): + b = 17 + dim1 = 37 + dim2 = 83 + + ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)]) + ref[1].weight.requires_grad = False + torch.nn.init.kaiming_normal_(ref[0].weight) + torch.nn.init.kaiming_normal_(ref[1].weight) + kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)]) + kbit[0].weight.detach().copy_(ref[0].weight) + kbit[1].weight.detach().copy_(ref[1].weight) + kbit[0].bias.detach().copy_(ref[0].bias) + kbit[1].bias.detach().copy_(ref[1].bias) + ref = ref.half().cuda() + kbit = kbit.half().cuda() + + errs1 = [] + errs2 = [] + relerrs1 = [] + relerrs2 = [] + for i in range(100): + batch = torch.randn(b, dim1).half().cuda() + out1 = ref(batch) + out2 = kbit(batch) + out1.mean().backward() + out2.mean().backward() + + grad1 = ref[0].weight.grad + grad2 = kbit[0].weight.grad + bgrad1 = ref[0].bias.grad + bgrad2 = kbit[0].bias.grad + + err1 = (out1-out2).abs().float() + err2 = (grad1-grad2).abs().float() + relerr1 = (err1/(out1.abs().float()+1e-9)) + relerr2 = (err2/(grad1.abs().float()+1e-9)) + errs1.append(err1.mean().item()) + errs2.append(err2.mean().item()) + relerrs1.append(relerr1.mean().item()) + relerrs2.append(relerr2.mean().item()) + + if isinstance(module, bnb.nn.Linear8bitLt): + torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05) + torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) + else: + torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05) + torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05) + ref.zero_grad() + kbit.zero_grad() + + assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 + assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 + print('out', sum(errs1)/len(errs1)) + print('grad', sum(errs2)/len(errs2)) + print('rel out', sum(relerrs1)/len(relerrs1)) + print('rel grad', sum(relerrs2)/len(relerrs2)) + +def test_fp8linear(): + + b = 10 + h = 1024 + inp = torch.randn(b, h).cuda() + fp32 = torch.nn.Linear(h, h*2).cuda() + fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda() + fp32b = torch.nn.Linear(h*2, h).cuda() + fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda() + + fp8.weight.data.copy_(fp32.weight.data) + fp8.bias.data.copy_(fp32.bias.data) + fp8b.weight.data.copy_(fp32b.weight.data) + fp8b.bias.data.copy_(fp32b.bias.data) + + a = fp32b(torch.nn.functional.gelu(fp32(inp))) + b = fp8b(torch.nn.functional.gelu(fp8(inp))) + + err = (a-b).abs().mean() + + a.mean().backward() + b.mean().backward() + + graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean() + bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean() + + assert err < 0.05 + assert graderr < 0.00002 + assert bgraderr < 0.00002 + + + + + + + diff --git a/tests/test_optim.py b/tests/test_optim.py index 839f80c..9e90083 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -19,11 +19,11 @@ import bitsandbytes.functional as F k = 20 def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): - idx = torch.isclose(a, b, rtol, atol) + idx = torch.isclose(a, b, rtol=rtol, atol=atol) error_count = (idx == 0).sum().item() if error_count > max_error_count: print(f"Too many values not close: assert {error_count} < {max_error_count}") - torch.testing.assert_allclose(a, b, rtol, atol) + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) def get_temp_dir(): @@ -35,11 +35,8 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) - str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) -# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) -# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["momentum_pytorch"] = ( None, @@ -47,28 +44,20 @@ str2optimizers["momentum_pytorch"] = ( bnb.optim.Adam, ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) -# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) +str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["lion"] = (Lion, bnb.optim.Lion) +str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["lars"] = ( - lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9), -) str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["adam8bit"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), -) -str2optimizers["lion8bit"] = ( - Lion, - lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False), -) +str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) +str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False)) str2optimizers["momentum8bit"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), @@ -77,19 +66,12 @@ str2optimizers["rmsprop8bit"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["lars8bit"] = ( - lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9), -) -str2optimizers["adam8bit_blockwise"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), -) -str2optimizers["lion8bit_blockwise"] = ( - Lion, - lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True), -) +str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) +str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) +str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) +str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -101,53 +83,35 @@ str2optimizers["rmsprop8bit_blockwise"] = ( str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lion"] = [("exp_avg", "state1")] +str2statenames["paged_lion"] = [("exp_avg", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")] -str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["rmsprop"] = [("square_avg", "state1")] -str2statenames["adam8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1"), - ("exp_avg_sq", "state2", "qmap2", "max2"), -] -str2statenames["lion8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1") -] -str2statenames["lamb8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1"), - ("exp_avg_sq", "state2", "qmap2", "max2"), -] -str2statenames["adam8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1"), - ("exp_avg_sq", "state2", "qmap2", "absmax2"), -] -str2statenames["lion8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1") -] -str2statenames["momentum8bit"] = [ - ("momentum_buffer", "state1", "qmap1", "max1") -] -str2statenames["momentum8bit_blockwise"] = [ - ("momentum_buffer", "state1", "qmap1", "absmax1") -] -str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] +str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] -str2statenames["rmsprop8bit_blockwise"] = [ - ("square_avg", "state1", "qmap1", "absmax1") -] +str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] +str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] +str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] dim1 = [1024] dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16] -optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"] +gtype = [torch.float32, torch.float16, torch.bfloat16] +optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion'] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values -] - - +names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer32bit(dim1, dim2, gtype, optim_name): + if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -159,6 +123,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): if gtype == torch.float32: atol, rtol = 1e-6, 1e-5 + elif gtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-2 else: atol, rtol = 1e-4, 1e-3 @@ -172,9 +138,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose( + torch.testing.assert_close( torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2], + bnb_optimizer.state[p2][name2].cuda(), atol=atol, rtol=rtol, ) @@ -201,14 +167,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): atol=atol, rtol=rtol, max_error_count=10) - if gtype == torch.float16: + if gtype != torch.float32: # the adam buffers should also be close because they are 32-bit # but the paramters can diverge because they are 16-bit # the difference grow larger and larger with each update # --> copy the state to keep weights close - p1.data = p1.data.half().float() + p1.data = p1.data.to(p2.dtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.half(), p2) + torch.testing.assert_close(p1.to(p2.dtype), p2) if optim_name in ["lars", "lamb"]: assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 @@ -268,7 +234,7 @@ def test_global_config(dim1, dim2, gtype): dim1 = [1024] dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16] +gtype = [torch.float32, torch.float16, torch.bfloat16] optimizer_names = [ "adam8bit", "lion8bit", @@ -276,7 +242,6 @@ optimizer_names = [ "rmsprop8bit", "adam8bit_blockwise", "lion8bit_blockwise", - "lars8bit", "momentum8bit_blockwise", "rmsprop8bit_blockwise", ] @@ -288,6 +253,7 @@ names = [ @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer8bit(dim1, dim2, gtype, optim_name): + if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip() if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -301,7 +267,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): if gtype == torch.float32: atol, rtol = 3e-3, 1e-3 patol, prtol = 1e-5, 1e-3 - + elif gtype == torch.bfloat16: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-4, 1e-2 else: atol, rtol = 3e-3, 1e-3 patol, prtol = 1e-5, 1e-3 @@ -309,7 +277,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): errors = [] relerrors = [] - for i in range(50): + for i in range(100): g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -343,13 +311,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) == 0 ) - assert num_not_close.sum().item() < 20 + #assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) relerr = err / (torch.abs(p1)+1e-9) - assert err.mean() < 0.0001 - assert relerr.mean() < 0.001 + if g.dtype == torch.bfloat16: + assert err.mean() < 0.00015 + assert relerr.mean() < 0.0016 + else: + assert err.mean() < 0.00012 + assert relerr.mean() < 0.0012 errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) @@ -369,12 +341,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose( - raws1cpy, bnb_optimizer.state[p2][name2] - ) - torch.testing.assert_allclose( - qmap1, bnb_optimizer.state[p2][qmap] - ) + torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) if "blockwise" in optim_name: s1 = F.dequantize_blockwise( @@ -389,17 +357,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], ) - torch.testing.assert_allclose(s1cpy, s1) + torch.testing.assert_close(s1cpy, s1) - num_not_close = ( - torch.isclose( - torch_optimizer.state[p1][name1], - s1, - atol=atol, - rtol=rtol, - ) - == 0 - ) + num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) assert num_not_close.sum().item() < 20 # since Lion can have pretty noisy updates where things lie at the boundary # allow up to 5 errors for Lion @@ -409,10 +369,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): # together so we can test against the Adam error p1.data = p1.data.to(gtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.to(gtype), p2) - for (name1, name2, qmap, max_val), s in zip( - str2statenames[optim_name], dequant_states - ): + torch.testing.assert_close(p1.to(gtype), p2) + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): torch_optimizer.state[p1][name1].copy_(s.data) # print(sum(errors)/len(errors)) @@ -473,28 +431,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state if optim_bits == 32: - torch.testing.assert_allclose(p1, p2) - torch.testing.assert_allclose( + torch.testing.assert_close(p1, p2) + torch.testing.assert_close( adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=5e-5, rtol=1e-4, ) - torch.testing.assert_allclose( + torch.testing.assert_close( adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=5e-5, rtol=1e-4, ) elif optim_bits == 8: - torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) - torch.testing.assert_allclose( + torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3) + torch.testing.assert_close( adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3, ) - torch.testing.assert_allclose( + torch.testing.assert_close( adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, @@ -526,7 +484,7 @@ gtype = [torch.float32, torch.float16] # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ["adam8bit_blockwise"] +optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values @@ -557,3 +515,47 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): params = (k - k // 5) * dim1 * dim2 print(optim_name, gtype, s / params) # assert s < 3.9 + +dim1 = [2*1024] +gtype = [torch.float16] +#mode = ['torch', 'bnb'] +mode = ['bnb'] +optimizer_names = ['paged_adamw'] +#optimizer_names = ['paged_adamw8bit_blockwise'] +values = list(product(dim1,gtype, optimizer_names, mode)) +names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names) +def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): + layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) + layers1 = layers1.to(gtype) + layers1 = layers1.cuda() + + large_tensor = None + if mode == 'torch': + optim = str2optimizers[optim_name][0](layers1.parameters()) + else: + optim = str2optimizers[optim_name][1](layers1.parameters()) + # 12 GB + large_tensor = torch.empty((int(4.5e9),), device='cuda') + + torch.cuda.synchronize() + time.sleep(5) + + num_batches = 5 + batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) + lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() + + for i in range(num_batches): + print(i) + b = batches[i] + if i ==2: + torch.cuda.synchronize() + t0 = time.time() + + out1 = layers1(b) + + loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() + loss1.backward() + optim.step() + torch.cuda.synchronize() + print(mode, time.time() - t0) diff --git a/tests/test_triton.py b/tests/test_triton.py new file mode 100644 index 0000000..e18c7a9 --- /dev/null +++ b/tests/test_triton.py @@ -0,0 +1,59 @@ +import pytest +import torch + +from bitsandbytes.triton.triton_utils import is_triton_available +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear +from bitsandbytes.nn import Linear8bitLt + +@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, + reason="This test requires triton and a GPU with compute capability 8.0 or higher.") +@pytest.mark.parametrize("vector_wise_quantization", [False, True]) +def test_switchback(vector_wise_quantization): + for dim in [83]: + for batch in [13]: + + standard = torch.nn.Linear(dim, 4 * dim).cuda().half() + switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() + baseline = Linear8bitLt(dim, 4 * dim).cuda().half() + switchback.weight.data.copy_(standard.weight) + switchback.bias.data.copy_(standard.bias) + baseline.weight.data.copy_(standard.weight) + baseline.bias.data.copy_(standard.bias) + + x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True) + x2 = x1.clone().detach().requires_grad_(True) + x3 = x1.clone().detach().requires_grad_(True) + + out_standard = standard(x1) + (2**10 * out_standard.abs().mean()).backward() + + print(x2.dtype) + out_sb = switchback(x2) + (2**10 * out_sb.abs().mean()).backward() + + out_baseline = baseline(x3) + (2**10 * out_baseline.abs().mean()).backward() + + err_sb = (out_standard - out_sb).abs().mean() + err_baseline = (out_standard - out_baseline).abs().mean() + print('OUT', err_sb, err_baseline) + assert err_sb < 2 * err_baseline + + err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() + err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() + + print('GW2', err_sb, err_baseline) + assert err_sb < 2 * err_baseline + + err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() + err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() + + print('GW1', err_sb, err_baseline) + assert err_sb < 2 * err_baseline + + err_sb = (x1.grad - x2.grad).abs().mean() + err_baseline = (x1.grad - x3.grad).abs().mean() + + print('GX1', err_sb, err_baseline) + assert err_sb < 2 * err_baseline +