Added bf16 Adam.
This commit is contained in:
parent
8645d1f71c
commit
c4cfe4fbdd
7
Makefile
7
Makefile
|
@ -12,6 +12,7 @@ CUDA_VERSION:=
|
|||
endif
|
||||
|
||||
|
||||
|
||||
NVCC := $(CUDA_HOME)/bin/nvcc
|
||||
|
||||
###########################################
|
||||
|
@ -59,9 +60,9 @@ CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89
|
|||
CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90
|
||||
|
||||
|
||||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
all: $(BUILD_DIR) env
|
||||
$(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
|
||||
|
||||
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
|
|
|
@ -73,6 +73,7 @@ if COMPILED_WITH_CUDA:
|
|||
str2optimizer8bit_blockwise["adam"] = (
|
||||
lib.cadam_8bit_blockwise_fp32,
|
||||
lib.cadam_8bit_blockwise_fp16,
|
||||
lib.cadam_8bit_blockwise_bf16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["momentum"] = (
|
||||
lib.cmomentum_8bit_blockwise_fp32,
|
||||
|
@ -1125,51 +1126,42 @@ def optimizer_update_8bit_blockwise(
|
|||
skip_zeros=False,
|
||||
) -> None:
|
||||
|
||||
optim_func = None
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit_blockwise[optimizer_name][0](
|
||||
get_ptr(p),
|
||||
get_ptr(g),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
get_ptr(qmap1),
|
||||
get_ptr(qmap2),
|
||||
get_ptr(absmax1),
|
||||
get_ptr(absmax2),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][0]
|
||||
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit_blockwise[optimizer_name][1](
|
||||
get_ptr(p),
|
||||
get_ptr(g),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
get_ptr(qmap1),
|
||||
get_ptr(qmap2),
|
||||
get_ptr(absmax1),
|
||||
get_ptr(absmax2),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][1]
|
||||
elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and
|
||||
len(str2optimizer8bit_blockwise[optimizer_name])==3):
|
||||
optimizer_func = str2optimizer8bit_blockwise[optimizer_name][2]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
|
||||
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
|
||||
|
||||
prev_device = pre_call(g.device)
|
||||
optimizer_func(
|
||||
get_ptr(p),
|
||||
get_ptr(g),
|
||||
get_ptr(state1),
|
||||
get_ptr(state2),
|
||||
ct.c_float(beta1),
|
||||
ct.c_float(beta2),
|
||||
ct.c_float(eps),
|
||||
ct.c_int32(step),
|
||||
ct.c_float(lr),
|
||||
get_ptr(qmap1),
|
||||
get_ptr(qmap2),
|
||||
get_ptr(absmax1),
|
||||
get_ptr(absmax2),
|
||||
ct.c_float(weight_decay),
|
||||
ct.c_float(gnorm_scale),
|
||||
ct.c_bool(skip_zeros),
|
||||
ct.c_int32(g.numel()),
|
||||
)
|
||||
post_call(prev_device)
|
||||
|
||||
def percentile_clipping(
|
||||
grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
|
||||
|
|
|
@ -2988,6 +2988,8 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
|
|||
|
||||
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8)
|
||||
|
||||
|
||||
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
|
||||
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
|
||||
|
|
|
@ -741,3 +741,5 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
|
|||
|
||||
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
|
||||
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
|
||||
|
|
|
@ -57,19 +57,20 @@ MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
|
|||
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
|
||||
|
||||
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
void fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
|
||||
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
|
||||
|
||||
MAKE_BLOCKWISE8(adam, ADAM, half, 16)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, float, 32)
|
||||
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
|
||||
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16)
|
||||
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
|
||||
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
||||
|
||||
|
||||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
|
@ -194,20 +195,20 @@ extern "C"
|
|||
MAKE_CFUNC8(rmsprop, half, 16)
|
||||
|
||||
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
void c##fname##_8bit_blockwise_##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
|
||||
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, float, 32)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
{ fname##_8bit_blockwise_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
||||
|
||||
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
|
||||
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
|
||||
|
|
|
@ -26,6 +26,8 @@ def get_temp_dir():
|
|||
def rm_path(path):
|
||||
shutil.rmtree(path)
|
||||
|
||||
str2bf16support = {}
|
||||
str2bf16support['adam8bit_blockwise'] = True
|
||||
|
||||
str2optimizers = {}
|
||||
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||
|
@ -238,7 +240,7 @@ def test_global_config(dim1, dim2, gtype):
|
|||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
gtype = [torch.float32, torch.float16, torch.bfloat16]
|
||||
optimizer_names = [
|
||||
"adam8bit",
|
||||
"momentum8bit",
|
||||
|
@ -256,6 +258,7 @@ names = [
|
|||
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
||||
if gtype == torch.bfloat16 and optim_name not in str2bf16support: return
|
||||
if dim1 == 1 and dim2 == 1:
|
||||
return
|
||||
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
|
||||
|
@ -269,7 +272,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
if gtype == torch.float32:
|
||||
atol, rtol = 3e-3, 1e-3
|
||||
patol, prtol = 1e-5, 1e-3
|
||||
|
||||
elif gtype == torch.bfloat16:
|
||||
atol, rtol = 3e-3, 1e-3
|
||||
patol, prtol = 1e-4, 1e-2
|
||||
else:
|
||||
atol, rtol = 3e-3, 1e-3
|
||||
patol, prtol = 1e-5, 1e-3
|
||||
|
@ -314,8 +319,12 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
|
||||
err = torch.abs(p1 - p2)
|
||||
relerr = err / torch.abs(p1)
|
||||
assert err.mean() < 0.0001
|
||||
assert relerr.mean() < 0.001
|
||||
if g.dtype == torch.bfloat16:
|
||||
assert err.mean() < 0.00015
|
||||
assert relerr.mean() < 0.0015
|
||||
else:
|
||||
assert err.mean() < 0.0001
|
||||
assert relerr.mean() < 0.001
|
||||
|
||||
errors.append(err.mean().item())
|
||||
relerrors.append(relerr.mean().item())
|
||||
|
@ -335,12 +344,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
||||
rm_path(path)
|
||||
torch.testing.assert_allclose(
|
||||
raws1cpy, bnb_optimizer.state[p2][name2]
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
qmap1, bnb_optimizer.state[p2][qmap]
|
||||
)
|
||||
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
|
||||
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
|
||||
|
||||
if "blockwise" in optim_name:
|
||||
s1 = F.dequantize_blockwise(
|
||||
|
@ -357,28 +362,16 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
)
|
||||
torch.testing.assert_allclose(s1cpy, s1)
|
||||
|
||||
num_not_close = (
|
||||
torch.isclose(
|
||||
torch_optimizer.state[p1][name1],
|
||||
s1,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
== 0
|
||||
)
|
||||
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
|
||||
assert num_not_close.sum().item() < 20
|
||||
torch.testing.assert_allclose(
|
||||
p1, p2.float(), atol=patol, rtol=prtol
|
||||
)
|
||||
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
|
||||
|
||||
# the parameters diverge quickly. Here we keep them close
|
||||
# together so we can test against the Adam error
|
||||
p1.data = p1.data.to(gtype).float()
|
||||
p2.copy_(p1.data)
|
||||
torch.testing.assert_allclose(p1.to(gtype), p2)
|
||||
for (name1, name2, qmap, max_val), s in zip(
|
||||
str2statenames[optim_name], dequant_states
|
||||
):
|
||||
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
|
||||
torch_optimizer.state[p1][name1].copy_(s.data)
|
||||
|
||||
# print(sum(errors)/len(errors))
|
||||
|
|
Loading…
Reference in New Issue
Block a user