bitsandbytes-rocm/bitsandbytes/optim/optimizer.py

688 lines
24 KiB
Python
Raw Normal View History

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
2021-10-06 02:16:20 +00:00
# LICENSE file in the root directory of this source tree.
from collections import abc as container_abcs
from collections import defaultdict
from copy import deepcopy
from itertools import chain
2021-10-06 02:16:20 +00:00
import torch
2021-10-06 02:16:20 +00:00
import bitsandbytes.functional as F
class MockArgs(object):
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
class GlobalOptimManager(object):
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
2021-10-06 02:16:20 +00:00
def initialize(self):
self.pid2config = {}
self.index2config = {}
self.optimizer = None
self.uses_config_override = False
self.module_weight_config_triple = []
2021-10-06 02:16:20 +00:00
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def register_parameters(self, params):
param_groups = list(params)
if not isinstance(param_groups[0], dict):
param_groups = [{"params": param_groups}]
2021-10-06 02:16:20 +00:00
for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group["params"]):
2021-10-06 02:16:20 +00:00
if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
"""
2021-10-06 02:16:20 +00:00
Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden
This can be both, optimizer parameters like "betas", or "lr" or it can be
8-bit specific paramters like "optim_bits", "percentile_clipping".
Parameters
----------
parameters : torch.Tensor or list(torch.Tensors)
The input parameters.
key : str
The hyperparamter to override.
value : object
The value for the hyperparamters.
key_value_dict : dict
A dictionary with multiple key-values to override.
"""
2021-10-06 02:16:20 +00:00
self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters]
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if key is not None and value is not None:
assert key_value_dict is None
key_value_dict = {key: value}
if key_value_dict is not None:
for p in parameters:
if id(p) in self.pid2config:
self.pid2config[id(p)].update(key_value_dict)
else:
self.pid2config[id(p)] = key_value_dict
2021-10-06 02:16:20 +00:00
def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config))
2021-10-06 02:16:20 +00:00
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults)
self.initialized = False
2021-10-06 02:16:20 +00:00
self.name2qmap = {}
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set(
[
"qmap1",
"qmap2",
"max1",
"max2",
"new_max1",
"new_max2",
"state1",
"state2",
"gnorm_vec",
"absmax1",
"absmax2",
"unorm_vec",
]
)
if optim_bits == 8:
self.fill_qmap()
2021-10-06 02:16:20 +00:00
def fill_qmap(self):
self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
2021-10-06 02:16:20 +00:00
def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state)
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict["param_groups"]
2021-10-06 02:16:20 +00:00
if len(groups) != len(saved_groups):
raise ValueError(
"loaded state dict has a different number of " "parameter groups"
)
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
2021-10-06 02:16:20 +00:00
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
2021-10-06 02:16:20 +00:00
# Update the state
id_map = {
old_id: p
for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in groups)),
)
}
2021-10-06 02:16:20 +00:00
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if param.is_floating_point() and value.dtype != torch.uint8:
value = value.to(param.dtype)
return value
elif isinstance(value, dict):
for k, v in value.items():
if k in self.non_castable_tensor_keys:
value[k] = v.to(param.device)
else:
value[k] = cast(param, v)
return value
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict["state"].items():
2021-10-06 02:16:20 +00:00
if k in id_map:
param = id_map[k]
state[param] = cast(param, v)
else:
state[k] = v
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group["params"] = group["params"]
2021-10-06 02:16:20 +00:00
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({"state": state, "param_groups": param_groups})
2021-10-06 02:16:20 +00:00
def to_gpu(self):
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
2021-10-06 02:16:20 +00:00
if p in self.state:
values = self.state[p]
for k, v in values.items():
if isinstance(v, torch.Tensor):
self.state[p][k] = v.to(p.device)
def check_overrides(self):
for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr)
assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(
pmodule, torch.Parameter
)
found = False
for gindex, group in enumerate(self.param_groups):
if found:
break
for pindex, p in enumerate(group["params"]):
if found:
break
if id(p) == id(pmodule):
# found the matching parameter
# init override
self.mng.pid2config[id(p)] = config
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
id(p)
]
found = True
2021-10-06 02:16:20 +00:00
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
overflows = []
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
2021-10-06 02:16:20 +00:00
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
2021-10-06 02:16:20 +00:00
if p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
self.init_state(group, p, gindex, pindex)
self.update_step(group, p, gindex, pindex)
return loss
def get_config(self, gindex, pindex, group):
config = {}
config["betas"] = group["betas"]
config["eps"] = group["eps"]
config["weight_decay"] = group["weight_decay"]
config["lr"] = group["lr"]
config["optim_bits"] = self.args.optim_bits
config["min_8bit_size"] = self.args.min_8bit_size
config["percentile_clipping"] = self.args.percentile_clipping
config["block_wise"] = self.args.block_wise
config["max_unorm"] = self.args.max_unorm
config["skip_zeros"] = self.args.skip_zeros
2021-10-06 02:16:20 +00:00
if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)])
return config
def init_state(self, group, p, gindex, pindex):
raise NotImplementedError(f"init_state method needs to be overidden")
2021-10-06 02:16:20 +00:00
def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(f"The update_step method needs to be overidden")
2021-10-06 02:16:20 +00:00
class Optimizer2State(Optimizer8bit):
def __init__(
self,
optimizer_name,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
):
2021-10-06 02:16:20 +00:00
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str):
2021-11-29 16:21:05 +00:00
# format: '(beta1, beta2)'
betas = betas.replace("(", "").replace(")", "").strip().split(",")
2021-11-29 16:21:05 +00:00
betas = [float(b) for b in betas]
2021-10-06 02:16:20 +00:00
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
2021-10-06 02:16:20 +00:00
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
args["max_unorm"] = max_unorm
args["skip_zeros"] = skip_zeros
2021-10-06 02:16:20 +00:00
self.args = MockArgs(args)
else:
self.args = args
self.optimizer_name = optimizer_name
@torch.no_grad()
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
if config["optim_bits"] == 32:
2021-10-06 02:16:20 +00:00
dtype = torch.float32
elif config["optim_bits"] == 8:
2021-10-06 02:16:20 +00:00
dtype = torch.uint8
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
2021-10-06 02:16:20 +00:00
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
2021-10-06 02:16:20 +00:00
state = self.state[p]
state["step"] = 0
2021-10-06 02:16:20 +00:00
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
2021-10-06 02:16:20 +00:00
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"]
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap2"] = self.name2qmap["udynamic"]
if config["block_wise"]:
2021-10-06 02:16:20 +00:00
n = p.numel()
blocks = n // 2048
2021-10-06 02:16:20 +00:00
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
state["absmax2"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
2021-10-06 02:16:20 +00:00
else:
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
2021-10-06 02:16:20 +00:00
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
2021-10-06 02:16:20 +00:00
if config["max_unorm"] > 0.0:
state["unorm_vec"] = torch.zeros((1,), device=p.device)
2021-10-06 02:16:20 +00:00
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
state = self.state[p]
grad = p.grad
config = self.get_config(gindex, pindex, group)
state["step"] += 1
step = state["step"]
2021-10-06 02:16:20 +00:00
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
)
2021-10-06 02:16:20 +00:00
else:
gnorm_scale = 1.0
if state["state1"].dtype == torch.float:
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
state["state1"],
config["betas"][0],
config["eps"],
step,
config["lr"],
state["state2"],
config["betas"][1],
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(
self.optimizer_name,
grad,
p,
state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["max1"],
state["max2"],
state["new_max1"],
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
2021-10-06 02:16:20 +00:00
# swap maxes
state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["absmax1"],
state["absmax2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
)
2021-10-06 02:16:20 +00:00
class Optimizer1State(Optimizer8bit):
def __init__(
self,
optimizer_name,
params,
lr=1e-3,
betas=(0.9, 0.0),
eps=1e-8,
weight_decay=0.0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
):
2021-10-06 02:16:20 +00:00
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
2021-10-06 02:16:20 +00:00
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
if args is None:
args = {}
args["optim_bits"] = optim_bits
args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
args["max_unorm"] = max_unorm
args["skip_zeros"] = skip_zeros
2021-10-06 02:16:20 +00:00
self.args = MockArgs(args)
else:
self.args = args
self.optimizer_name = optimizer_name
@torch.no_grad()
def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group)
if config["optim_bits"] == 32:
2021-10-06 02:16:20 +00:00
dtype = torch.float32
elif config["optim_bits"] == 8:
2021-10-06 02:16:20 +00:00
dtype = torch.uint8
else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
2021-10-06 02:16:20 +00:00
if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
2021-10-06 02:16:20 +00:00
state = self.state[p]
state["step"] = 0
2021-10-06 02:16:20 +00:00
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
2021-10-06 02:16:20 +00:00
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"]
if config["block_wise"]:
2021-10-06 02:16:20 +00:00
n = p.numel()
blocks = n // 2048
2021-10-06 02:16:20 +00:00
blocks += 1 if n % 2048 > 0 else 0
state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
2021-10-06 02:16:20 +00:00
else:
state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
2021-10-06 02:16:20 +00:00
if config["percentile_clipping"] < 100:
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
2021-10-06 02:16:20 +00:00
if config["max_unorm"] > 0.0:
state["unorm_vec"] = torch.zeros((1,), device=p.device)
2021-10-06 02:16:20 +00:00
@torch.no_grad()
def update_step(self, group, p, gindex, pindex):
state = self.state[p]
grad = p.grad
config = self.get_config(gindex, pindex, group)
state["step"] += 1
step = state["step"]
2021-10-06 02:16:20 +00:00
if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
)
2021-10-06 02:16:20 +00:00
else:
gnorm_scale = 1.0
if state["state1"].dtype == torch.float:
F.optimizer_update_32bit(
self.optimizer_name,
grad,
p,
state["state1"],
config["betas"][0],
config["eps"],
step,
config["lr"],
None,
0.0,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(
self.optimizer_name,
grad,
p,
state["state1"],
None,
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
None,
state["max1"],
None,
state["new_max1"],
None,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
F.optimizer_update_8bit_blockwise(
self.optimizer_name,
grad,
p,
state["state1"],
None,
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
None,
state["absmax1"],
None,
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
)