make sure interface is correct

This commit is contained in:
Phil Wang 2023-03-09 09:45:33 -08:00
parent 7247cb4554
commit d43ea9722c

View File

@ -9,30 +9,21 @@ class Lion(Optimizer1State):
def __init__( def __init__(
self, self,
params, params,
lr=1e-2, lr=1e-4,
alpha=0.99, betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0, weight_decay=0,
momentum=0,
centered=False,
optim_bits=32, optim_bits=32,
args=None, args=None,
min_8bit_size=4096, min_8bit_size=4096,
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
(alpha, momentum), betas,
eps, 0.,
weight_decay, weight_decay,
optim_bits, optim_bits,
args, args,
@ -46,29 +37,20 @@ class Lion8bit(Optimizer1State):
def __init__( def __init__(
self, self,
params, params,
lr=1e-2, lr=1e-4,
alpha=0.99, betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0, weight_decay=0,
momentum=0,
centered=False,
args=None, args=None,
min_8bit_size=4096, min_8bit_size=4096,
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
(alpha, momentum), betas,
eps, 0.,
weight_decay, weight_decay,
8, 8,
args, args,
@ -82,30 +64,20 @@ class Lion32bit(Optimizer1State):
def __init__( def __init__(
self, self,
params, params,
lr=1e-2, lr=1e-4,
alpha=0.99, betas=(0.9, 0.99),
eps=1e-8,
weight_decay=0, weight_decay=0,
momentum=0,
centered=False,
args=None, args=None,
min_8bit_size=4096, min_8bit_size=4096,
percentile_clipping=100, percentile_clipping=100,
block_wise=True, block_wise=True,
): ):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__( super().__init__(
"rmsprop", "rmsprop",
params, params,
lr, lr,
(alpha, momentum), betas,
eps, 0.,
weight_decay, weight_decay,
32, 32,
args, args,