Added PagedLion and bf16 Lion.

This commit is contained in:
Tim Dettmers 2023-05-23 19:37:38 -07:00
parent 2bce175d15
commit 1b8772a8f3
7 changed files with 46 additions and 97 deletions

View File

@ -37,10 +37,7 @@ if COMPILED_WITH_CUDA:
lib.crmsprop32bit_grad_32, lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_16, lib.crmsprop32bit_grad_16,
) )
str2optimizer32bit["lion"] = ( str2optimizer32bit["lion"] = (lib.clion32bit_grad_fp32, lib.clion32bit_grad_fp16, lib.clion32bit_grad_bf16)
lib.clion32bit_grad_32,
lib.clion32bit_grad_16,
)
str2optimizer32bit["adagrad"] = ( str2optimizer32bit["adagrad"] = (
lib.cadagrad32bit_grad_32, lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16, lib.cadagrad32bit_grad_16,
@ -89,6 +86,7 @@ if COMPILED_WITH_CUDA:
str2optimizer8bit_blockwise["lion"] = ( str2optimizer8bit_blockwise["lion"] = (
lib.clion_8bit_blockwise_grad_fp32, lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16, lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
) )
str2optimizer8bit_blockwise["adagrad"] = ( str2optimizer8bit_blockwise["adagrad"] = (
lib.cadagrad_8bit_blockwise_grad_fp32, lib.cadagrad_8bit_blockwise_grad_fp32,

View File

@ -12,5 +12,5 @@ from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
from .lion import Lion, Lion8bit, Lion32bit from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
from .sgd import SGD, SGD8bit, SGD32bit from .sgd import SGD, SGD8bit, SGD32bit

View File

@ -4,84 +4,27 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class Lion(Optimizer1State): class Lion(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion8bit(Optimizer1State): class Lion8bit(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"lion",
params,
lr,
betas,
0.,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Lion32bit(Optimizer1State): class Lion32bit(Optimizer1State):
def __init__( def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
self, super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
params,
lr=1e-4,
betas=(0.9, 0.99), class PagedLion(Optimizer1State):
weight_decay=0, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
args=None, super().__init__("lion", params, lr, betas, 0., weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
min_8bit_size=4096,
percentile_clipping=100, class PagedLion8bit(Optimizer1State):
block_wise=True, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
): super().__init__("lion", params, lr, betas, 0., weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
super().__init__(
"lion", class PagedLion32bit(Optimizer1State):
params, def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
lr, super().__init__("lion", params, lr, betas, 0., weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
betas,
0.,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)

View File

@ -3666,6 +3666,7 @@ MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(LION, half) MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float) MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
@ -3679,6 +3680,7 @@ MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float) MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(LION, half) MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float) MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
MAKE_Optimizer32bit1State(ADAGRAD, half) MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float) MAKE_Optimizer32bit1State(ADAGRAD, float)
@ -3852,5 +3854,6 @@ MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)

View File

@ -802,6 +802,7 @@ MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(LION, half) MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float) MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, float)
@ -837,6 +838,7 @@ MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION); MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION); MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);

View File

@ -51,8 +51,9 @@ MAKE_FUNC32(adam, ADAM, half, fp16)
MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16)
MAKE_FUNC32(rmsprop, RMSPROP, float, 32) MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
MAKE_FUNC32(rmsprop, RMSPROP, half, 16) MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
MAKE_FUNC32(lion, LION, float, 32) MAKE_FUNC32(lion, LION, float, fp32)
MAKE_FUNC32(lion, LION, half, 16) MAKE_FUNC32(lion, LION, half, fp16)
MAKE_FUNC32(lion, LION, __nv_bfloat16, bf16)
MAKE_FUNC32(adagrad, ADAGRAD, float, 32) MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
MAKE_FUNC32(adagrad, ADAGRAD, half, 16) MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
@ -95,6 +96,7 @@ MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, half, fp16) MAKE_BLOCKWISE8(lion, LION, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, fp32) MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); } void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
@ -201,8 +203,9 @@ extern "C"
MAKE_CFUNC32(momentum, half, 16) MAKE_CFUNC32(momentum, half, 16)
MAKE_CFUNC32(rmsprop, float, 32) MAKE_CFUNC32(rmsprop, float, 32)
MAKE_CFUNC32(rmsprop, half, 16) MAKE_CFUNC32(rmsprop, half, 16)
MAKE_CFUNC32(lion, float, 32) MAKE_CFUNC32(lion, float, fp32)
MAKE_CFUNC32(lion, half, 16) MAKE_CFUNC32(lion, half, fp16)
MAKE_CFUNC32(lion, __nv_bfloat16, bf16)
MAKE_CFUNC32(adagrad, float, 32) MAKE_CFUNC32(adagrad, float, 32)
MAKE_CFUNC32(adagrad, half, 16) MAKE_CFUNC32(adagrad, half, 16)
@ -245,6 +248,7 @@ extern "C"
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16) MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32) MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }

View File

@ -19,11 +19,11 @@ import bitsandbytes.functional as F
k = 20 k = 20
def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0):
idx = torch.isclose(a, b, rtol, atol) idx = torch.isclose(a, b, rtol=rtol, atol=atol)
error_count = (idx == 0).sum().item() error_count = (idx == 0).sum().item()
if error_count > max_error_count: if error_count > max_error_count:
print(f"Too many values not close: assert {error_count} < {max_error_count}") print(f"Too many values not close: assert {error_count} < {max_error_count}")
torch.testing.assert_close(a, b, rtol, atol) torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
def get_temp_dir(): def get_temp_dir():
@ -35,13 +35,8 @@ def get_temp_dir():
def rm_path(path): def rm_path(path):
shutil.rmtree(path) shutil.rmtree(path)
str2bf16support = {}
str2bf16support['adam8bit_blockwise'] = True
str2optimizers = {} str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = ( str2optimizers["momentum_pytorch"] = (
None, None,
@ -51,8 +46,8 @@ str2optimizers["momentum_pytorch"] = (
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["paged_lion"] = (Lion, bnb.optim.PagedLion)
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
@ -76,6 +71,7 @@ str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True))
str2optimizers["paged_lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.PagedLion8bit(pxx, block_wise=True))
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
@ -90,6 +86,7 @@ str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")] str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["paged_lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
@ -104,15 +101,17 @@ str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion'] optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion', 'paged_lion']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name in ['momentum', 'rmsprop']: pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
@ -254,7 +253,7 @@ names = [
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if gtype == torch.bfloat16 and optim_name not in str2bf16support: return if gtype == torch.bfloat16 and optim_name not in ['adam8bit_blockwise', 'lion8bit_blockwise']: pytest.skip()
if dim1 == 1 and dim2 == 1: if dim1 == 1 and dim2 == 1:
return return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
@ -485,7 +484,7 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit'] # optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise'] optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise', 'paged_lion8bit_blockwise']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values