bitsandbytes-rocm/bitsandbytes/optim/rmsprop.py

116 lines
2.7 KiB
Python
Raw Normal View History

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
2021-10-06 02:16:20 +00:00
# LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State
2021-10-06 02:16:20 +00:00
class RMSprop(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
2021-10-06 02:16:20 +00:00
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
2021-10-06 02:16:20 +00:00
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
2021-10-06 02:16:20 +00:00
class RMSprop8bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
2021-10-06 02:16:20 +00:00
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
2021-10-06 02:16:20 +00:00
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
2021-10-06 02:16:20 +00:00
class RMSprop32bit(Optimizer1State):
def __init__(
self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
2021-10-06 02:16:20 +00:00
if alpha == 0:
raise NotImplementedError(
"RMSprop with alpha==0.0 is not supported!"
)
2021-10-06 02:16:20 +00:00
if centered:
raise NotImplementedError("Centered RMSprop is not supported!")
super().__init__(
"rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)