bitsandbytes-rocm/bitsandbytes/optim/lars.py

211 lines
5.5 KiB
Python
Raw Normal View History

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
2021-10-06 02:16:20 +00:00
# LICENSE file in the root directory of this source tree.
import torch
from torch.optim import Optimizer
2021-10-06 02:16:20 +00:00
from bitsandbytes.optim.optimizer import Optimizer1State
2021-10-06 02:16:20 +00:00
class LARS(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
2021-10-06 02:16:20 +00:00
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
super().__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
2021-10-06 02:16:20 +00:00
class LARS8bit(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
2021-10-06 02:16:20 +00:00
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
super().__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
2021-10-06 02:16:20 +00:00
class LARS32bit(Optimizer1State):
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
2021-10-06 02:16:20 +00:00
if momentum == 0:
raise NotImplementedError(
"LARS without momentum is not supported!"
)
super().__init__(
"lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
2021-10-06 02:16:20 +00:00
class PytorchLARS(Optimizer):
def __init__(
self,
params,
lr=0.01,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
max_unorm=0.02,
):
2021-10-06 02:16:20 +00:00
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
2021-10-06 02:16:20 +00:00
if momentum < 0.0:
raise ValueError(f"Invalid momentum value: {momentum}")
2021-10-06 02:16:20 +00:00
if weight_decay < 0.0:
raise ValueError(
f"Invalid weight_decay value: {weight_decay}"
)
2021-10-06 02:16:20 +00:00
defaults = dict(
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
max_unorm=max_unorm,
)
2021-10-06 02:16:20 +00:00
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening"
)
super().__init__(params, defaults)
2021-10-06 02:16:20 +00:00
def __setstate__(self, state):
super().__setstate__(state)
2021-10-06 02:16:20 +00:00
for group in self.param_groups:
group.setdefault("nesterov", False)
2021-10-06 02:16:20 +00:00
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
max_unorm = group["max_unorm"]
lr = group["lr"]
2021-10-06 02:16:20 +00:00
for p in group["params"]:
if p.grad is None:
continue
2021-10-06 02:16:20 +00:00
state = self.state[p]
d_p = p.grad
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
2021-10-06 02:16:20 +00:00
if momentum != 0:
buf = state.get("momentum_buffer", None)
2021-10-06 02:16:20 +00:00
if buf is None:
buf = torch.clone(d_p).detach()
state["momentum_buffer"] = buf
2021-10-06 02:16:20 +00:00
else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
update = d_p + buf * momentum
2021-10-06 02:16:20 +00:00
else:
update = buf
update_scale = 1.0
if max_unorm > 0.0:
assert p.dtype == torch.float32
pnorm = torch.norm(p.detach())
unorm = torch.norm(update)
if unorm > max_unorm * pnorm:
update_scale = max_unorm * pnorm / unorm
2021-10-06 02:16:20 +00:00
p.add_(update, alpha=-lr * update_scale)
2021-10-06 02:16:20 +00:00
return loss