Added 32-bit optimizer for bfloat16 gradients.
This commit is contained in:
parent
b8ea2b416d
commit
7dc198feb7
|
@ -23,7 +23,7 @@ try:
|
|||
CUDA Setup failed despite GPU being available. Inspect the CUDA SETUP outputs above to fix your environment!
|
||||
If you cannot find any issues and suspect a bug, please open an issue with detals about your environment:
|
||||
https://github.com/TimDettmers/bitsandbytes/issues''')
|
||||
lib.cadam32bit_g32
|
||||
lib.cadam_8bit_blockwise_fp32
|
||||
lib.get_context.restype = ct.c_void_p
|
||||
lib.get_cusparse.restype = ct.c_void_p
|
||||
COMPILED_WITH_CUDA = True
|
||||
|
|
|
@ -28,7 +28,7 @@ name2qmap = {}
|
|||
if COMPILED_WITH_CUDA:
|
||||
"""C FUNCTIONS FOR OPTIMIZERS"""
|
||||
str2optimizer32bit = {}
|
||||
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
||||
str2optimizer32bit["adam"] = (lib.cadam32bit_gfp32, lib.cadam32bit_gfp16, lib.cadam32bit_gbf16)
|
||||
str2optimizer32bit["momentum"] = (
|
||||
lib.cmomentum32bit_g32,
|
||||
lib.cmomentum32bit_g16,
|
||||
|
@ -41,11 +41,6 @@ if COMPILED_WITH_CUDA:
|
|||
lib.cadagrad32bit_g32,
|
||||
lib.cadagrad32bit_g16,
|
||||
)
|
||||
str2optimizer32bit["lars"] = (
|
||||
lib.cmomentum32bit_g32,
|
||||
lib.cmomentum32bit_g16,
|
||||
)
|
||||
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
||||
|
||||
str2optimizer8bit = {}
|
||||
str2optimizer8bit["adam"] = (
|
||||
|
@ -998,53 +993,37 @@ def optimizer_update_32bit(
|
|||
if max_unorm > 0.0:
|
||||
param_norm = torch.norm(p.data.float())
|
||||
|
||||
if optimizer_name not in str2optimizer32bit:
|
||||
raise NotImplementedError(
|
||||
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
|
||||
)
|
||||
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.float32:
|
||||
str2optimizer32bit[optimizer_name][0](
|
||||
get_ptr(g),
|
||||
get_ptr(p),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
get_ptr(unorm_vec),
|
||||
ct.c_float(max_unorm),
|
||||
ct.c_float(param_norm),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
elif g.dtype == torch.float16 and state1.dtype == torch.float32:
|
||||
str2optimizer32bit[optimizer_name][1](
|
||||
get_ptr(g),
|
||||
get_ptr(p),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
get_ptr(unorm_vec),
|
||||
ct.c_float(max_unorm),
|
||||
ct.c_float(param_norm),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
optim_func = None
|
||||
if g.dtype == torch.float32:
|
||||
optim_func = str2optimizer32bit[optimizer_name][0]
|
||||
elif g.dtype == torch.float16:
|
||||
optim_func = str2optimizer32bit[optimizer_name][1]
|
||||
elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3):
|
||||
optim_func = str2optimizer32bit[optimizer_name][2]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}")
|
||||
|
||||
is_on_gpu([g, p, state1, state2, unorm_vec])
|
||||
prev_device = pre_call(g.device)
|
||||
optim_func(
|
||||
get_ptr(g),
|
||||
get_ptr(p),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
get_ptr(unorm_vec),
|
||||
ct.c_float(max_unorm),
|
||||
ct.c_float(param_norm),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()))
|
||||
post_call(prev_device)
|
||||
|
||||
|
||||
def optimizer_update_8bit(
|
||||
|
@ -1199,12 +1178,12 @@ def optimizer_update_8bit_blockwise(
|
|||
|
||||
optim_func = None
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
||||
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][0]
|
||||
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
|
||||
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
|
||||
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][1]
|
||||
optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
|
||||
elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and
|
||||
len(str2optimizer8bit_blockwise[optimizer_name])==3):
|
||||
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][2]
|
||||
optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
|
@ -1213,7 +1192,7 @@ def optimizer_update_8bit_blockwise(
|
|||
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
|
||||
|
||||
prev_device = pre_call(g.device)
|
||||
optimizer_func(
|
||||
optim_func(
|
||||
get_ptr(p),
|
||||
get_ptr(g),
|
||||
get_ptr(state1),
|
||||
|
|
|
@ -178,6 +178,13 @@ class Params4bit(torch.nn.Parameter):
|
|||
s[0] = s[0].to(device)
|
||||
if self.compress_statistics:
|
||||
# TODO: refactor this. This is a nightmare
|
||||
# for 4-bit:
|
||||
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
|
||||
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
|
||||
#s[-2][0] = s[-2][0].to(device) # offset
|
||||
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
|
||||
|
||||
# for 8-bit
|
||||
s[-2][0] = s[-2][0].to(device) # offset
|
||||
s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics
|
||||
s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook
|
||||
|
|
|
@ -2981,12 +2981,15 @@ template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8
|
|||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, const int n); \
|
||||
|
||||
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
|
||||
MAKE_PreconditionOptimizer32bit2State(ADAM, float)
|
||||
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
|
||||
MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16)
|
||||
|
||||
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
|
||||
|
|
|
@ -703,6 +703,7 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
|||
|
||||
MAKE_optimizer32bit(ADAM, half)
|
||||
MAKE_optimizer32bit(ADAM, float)
|
||||
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
|
||||
MAKE_optimizer32bit(MOMENTUM, half)
|
||||
MAKE_optimizer32bit(MOMENTUM, float)
|
||||
MAKE_optimizer32bit(RMSPROP, half)
|
||||
|
|
|
@ -29,8 +29,9 @@ void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
|||
|
||||
MAKE_FUNC32(momentum, MOMENTUM, float, 32)
|
||||
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
|
||||
MAKE_FUNC32(adam, ADAM, float, 32)
|
||||
MAKE_FUNC32(adam, ADAM, half, 16)
|
||||
MAKE_FUNC32(adam, ADAM, float, fp32)
|
||||
MAKE_FUNC32(adam, ADAM, half, fp16)
|
||||
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
|
||||
|
@ -173,8 +174,9 @@ extern "C"
|
|||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
|
||||
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CFUNC32(adam, float, 32)
|
||||
MAKE_CFUNC32(adam, half, 16)
|
||||
MAKE_CFUNC32(adam, float, fp32)
|
||||
MAKE_CFUNC32(adam, half, fp16)
|
||||
MAKE_CFUNC32(adam, __nv_bfloat16, bf16)
|
||||
MAKE_CFUNC32(momentum, float, 32)
|
||||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
|
|
|
@ -44,10 +44,6 @@ str2optimizers["momentum"] = (
|
|||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["lars"] = (
|
||||
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
|
||||
)
|
||||
str2optimizers["rmsprop"] = (
|
||||
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -64,10 +60,6 @@ str2optimizers["rmsprop8bit"] = (
|
|||
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["lars8bit"] = (
|
||||
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
|
||||
)
|
||||
|
||||
str2optimizers["adam8bit_blockwise"] = (
|
||||
torch.optim.Adam,
|
||||
|
@ -85,7 +77,6 @@ str2optimizers["rmsprop8bit_blockwise"] = (
|
|||
str2statenames = {}
|
||||
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lars"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["rmsprop"] = [("square_avg", "state1")]
|
||||
str2statenames["adam8bit"] = [
|
||||
|
@ -106,7 +97,6 @@ str2statenames["momentum8bit"] = [
|
|||
str2statenames["momentum8bit_blockwise"] = [
|
||||
("momentum_buffer", "state1", "qmap1", "absmax1")
|
||||
]
|
||||
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
||||
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
|
||||
str2statenames["rmsprop8bit_blockwise"] = [
|
||||
("square_avg", "state1", "qmap1", "absmax1")
|
||||
|
@ -114,14 +104,10 @@ str2statenames["rmsprop8bit_blockwise"] = [
|
|||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
|
||||
gtype = [torch.float32, torch.float16, torch.bfloat16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop"]
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
]
|
||||
|
||||
|
||||
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
||||
if dim1 == 1 and dim2 == 1:
|
||||
|
@ -135,6 +121,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
|
||||
if gtype == torch.float32:
|
||||
atol, rtol = 1e-6, 1e-5
|
||||
elif gtype == torch.bfloat16:
|
||||
atol, rtol = 1e-3, 1e-2
|
||||
else:
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
||||
|
@ -173,14 +161,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
rtol=rtol,
|
||||
)
|
||||
|
||||
if gtype == torch.float16:
|
||||
if gtype != torch.float32:
|
||||
# the adam buffers should also be close because they are 32-bit
|
||||
# but the paramters can diverge because they are 16-bit
|
||||
# the difference grow larger and larger with each update
|
||||
# --> copy the state to keep weights close
|
||||
p1.data = p1.data.half().float()
|
||||
p1.data = p1.data.to(p2.dtype).float()
|
||||
p2.copy_(p1.data)
|
||||
torch.testing.assert_allclose(p1.half(), p2)
|
||||
torch.testing.assert_allclose(p1.to(p2.dtype), p2)
|
||||
if optim_name in ["lars", "lamb"]:
|
||||
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
|
||||
|
||||
|
@ -246,7 +234,6 @@ optimizer_names = [
|
|||
"momentum8bit",
|
||||
"rmsprop8bit",
|
||||
"adam8bit_blockwise",
|
||||
"lars8bit",
|
||||
"momentum8bit_blockwise",
|
||||
"rmsprop8bit_blockwise",
|
||||
]
|
||||
|
@ -321,10 +308,10 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
relerr = err / torch.abs(p1)
|
||||
if g.dtype == torch.bfloat16:
|
||||
assert err.mean() < 0.00015
|
||||
assert relerr.mean() < 0.0015
|
||||
assert relerr.mean() < 0.0016
|
||||
else:
|
||||
assert err.mean() < 0.0001
|
||||
assert relerr.mean() < 0.001
|
||||
assert err.mean() < 0.00012
|
||||
assert relerr.mean() < 0.0012
|
||||
|
||||
errors.append(err.mean().item())
|
||||
relerrors.append(relerr.mean().item())
|
||||
|
|
Loading…
Reference in New Issue
Block a user