make sure interface is correct

This commit is contained in:
Phil Wang 2023-03-09 09:45:33 -08:00
parent 7247cb4554
commit d43ea9722c

View File

@ -9,30 +9,21 @@ class Lion(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
momentum=0,
centered=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
betas,
0.,
weight_decay,
optim_bits,
args,
@ -46,29 +37,20 @@ class Lion8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
betas,
0.,
weight_decay,
8,
args,
@ -82,30 +64,20 @@ class Lion32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
betas,
0.,
weight_decay,
32,
args,