diff --git a/Makefile b/Makefile index 7bee7ef..e114160 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,7 @@ CUDA_VERSION:= endif + NVCC := $(CUDA_HOME)/bin/nvcc ########################################### @@ -59,9 +60,9 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 -all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env - $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) - $(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o +all: $(BUILD_DIR) env + $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 969250a..8bfd668 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -73,6 +73,7 @@ if COMPILED_WITH_CUDA: str2optimizer8bit_blockwise["adam"] = ( lib.cadam_8bit_blockwise_fp32, lib.cadam_8bit_blockwise_fp16, + lib.cadam_8bit_blockwise_bf16, ) str2optimizer8bit_blockwise["momentum"] = ( lib.cmomentum_8bit_blockwise_fp32, @@ -1125,51 +1126,42 @@ def optimizer_update_8bit_blockwise( skip_zeros=False, ) -> None: + optim_func = None if g.dtype == torch.float32 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][0]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optimizer_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - str2optimizer8bit_blockwise[optimizer_name][1]( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optimizer_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and + len(str2optimizer8bit_blockwise[optimizer_name])==3): + optimizer_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" ) + is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + prev_device = pre_call(g.device) + optimizer_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + post_call(prev_device) def percentile_clipping( grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 8f33161..e7e57d7 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2988,6 +2988,8 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise( \ diff --git a/csrc/ops.cu b/csrc/ops.cu index 8044c66..a5a23b5 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -741,3 +741,5 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index 6a4bb0d..a485a09 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -57,19 +57,20 @@ MAKE_FUNC8(rmsprop, RMSPROP, float, 32) MAKE_FUNC8(rmsprop, RMSPROP, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ -void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ +void fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ { optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ -MAKE_BLOCKWISE8(adam, ADAM, half, 16) -MAKE_BLOCKWISE8(adam, ADAM, float, 32) -MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16) -MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16) -MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16) -MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32) +MAKE_BLOCKWISE8(adam, ADAM, half, fp16) +MAKE_BLOCKWISE8(adam, ADAM, float, fp32) +MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } @@ -194,20 +195,20 @@ extern "C" MAKE_CFUNC8(rmsprop, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ - void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \ + void c##fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ - - MAKE_CBLOCKWISE8(adam, ADAM, half, 16) - MAKE_CBLOCKWISE8(adam, ADAM, float, 32) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16) - MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16) - MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16) - MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32) + { fname##_8bit_blockwise_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) + MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) + MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } diff --git a/tests/test_optim.py b/tests/test_optim.py index 3df2dad..92e3ed2 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -26,6 +26,8 @@ def get_temp_dir(): def rm_path(path): shutil.rmtree(path) +str2bf16support = {} +str2bf16support['adam8bit_blockwise'] = True str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) @@ -238,7 +240,7 @@ def test_global_config(dim1, dim2, gtype): dim1 = [1024] dim2 = [32, 1024, 4097] -gtype = [torch.float32, torch.float16] +gtype = [torch.float32, torch.float16, torch.bfloat16] optimizer_names = [ "adam8bit", "momentum8bit", @@ -256,6 +258,7 @@ names = [ @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer8bit(dim1, dim2, gtype, optim_name): + if gtype == torch.bfloat16 and optim_name not in str2bf16support: return if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 @@ -269,7 +272,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): if gtype == torch.float32: atol, rtol = 3e-3, 1e-3 patol, prtol = 1e-5, 1e-3 - + elif gtype == torch.bfloat16: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-4, 1e-2 else: atol, rtol = 3e-3, 1e-3 patol, prtol = 1e-5, 1e-3 @@ -314,8 +319,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): err = torch.abs(p1 - p2) relerr = err / torch.abs(p1) - assert err.mean() < 0.0001 - assert relerr.mean() < 0.001 + if g.dtype == torch.bfloat16: + assert err.mean() < 0.00015 + assert relerr.mean() < 0.0015 + else: + assert err.mean() < 0.0001 + assert relerr.mean() < 0.001 errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) @@ -335,12 +344,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose( - raws1cpy, bnb_optimizer.state[p2][name2] - ) - torch.testing.assert_allclose( - qmap1, bnb_optimizer.state[p2][qmap] - ) + torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) if "blockwise" in optim_name: s1 = F.dequantize_blockwise( @@ -357,28 +362,16 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): ) torch.testing.assert_allclose(s1cpy, s1) - num_not_close = ( - torch.isclose( - torch_optimizer.state[p1][name1], - s1, - atol=atol, - rtol=rtol, - ) - == 0 - ) + num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) assert num_not_close.sum().item() < 20 - torch.testing.assert_allclose( - p1, p2.float(), atol=patol, rtol=prtol - ) + torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) # the parameters diverge quickly. Here we keep them close # together so we can test against the Adam error p1.data = p1.data.to(gtype).float() p2.copy_(p1.data) torch.testing.assert_allclose(p1.to(gtype), p2) - for (name1, name2, qmap, max_val), s in zip( - str2statenames[optim_name], dequant_states - ): + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): torch_optimizer.state[p1][name1].copy_(s.data) # print(sum(errors)/len(errors))