diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index e2ca978..8adca93 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -23,7 +23,7 @@ try: CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! If you cannot find any issues and suspect a bug, please open an issue with detals about your environment: https://github.com/TimDettmers/bitsandbytes/issues''') - lib.cadam32bit_g32 + lib.cadam_8bit_blockwise_fp32 lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p COMPILED_WITH_CUDA = True diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b168606..ff0eb7e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -28,7 +28,7 @@ name2qmap = {} if COMPILED_WITH_CUDA: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = {} - str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) + str2optimizer32bit["adam"] = (lib.cadam32bit_gfp32, lib.cadam32bit_gfp16, lib.cadam32bit_gbf16) str2optimizer32bit["momentum"] = ( lib.cmomentum32bit_g32, lib.cmomentum32bit_g16, @@ -41,11 +41,6 @@ if COMPILED_WITH_CUDA: lib.cadagrad32bit_g32, lib.cadagrad32bit_g16, ) - str2optimizer32bit["lars"] = ( - lib.cmomentum32bit_g32, - lib.cmomentum32bit_g16, - ) - str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer8bit = {} str2optimizer8bit["adam"] = ( @@ -998,53 +993,37 @@ def optimizer_update_32bit( if max_unorm > 0.0: param_norm = torch.norm(p.data.float()) - if optimizer_name not in str2optimizer32bit: - raise NotImplementedError( - f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}' - ) - if g.dtype == torch.float32 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][0]( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - elif g.dtype == torch.float16 and state1.dtype == torch.float32: - str2optimizer32bit[optimizer_name][1]( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) + optim_func = None + if g.dtype == torch.float32: + optim_func = str2optimizer32bit[optimizer_name][0] + elif g.dtype == torch.float16: + optim_func = str2optimizer32bit[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + optim_func = str2optimizer32bit[optimizer_name][2] else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" - ) + raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + + is_on_gpu([g, p, state1, state2, unorm_vec]) + prev_device = pre_call(g.device) + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel())) + post_call(prev_device) def optimizer_update_8bit( @@ -1199,12 +1178,12 @@ def optimizer_update_8bit_blockwise( optim_func = None if g.dtype == torch.float32 and state1.dtype == torch.uint8: - optimizer_func = str2optimizer8bit_blockwise[optimizer_name][0] + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - optimizer_func = str2optimizer8bit_blockwise[optimizer_name][1] + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and len(str2optimizer8bit_blockwise[optimizer_name])==3): - optimizer_func = str2optimizer8bit_blockwise[optimizer_name][2] + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] else: raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" @@ -1213,7 +1192,7 @@ def optimizer_update_8bit_blockwise( is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) prev_device = pre_call(g.device) - optimizer_func( + optim_func( get_ptr(p), get_ptr(g), get_ptr(state1), diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index ab16e01..24f5070 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -178,6 +178,13 @@ class Params4bit(torch.nn.Parameter): s[0] = s[0].to(device) if self.compress_statistics: # TODO: refactor this. This is a nightmare + # for 4-bit: + # state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + # state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + #s[-2][0] = s[-2][0].to(device) # offset + #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax + + # for 8-bit s[-2][0] = s[-2][0].to(device) # offset s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook diff --git a/csrc/kernels.cu b/csrc/kernels.cu index c35acc8..2d940be 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -2981,12 +2981,15 @@ template __global__ void kPreconditionOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ diff --git a/csrc/ops.cu b/csrc/ops.cu index de14039..76777ae 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -703,6 +703,7 @@ template void optimizer32bit(gtype* g, gtype* p, \ MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, __nv_bfloat16) MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c index d169178..0e9106c 100644 --- a/csrc/pythonInterface.c +++ b/csrc/pythonInterface.c @@ -29,8 +29,9 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) -MAKE_FUNC32(adam, ADAM, float, 32) -MAKE_FUNC32(adam, ADAM, half, 16) +MAKE_FUNC32(adam, ADAM, float, fp32) +MAKE_FUNC32(adam, ADAM, half, fp16) +MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, float, 32) @@ -173,8 +174,9 @@ extern "C" const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ - MAKE_CFUNC32(adam, float, 32) - MAKE_CFUNC32(adam, half, 16) + MAKE_CFUNC32(adam, float, fp32) + MAKE_CFUNC32(adam, half, fp16) + MAKE_CFUNC32(adam, __nv_bfloat16, bf16) MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(rmsprop, float, 32) diff --git a/tests/test_optim.py b/tests/test_optim.py index 83390a4..a13b332 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -44,10 +44,6 @@ str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["lars"] = ( - lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9), -) str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), @@ -64,10 +60,6 @@ str2optimizers["rmsprop8bit"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["lars8bit"] = ( - lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), - lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9), -) str2optimizers["adam8bit_blockwise"] = ( torch.optim.Adam, @@ -85,7 +77,6 @@ str2optimizers["rmsprop8bit_blockwise"] = ( str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["momentum"] = [("momentum_buffer", "state1")] -str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["adam8bit"] = [ @@ -106,7 +97,6 @@ str2statenames["momentum8bit"] = [ str2statenames["momentum8bit_blockwise"] = [ ("momentum_buffer", "state1", "qmap1", "absmax1") ] -str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit_blockwise"] = [ ("square_avg", "state1", "qmap1", "absmax1") @@ -114,14 +104,10 @@ str2statenames["rmsprop8bit_blockwise"] = [ dim1 = [1024] dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16] -optimizer_names = ["adam", "momentum", "rmsprop", "lars"] +gtype = [torch.float32, torch.float16, torch.bfloat16] +optimizer_names = ["adam", "momentum", "rmsprop"] values = list(product(dim1, dim2, gtype, optimizer_names)) -names = [ - "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values -] - - +names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) def test_optimizer32bit(dim1, dim2, gtype, optim_name): if dim1 == 1 and dim2 == 1: @@ -135,6 +121,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): if gtype == torch.float32: atol, rtol = 1e-6, 1e-5 + elif gtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-2 else: atol, rtol = 1e-4, 1e-3 @@ -173,14 +161,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): rtol=rtol, ) - if gtype == torch.float16: + if gtype != torch.float32: # the adam buffers should also be close because they are 32-bit # but the paramters can diverge because they are 16-bit # the difference grow larger and larger with each update # --> copy the state to keep weights close - p1.data = p1.data.half().float() + p1.data = p1.data.to(p2.dtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.half(), p2) + torch.testing.assert_allclose(p1.to(p2.dtype), p2) if optim_name in ["lars", "lamb"]: assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 @@ -246,7 +234,6 @@ optimizer_names = [ "momentum8bit", "rmsprop8bit", "adam8bit_blockwise", - "lars8bit", "momentum8bit_blockwise", "rmsprop8bit_blockwise", ] @@ -321,10 +308,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): relerr = err / torch.abs(p1) if g.dtype == torch.bfloat16: assert err.mean() < 0.00015 - assert relerr.mean() < 0.0015 + assert relerr.mean() < 0.0016 else: - assert err.mean() < 0.0001 - assert relerr.mean() < 0.001 + assert err.mean() < 0.00012 + assert relerr.mean() < 0.0012 errors.append(err.mean().item()) relerrors.append(relerr.mean().item())