Added PagedLion and bf16 Lion.
This commit is contained in:
parent
2bce175d15
commit
1b8772a8f3
|
@ -37,10 +37,7 @@ if COMPILED_WITH_CUDA:
|
|||
lib.crmsprop32bit_grad_32,
|
||||
lib.crmsprop32bit_grad_16,
|
||||
)
|
||||
str2optimizer32bit["lion"] = (
|
||||
lib.clion32bit_grad_32,
|
||||
lib.clion32bit_grad_16,
|
||||
)
|
||||
str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16)
|
||||
str2optimizer32bit["adagrad"] = (
|
||||
lib.cadagrad32bit_grad_32,
|
||||
lib.cadagrad32bit_grad_16,
|
||||
|
@ -89,6 +86,7 @@ if COMPILED_WITH_CUDA:
|
|||
str2optimizer8bit_blockwise["lion"] = (
|
||||
lib.clion_8bit_blockwise_grad_fp32,
|
||||
lib.clion_8bit_blockwise_grad_fp16,
|
||||
lib.clion_8bit_blockwise_grad_bf16,
|
||||
)
|
||||
str2optimizer8bit_blockwise["adagrad"] = (
|
||||
lib.cadagrad_8bit_blockwise_grad_fp32,
|
||||
|
|
|
@ -12,5 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
|
|||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .optimizer import GlobalOptimManager
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .lion import Lion, Lion8bit, Lion32bit
|
||||
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
|
|
|
@ -4,84 +4,27 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
from bitsandbytes.optim.optimizer import Optimizer1State
|
||||
|
||||
|
||||
class Lion(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class Lion8bit(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class Lion32bit(Optimizer1State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.99),
|
||||
weight_decay=0,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"lion",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
0.,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
|
||||
class PagedLion(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedLion8bit(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedLion32bit(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
|
|
@ -3666,6 +3666,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
|
|||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
|
@ -3679,6 +3680,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half)
|
|||
MAKE_Optimizer32bit1State(RMSPROP, float)
|
||||
MAKE_Optimizer32bit1State(LION, half)
|
||||
MAKE_Optimizer32bit1State(LION, float)
|
||||
MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
|
@ -3852,5 +3854,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
|
|||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
|
||||
|
|
|
@ -802,6 +802,7 @@ MAKE_optimizer32bit(RMSPROP, half)
|
|||
MAKE_optimizer32bit(RMSPROP, float)
|
||||
MAKE_optimizer32bit(LION, half)
|
||||
MAKE_optimizer32bit(LION, float)
|
||||
MAKE_optimizer32bit(LION, __nv_bfloat16)
|
||||
MAKE_optimizer32bit(ADAGRAD, half)
|
||||
MAKE_optimizer32bit(ADAGRAD, float)
|
||||
|
||||
|
@ -837,6 +838,7 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
|
|||
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
|
||||
|
||||
|
|
|
@ -51,8 +51,9 @@ MAKE_FUNC32(adam, ADAM, half, fp16)
|
|||
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC32(lion, LION, float, 32)
|
||||
MAKE_FUNC32(lion, LION, half, 16)
|
||||
MAKE_FUNC32(lion, LION, float, fp32)
|
||||
MAKE_FUNC32(lion, LION, half, fp16)
|
||||
MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
|
||||
|
||||
|
@ -95,6 +96,7 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
|
|||
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_BLOCKWISE8(lion, LION, half, fp16)
|
||||
MAKE_BLOCKWISE8(lion, LION, float, fp32)
|
||||
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
|
||||
|
||||
|
||||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
|
@ -201,8 +203,9 @@ extern "C"
|
|||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
MAKE_CFUNC32(rmsprop, half, 16)
|
||||
MAKE_CFUNC32(lion, float, 32)
|
||||
MAKE_CFUNC32(lion, half, 16)
|
||||
MAKE_CFUNC32(lion, float, fp32)
|
||||
MAKE_CFUNC32(lion, half, fp16)
|
||||
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
|
||||
MAKE_CFUNC32(adagrad, float, 32)
|
||||
MAKE_CFUNC32(adagrad, half, 16)
|
||||
|
||||
|
@ -245,6 +248,7 @@ extern "C"
|
|||
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
|
||||
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
|
||||
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
|
||||
|
||||
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
|
||||
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
|
||||
|
|
|
@ -19,11 +19,11 @@ import bitsandbytes.functional as F
|
|||
k = 20
|
||||
|
||||
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
|
||||
idx = torch.isclose(a, b, rtol, atol)
|
||||
idx = torch.isclose(a, b, rtol=rtol, atol=atol)
|
||||
error_count = (idx == 0).sum().item()
|
||||
if error_count > max_error_count:
|
||||
print(f"Too many values not close: assert {error_count} < {max_error_count}")
|
||||
torch.testing.assert_close(a, b, rtol, atol)
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def get_temp_dir():
|
||||
|
@ -35,13 +35,8 @@ def get_temp_dir():
|
|||
def rm_path(path):
|
||||
shutil.rmtree(path)
|
||||
|
||||
str2bf16support = {}
|
||||
str2bf16support['adam8bit_blockwise'] = True
|
||||
|
||||
str2optimizers = {}
|
||||
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
|
||||
str2optimizers["momentum_pytorch"] = (
|
||||
None,
|
||||
|
@ -51,8 +46,8 @@ str2optimizers["momentum_pytorch"] = (
|
|||
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
||||
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
|
||||
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
|
||||
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
|
||||
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
|
||||
str2optimizers["momentum"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -76,6 +71,7 @@ str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.
|
|||
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
|
||||
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
|
||||
str2optimizers["momentum8bit_blockwise"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
||||
|
@ -90,6 +86,7 @@ str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
|||
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["lion"] = [("exp_avg", "state1")]
|
||||
str2statenames["paged_lion"] = [("exp_avg", "state1")]
|
||||
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["rmsprop"] = [("square_avg", "state1")]
|
||||
|
@ -104,15 +101,17 @@ str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1
|
|||
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
|
||||
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
|
||||
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
|
||||
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
|
||||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion']
|
||||
gtype = [torch.float32, torch.float16, torch.bfloat16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
||||
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
|
||||
if dim1 == 1 and dim2 == 1:
|
||||
return
|
||||
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
|
||||
|
@ -254,7 +253,7 @@ names = [
|
|||
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
||||
if gtype == torch.bfloat16 and optim_name not in str2bf16support: return
|
||||
if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
|
||||
if dim1 == 1 and dim2 == 1:
|
||||
return
|
||||
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
|
||||
|
@ -485,7 +484,7 @@ gtype = [torch.float32, torch.float16]
|
|||
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
|
||||
# optimizer_names = ['lamb_apex', 'lamb8bit']
|
||||
# optimizer_names = ['lars_apex', 'lars8bit']
|
||||
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise']
|
||||
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
|
|
Loading…
Reference in New Issue
Block a user