Added 32-bit optimizer for bfloat16 gradients.

This commit is contained in:
Tim Dettmers 2023-04-17 18:01:49 -07:00
parent b8ea2b416d
commit 7dc198feb7
7 changed files with 65 additions and 86 deletions

View File

@ -23,7 +23,7 @@ try:
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment! CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment: If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
https://github.com/TimDettmers/bitsandbytes/issues''') https://github.com/TimDettmers/bitsandbytes/issues''')
lib.cadam32bit_g32 lib.cadam_8bit_blockwise_fp32
lib.get_context.restype = ct.c_void_p lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True COMPILED_WITH_CUDA = True

View File

@ -28,7 +28,7 @@ name2qmap = {}
if COMPILED_WITH_CUDA: if COMPILED_WITH_CUDA:
"""C FUNCTIONS FOR OPTIMIZERS""" """C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {} str2optimizer32bit = {}
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer32bit["adam"] = (lib.cadam32bit_gfp32, lib.cadam32bit_gfp16, lib.cadam32bit_gbf16)
str2optimizer32bit["momentum"] = ( str2optimizer32bit["momentum"] = (
lib.cmomentum32bit_g32, lib.cmomentum32bit_g32,
lib.cmomentum32bit_g16, lib.cmomentum32bit_g16,
@ -41,11 +41,6 @@ if COMPILED_WITH_CUDA:
lib.cadagrad32bit_g32, lib.cadagrad32bit_g32,
lib.cadagrad32bit_g16, lib.cadagrad32bit_g16,
) )
str2optimizer32bit["lars"] = (
lib.cmomentum32bit_g32,
lib.cmomentum32bit_g16,
)
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer8bit = {} str2optimizer8bit = {}
str2optimizer8bit["adam"] = ( str2optimizer8bit["adam"] = (
@ -998,53 +993,37 @@ def optimizer_update_32bit(
if max_unorm > 0.0: if max_unorm > 0.0:
param_norm = torch.norm(p.data.float()) param_norm = torch.norm(p.data.float())
if optimizer_name not in str2optimizer32bit:
raise NotImplementedError(
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
)
if g.dtype == torch.float32 and state1.dtype == torch.float32: optim_func = None
str2optimizer32bit[optimizer_name][0]( if g.dtype == torch.float32:
get_ptr(g), optim_func = str2optimizer32bit[optimizer_name][0]
get_ptr(p), elif g.dtype == torch.float16:
get_ptr(state1), optim_func = str2optimizer32bit[optimizer_name][1]
get_ptr(state2), elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3):
get_ptr(unorm_vec), optim_func = str2optimizer32bit[optimizer_name][2]
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
elif g.dtype == torch.float16 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][1](
get_ptr(g),
get_ptr(p),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
else: else:
raise ValueError( raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}")
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
) is_on_gpu([g, p, state1, state2, unorm_vec])
prev_device = pre_call(g.device)
optim_func(
get_ptr(g),
get_ptr(p),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()))
post_call(prev_device)
def optimizer_update_8bit( def optimizer_update_8bit(
@ -1199,12 +1178,12 @@ def optimizer_update_8bit_blockwise(
optim_func = None optim_func = None
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][0] optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8: elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][1] optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and
len(str2optimizer8bit_blockwise[optimizer_name])==3): len(str2optimizer8bit_blockwise[optimizer_name])==3):
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][2] optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
else: else:
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
@ -1213,7 +1192,7 @@ def optimizer_update_8bit_blockwise(
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
prev_device = pre_call(g.device) prev_device = pre_call(g.device)
optimizer_func( optim_func(
get_ptr(p), get_ptr(p),
get_ptr(g), get_ptr(g),
get_ptr(state1), get_ptr(state1),

View File

@ -178,6 +178,13 @@ class Params4bit(torch.nn.Parameter):
s[0] = s[0].to(device) s[0] = s[0].to(device)
if self.compress_statistics: if self.compress_statistics:
# TODO: refactor this. This is a nightmare # TODO: refactor this. This is a nightmare
# for 4-bit:
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
#s[-2][0] = s[-2][0].to(device) # offset
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
s[-2][0] = s[-2][0].to(device) # offset s[-2][0] = s[-2][0].to(device) # offset
s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook

View File

@ -2981,12 +2981,15 @@ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8
const float beta1, const float beta2, const float eps, const float weight_decay, \ const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \ const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
MAKE_PreconditionOptimizer32bit2State(ADAM, float) MAKE_PreconditionOptimizer32bit2State(ADAM, float)
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16)
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ #define MAKE_PreconditionStatic8bit1State(oname, gtype) \

View File

@ -703,6 +703,7 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, half)

View File

@ -29,8 +29,9 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \
MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16) MAKE_FUNC32(momentum, MOMENTUM, half, 16)
MAKE_FUNC32(adam, ADAM, float, 32) MAKE_FUNC32(adam, ADAM, float, fp32)
MAKE_FUNC32(adam, ADAM, half, 16) MAKE_FUNC32(adam, ADAM, half, fp16)
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
@ -173,8 +174,9 @@ extern "C"
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ { name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
MAKE_CFUNC32(adam, float, 32) MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, 16) MAKE_CFUNC32(adam, half, fp16)
MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
MAKE_CFUNC32(momentum, float, 32) MAKE_CFUNC32(momentum, float, 32)
MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, float, 32)

View File

@ -44,10 +44,6 @@ str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["lars"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
)
str2optimizers["rmsprop"] = ( str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
@ -64,10 +60,6 @@ str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["lars8bit"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
)
str2optimizers["adam8bit_blockwise"] = ( str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam, torch.optim.Adam,
@ -85,7 +77,6 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames = {} str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [ str2statenames["adam8bit"] = [
@ -106,7 +97,6 @@ str2statenames["momentum8bit"] = [
str2statenames["momentum8bit_blockwise"] = [ str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1") ("momentum_buffer", "state1", "qmap1", "absmax1")
] ]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [ str2statenames["rmsprop8bit_blockwise"] = [
("square_avg", "state1", "qmap1", "absmax1") ("square_avg", "state1", "qmap1", "absmax1")
@ -114,14 +104,10 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"] optimizer_names = ["adam", "momentum", "rmsprop"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
@ -135,6 +121,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.float32: if gtype == torch.float32:
atol, rtol = 1e-6, 1e-5 atol, rtol = 1e-6, 1e-5
elif gtype == torch.bfloat16:
atol, rtol = 1e-3, 1e-2
else: else:
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
@ -173,14 +161,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
rtol=rtol, rtol=rtol,
) )
if gtype == torch.float16: if gtype != torch.float32:
# the adam buffers should also be close because they are 32-bit # the adam buffers should also be close because they are 32-bit
# but the paramters can diverge because they are 16-bit # but the paramters can diverge because they are 16-bit
# the difference grow larger and larger with each update # the difference grow larger and larger with each update
# --> copy the state to keep weights close # --> copy the state to keep weights close
p1.data = p1.data.half().float() p1.data = p1.data.to(p2.dtype).float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2) torch.testing.assert_allclose(p1.to(p2.dtype), p2)
if optim_name in ["lars", "lamb"]: if optim_name in ["lars", "lamb"]:
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
@ -246,7 +234,6 @@ optimizer_names = [
"momentum8bit", "momentum8bit",
"rmsprop8bit", "rmsprop8bit",
"adam8bit_blockwise", "adam8bit_blockwise",
"lars8bit",
"momentum8bit_blockwise", "momentum8bit_blockwise",
"rmsprop8bit_blockwise", "rmsprop8bit_blockwise",
] ]
@ -321,10 +308,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerr = err / torch.abs(p1) relerr = err / torch.abs(p1)
if g.dtype == torch.bfloat16: if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015 assert err.mean() < 0.00015
assert relerr.mean() < 0.0015 assert relerr.mean() < 0.0016
else: else:
assert err.mean() < 0.0001 assert err.mean() < 0.00012
assert relerr.mean() < 0.001 assert relerr.mean() < 0.0012
errors.append(err.mean().item()) errors.append(err.mean().item())
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())