723 lines
24 KiB
Python
723 lines
24 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
from collections import abc as container_abcs
|
|
from collections import defaultdict
|
|
from copy import deepcopy
|
|
from itertools import chain
|
|
|
|
import torch
|
|
|
|
import bitsandbytes.functional as F
|
|
|
|
|
|
class MockArgs(object):
|
|
def __init__(self, initial_data):
|
|
for key in initial_data:
|
|
setattr(self, key, initial_data[key])
|
|
|
|
|
|
class GlobalOptimManager(object):
|
|
_instance = None
|
|
|
|
def __init__(self):
|
|
raise RuntimeError("Call get_instance() instead")
|
|
|
|
def initialize(self):
|
|
self.pid2config = {}
|
|
self.index2config = {}
|
|
self.optimizer = None
|
|
self.uses_config_override = False
|
|
self.module_weight_config_triple = []
|
|
|
|
@classmethod
|
|
def get_instance(cls):
|
|
if cls._instance is None:
|
|
cls._instance = cls.__new__(cls)
|
|
cls._instance.initialize()
|
|
return cls._instance
|
|
|
|
def register_parameters(self, params):
|
|
param_groups = list(params)
|
|
if not isinstance(param_groups[0], dict):
|
|
param_groups = [{"params": param_groups}]
|
|
|
|
for group_index, group in enumerate(param_groups):
|
|
for p_index, p in enumerate(group["params"]):
|
|
if id(p) in self.pid2config:
|
|
self.index2config[(group_index, p_index)] = self.pid2config[
|
|
id(p)
|
|
]
|
|
|
|
def override_config(
|
|
self, parameters, key=None, value=None, key_value_dict=None
|
|
):
|
|
"""
|
|
Overrides initial optimizer config for specific parameters.
|
|
|
|
The key-values of the optimizer config for the input parameters are overidden
|
|
This can be both, optimizer parameters like "betas", or "lr" or it can be
|
|
8-bit specific paramters like "optim_bits", "percentile_clipping".
|
|
|
|
Parameters
|
|
----------
|
|
parameters : torch.Tensor or list(torch.Tensors)
|
|
The input parameters.
|
|
key : str
|
|
The hyperparamter to override.
|
|
value : object
|
|
The value for the hyperparamters.
|
|
key_value_dict : dict
|
|
A dictionary with multiple key-values to override.
|
|
"""
|
|
self.uses_config_override = True
|
|
if isinstance(parameters, torch.nn.Parameter):
|
|
parameters = [parameters]
|
|
if isinstance(parameters, torch.Tensor):
|
|
parameters = [parameters]
|
|
if key is not None and value is not None:
|
|
assert key_value_dict is None
|
|
key_value_dict = {key: value}
|
|
|
|
if key_value_dict is not None:
|
|
for p in parameters:
|
|
if id(p) in self.pid2config:
|
|
self.pid2config[id(p)].update(key_value_dict)
|
|
else:
|
|
self.pid2config[id(p)] = key_value_dict
|
|
|
|
def register_module_override(self, module, param_name, config):
|
|
self.module_weight_config_triple.append((module, param_name, config))
|
|
|
|
|
|
class Optimizer8bit(torch.optim.Optimizer):
|
|
def __init__(self, params, defaults, optim_bits=32):
|
|
super(Optimizer8bit, self).__init__(params, defaults)
|
|
self.initialized = False
|
|
self.name2qmap = {}
|
|
|
|
self.mng = GlobalOptimManager.get_instance()
|
|
self.non_castable_tensor_keys = set(
|
|
[
|
|
"qmap1",
|
|
"qmap2",
|
|
"max1",
|
|
"max2",
|
|
"new_max1",
|
|
"new_max2",
|
|
"state1",
|
|
"state2",
|
|
"gnorm_vec",
|
|
"absmax1",
|
|
"absmax2",
|
|
"unorm_vec",
|
|
]
|
|
)
|
|
|
|
if optim_bits == 8:
|
|
self.fill_qmap()
|
|
|
|
def fill_qmap(self):
|
|
self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
|
|
self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
|
|
|
|
def __setstate__(self, state):
|
|
super(Optimizer8bit, self).__setstate__(state)
|
|
|
|
def load_state_dict(self, state_dict):
|
|
r"""Loads the optimizer state.
|
|
|
|
Args:
|
|
state_dict (dict): optimizer state. Should be an object returned
|
|
from a call to :meth:`state_dict`.
|
|
"""
|
|
# deepcopy, to be consistent with module API
|
|
state_dict = deepcopy(state_dict)
|
|
# Validate the state_dict
|
|
groups = self.param_groups
|
|
saved_groups = state_dict["param_groups"]
|
|
|
|
if len(groups) != len(saved_groups):
|
|
raise ValueError(
|
|
"loaded state dict has a different number of "
|
|
"parameter groups"
|
|
)
|
|
param_lens = (len(g["params"]) for g in groups)
|
|
saved_lens = (len(g["params"]) for g in saved_groups)
|
|
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
|
raise ValueError(
|
|
"loaded state dict contains a parameter group "
|
|
"that doesn't match the size of optimizer's group"
|
|
)
|
|
|
|
# Update the state
|
|
id_map = {
|
|
old_id: p
|
|
for old_id, p in zip(
|
|
chain.from_iterable((g["params"] for g in saved_groups)),
|
|
chain.from_iterable((g["params"] for g in groups)),
|
|
)
|
|
}
|
|
|
|
def cast(param, value):
|
|
r"""Make a deep copy of value, casting all tensors to device of param."""
|
|
if isinstance(value, torch.Tensor):
|
|
# Floating-point types are a bit special here. They are the only ones
|
|
# that are assumed to always match the type of params.
|
|
if param.is_floating_point() and value.dtype != torch.uint8:
|
|
value = value.to(param.dtype)
|
|
return value
|
|
elif isinstance(value, dict):
|
|
for k, v in value.items():
|
|
if k in self.non_castable_tensor_keys:
|
|
value[k] = v.to(param.device)
|
|
else:
|
|
value[k] = cast(param, v)
|
|
|
|
return value
|
|
elif isinstance(value, container_abcs.Iterable):
|
|
return type(value)(cast(param, v) for v in value)
|
|
else:
|
|
return value
|
|
|
|
# Copy state assigned to params (and cast tensors to appropriate types).
|
|
# State that is not assigned to params is copied as is (needed for
|
|
# backward compatibility).
|
|
state = defaultdict(dict)
|
|
for k, v in state_dict["state"].items():
|
|
if k in id_map:
|
|
param = id_map[k]
|
|
state[param] = cast(param, v)
|
|
else:
|
|
state[k] = v
|
|
|
|
# Update parameter groups, setting their 'params' value
|
|
def update_group(group, new_group):
|
|
new_group["params"] = group["params"]
|
|
return new_group
|
|
|
|
param_groups = [
|
|
update_group(g, ng) for g, ng in zip(groups, saved_groups)
|
|
]
|
|
self.__setstate__({"state": state, "param_groups": param_groups})
|
|
|
|
def to_gpu(self):
|
|
for gindex, group in enumerate(self.param_groups):
|
|
for pindex, p in enumerate(group["params"]):
|
|
if p in self.state:
|
|
values = self.state[p]
|
|
for k, v in values.items():
|
|
if isinstance(v, torch.Tensor):
|
|
self.state[p][k] = v.to(p.device)
|
|
|
|
def check_overrides(self):
|
|
for module, attr, config in self.mng.module_weight_config_triple:
|
|
pmodule = getattr(module, attr)
|
|
assert pmodule is not None
|
|
assert isinstance(pmodule, torch.Tensor) or isinstance(
|
|
pmodule, torch.Parameter
|
|
)
|
|
found = False
|
|
for gindex, group in enumerate(self.param_groups):
|
|
if found:
|
|
break
|
|
for pindex, p in enumerate(group["params"]):
|
|
if found:
|
|
break
|
|
if id(p) == id(pmodule):
|
|
# found the matching parameter
|
|
# init override
|
|
self.mng.pid2config[id(p)] = config
|
|
self.mng.index2config[
|
|
(gindex, pindex)
|
|
] = self.mng.pid2config[id(p)]
|
|
found = True
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
"""Performs a single optimization step.
|
|
|
|
Arguments:
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
and returns the loss.
|
|
"""
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
overflows = []
|
|
|
|
if not self.initialized:
|
|
self.check_overrides()
|
|
self.to_gpu() # needed for fairseq pure fp16 training
|
|
self.initialized = True
|
|
|
|
for gindex, group in enumerate(self.param_groups):
|
|
for pindex, p in enumerate(group["params"]):
|
|
if p.grad is None:
|
|
continue
|
|
state = self.state[p]
|
|
if len(state) == 0:
|
|
self.init_state(group, p, gindex, pindex)
|
|
|
|
self.update_step(group, p, gindex, pindex)
|
|
|
|
return loss
|
|
|
|
def get_config(self, gindex, pindex, group):
|
|
config = {}
|
|
config["betas"] = group["betas"]
|
|
config["eps"] = group["eps"]
|
|
config["weight_decay"] = group["weight_decay"]
|
|
config["lr"] = group["lr"]
|
|
config["optim_bits"] = self.args.optim_bits
|
|
config["min_8bit_size"] = self.args.min_8bit_size
|
|
config["percentile_clipping"] = self.args.percentile_clipping
|
|
config["block_wise"] = self.args.block_wise
|
|
config["max_unorm"] = self.args.max_unorm
|
|
config["skip_zeros"] = self.args.skip_zeros
|
|
|
|
if (gindex, pindex) in self.mng.index2config:
|
|
config.update(self.mng.index2config[(gindex, pindex)])
|
|
return config
|
|
|
|
def init_state(self, group, p, gindex, pindex):
|
|
raise NotImplementedError(f"init_state method needs to be overidden")
|
|
|
|
def update_step(self, group, p, gindex, pindex):
|
|
raise NotImplementedError(
|
|
f"The update_step method needs to be overidden"
|
|
)
|
|
|
|
|
|
class Optimizer2State(Optimizer8bit):
|
|
def __init__(
|
|
self,
|
|
optimizer_name,
|
|
params,
|
|
lr=1e-3,
|
|
betas=(0.9, 0.999),
|
|
eps=1e-8,
|
|
weight_decay=0.0,
|
|
optim_bits=32,
|
|
args=None,
|
|
min_8bit_size=4096,
|
|
percentile_clipping=100,
|
|
block_wise=True,
|
|
max_unorm=0.0,
|
|
skip_zeros=False,
|
|
):
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
if isinstance(betas, str):
|
|
# format: '(beta1, beta2)'
|
|
betas = betas.replace("(", "").replace(")", "").strip().split(",")
|
|
betas = [float(b) for b in betas]
|
|
for i in range(len(betas)):
|
|
if not 0.0 <= betas[i] < 1.0:
|
|
raise ValueError(
|
|
f"Invalid beta parameter at index {i}: {betas[i]}"
|
|
)
|
|
if not 0.0 <= weight_decay:
|
|
raise ValueError(
|
|
"Invalid weight_decay value: {}".format(weight_decay)
|
|
)
|
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
|
|
|
|
if args is None:
|
|
args = {}
|
|
args["optim_bits"] = optim_bits
|
|
args["percentile_clipping"] = 100
|
|
args["min_8bit_size"] = min_8bit_size
|
|
args["percentile_clipping"] = percentile_clipping
|
|
args["block_wise"] = block_wise
|
|
args["max_unorm"] = max_unorm
|
|
args["skip_zeros"] = skip_zeros
|
|
|
|
self.args = MockArgs(args)
|
|
else:
|
|
self.args = args
|
|
|
|
self.optimizer_name = optimizer_name
|
|
|
|
@torch.no_grad()
|
|
def init_state(self, group, p, gindex, pindex):
|
|
config = self.get_config(gindex, pindex, group)
|
|
|
|
if config["optim_bits"] == 32:
|
|
dtype = torch.float32
|
|
elif config["optim_bits"] == 8:
|
|
dtype = torch.uint8
|
|
else:
|
|
raise NotImplementedError(
|
|
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
|
|
)
|
|
|
|
if p.numel() < config["min_8bit_size"]:
|
|
dtype = torch.float32
|
|
|
|
state = self.state[p]
|
|
state["step"] = 0
|
|
|
|
if dtype == torch.float32 or (
|
|
dtype == torch.uint8 and p.numel() < 4096
|
|
):
|
|
state["state1"] = torch.zeros_like(
|
|
p,
|
|
memory_format=torch.preserve_format,
|
|
dtype=torch.float32,
|
|
device=p.device,
|
|
)
|
|
state["state2"] = torch.zeros_like(
|
|
p,
|
|
memory_format=torch.preserve_format,
|
|
dtype=torch.float32,
|
|
device=p.device,
|
|
)
|
|
elif dtype == torch.uint8:
|
|
if state["step"] == 0:
|
|
if "dynamic" not in self.name2qmap:
|
|
self.fill_qmap()
|
|
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
|
|
p.device
|
|
)
|
|
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(
|
|
p.device
|
|
)
|
|
|
|
state["state1"] = torch.zeros_like(
|
|
p,
|
|
memory_format=torch.preserve_format,
|
|
dtype=torch.uint8,
|
|
device=p.device,
|
|
)
|
|
state["qmap1"] = self.name2qmap["dynamic"]
|
|
|
|
state["state2"] = torch.zeros_like(
|
|
p,
|
|
memory_format=torch.preserve_format,
|
|
dtype=torch.uint8,
|
|
device=p.device,
|
|
)
|
|
state["qmap2"] = self.name2qmap["udynamic"]
|
|
|
|
if config["block_wise"]:
|
|
n = p.numel()
|
|
blocks = n // 2048
|
|
blocks += 1 if n % 2048 > 0 else 0
|
|
|
|
state["absmax1"] = torch.zeros(
|
|
(blocks,), dtype=torch.float32, device=p.device
|
|
)
|
|
state["absmax2"] = torch.zeros(
|
|
(blocks,), dtype=torch.float32, device=p.device
|
|
)
|
|
else:
|
|
state["max1"] = torch.zeros(
|
|
(1,), dtype=torch.float32, device=p.device
|
|
)
|
|
state["new_max1"] = torch.zeros(
|
|
(1,), dtype=torch.float32, device=p.device
|
|
)
|
|
state["max2"] = torch.zeros(
|
|
(1,), dtype=torch.float32, device=p.device
|
|
)
|
|
state["new_max2"] = torch.zeros(
|
|
(1,), dtype=torch.float32, device=p.device
|
|
)
|
|
|
|
if config["percentile_clipping"] < 100:
|
|
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
|
|
|
|
if config["max_unorm"] > 0.0:
|
|
state["unorm_vec"] = torch.zeros((1,), device=p.device)
|
|
|
|
@torch.no_grad()
|
|
def update_step(self, group, p, gindex, pindex):
|
|
state = self.state[p]
|
|
grad = p.grad
|
|
|
|
config = self.get_config(gindex, pindex, group)
|
|
|
|
state["step"] += 1
|
|
step = state["step"]
|
|
|
|
if config["percentile_clipping"] < 100:
|
|
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
|
|
grad, state["gnorm_vec"], step, config["percentile_clipping"]
|
|
)
|
|
else:
|
|
gnorm_scale = 1.0
|
|
|
|
if state["state1"].dtype == torch.float:
|
|
F.optimizer_update_32bit(
|
|
self.optimizer_name,
|
|
grad,
|
|
p,
|
|
state["state1"],
|
|
config["betas"][0],
|
|
config["eps"],
|
|
step,
|
|
config["lr"],
|
|
state["state2"],
|
|
config["betas"][1],
|
|
config["weight_decay"],
|
|
gnorm_scale,
|
|
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
|
|
max_unorm=config["max_unorm"],
|
|
skip_zeros=config["skip_zeros"],
|
|
)
|
|
|
|
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
|
|
F.optimizer_update_8bit(
|
|
self.optimizer_name,
|
|
grad,
|
|
p,
|
|
state["state1"],
|
|
state["state2"],
|
|
config["betas"][0],
|
|
config["betas"][1],
|
|
config["eps"],
|
|
step,
|
|
config["lr"],
|
|
state["qmap1"],
|
|
state["qmap2"],
|
|
state["max1"],
|
|
state["max2"],
|
|
state["new_max1"],
|
|
state["new_max2"],
|
|
config["weight_decay"],
|
|
gnorm_scale=gnorm_scale,
|
|
unorm_vec=state["unorm_vec"]
|
|
if config["max_unorm"] > 0.0
|
|
else None,
|
|
max_unorm=config["max_unorm"],
|
|
)
|
|
|
|
# swap maxes
|
|
state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
|
|
state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
|
|
elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
|
|
F.optimizer_update_8bit_blockwise(
|
|
self.optimizer_name,
|
|
grad,
|
|
p,
|
|
state["state1"],
|
|
state["state2"],
|
|
config["betas"][0],
|
|
config["betas"][1],
|
|
config["eps"],
|
|
step,
|
|
config["lr"],
|
|
state["qmap1"],
|
|
state["qmap2"],
|
|
state["absmax1"],
|
|
state["absmax2"],
|
|
config["weight_decay"],
|
|
gnorm_scale=gnorm_scale,
|
|
skip_zeros=config["skip_zeros"],
|
|
)
|
|
|
|
|
|
class Optimizer1State(Optimizer8bit):
|
|
def __init__(
|
|
self,
|
|
optimizer_name,
|
|
params,
|
|
lr=1e-3,
|
|
betas=(0.9, 0.0),
|
|
eps=1e-8,
|
|
weight_decay=0.0,
|
|
optim_bits=32,
|
|
args=None,
|
|
min_8bit_size=4096,
|
|
percentile_clipping=100,
|
|
block_wise=True,
|
|
max_unorm=0.0,
|
|
skip_zeros=False,
|
|
):
|
|
if not 0.0 <= lr:
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
if not 0.0 <= eps:
|
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
|
for i in range(len(betas)):
|
|
if not 0.0 <= betas[i] < 1.0:
|
|
raise ValueError(
|
|
f"Invalid beta parameter at index {i}: {betas[i]}"
|
|
)
|
|
if not 0.0 <= weight_decay:
|
|
raise ValueError(
|
|
"Invalid weight_decay value: {}".format(weight_decay)
|
|
)
|
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
|
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
|
|
|
|
if args is None:
|
|
args = {}
|
|
args["optim_bits"] = optim_bits
|
|
args["percentile_clipping"] = 100
|
|
args["min_8bit_size"] = min_8bit_size
|
|
args["percentile_clipping"] = percentile_clipping
|
|
args["block_wise"] = block_wise
|
|
args["max_unorm"] = max_unorm
|
|
args["skip_zeros"] = skip_zeros
|
|
|
|
self.args = MockArgs(args)
|
|
else:
|
|
self.args = args
|
|
|
|
self.optimizer_name = optimizer_name
|
|
|
|
@torch.no_grad()
|
|
def init_state(self, group, p, gindex, pindex):
|
|
config = self.get_config(gindex, pindex, group)
|
|
|
|
if config["optim_bits"] == 32:
|
|
dtype = torch.float32
|
|
elif config["optim_bits"] == 8:
|
|
dtype = torch.uint8
|
|
else:
|
|
raise NotImplementedError(
|
|
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
|
|
)
|
|
|
|
if p.numel() < config["min_8bit_size"]:
|
|
dtype = torch.float32
|
|
|
|
state = self.state[p]
|
|
state["step"] = 0
|
|
|
|
if dtype == torch.float32 or (
|
|
dtype == torch.uint8 and p.numel() < 4096
|
|
):
|
|
state["state1"] = torch.zeros_like(
|
|
p,
|
|
memory_format=torch.preserve_format,
|
|
dtype=torch.float32,
|
|
device=p.device,
|
|
)
|
|
elif dtype == torch.uint8:
|
|
if state["step"] == 0:
|
|
if "dynamic" not in self.name2qmap:
|
|
self.fill_qmap()
|
|
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(
|
|
p.device
|
|
)
|
|
|
|
state["state1"] = torch.zeros_like(
|
|
p,
|
|
memory_format=torch.preserve_format,
|
|
dtype=torch.uint8,
|
|
device=p.device,
|
|
)
|
|
state["qmap1"] = self.name2qmap["dynamic"]
|
|
|
|
if config["block_wise"]:
|
|
n = p.numel()
|
|
blocks = n // 2048
|
|
blocks += 1 if n % 2048 > 0 else 0
|
|
|
|
state["absmax1"] = torch.zeros(
|
|
(blocks,), dtype=torch.float32, device=p.device
|
|
)
|
|
else:
|
|
state["max1"] = torch.zeros(
|
|
(1,), dtype=torch.float32, device=p.device
|
|
)
|
|
state["new_max1"] = torch.zeros(
|
|
(1,), dtype=torch.float32, device=p.device
|
|
)
|
|
|
|
if config["percentile_clipping"] < 100:
|
|
state["gnorm_vec"] = torch.zeros((100,), device=p.device)
|
|
|
|
if config["max_unorm"] > 0.0:
|
|
state["unorm_vec"] = torch.zeros((1,), device=p.device)
|
|
|
|
@torch.no_grad()
|
|
def update_step(self, group, p, gindex, pindex):
|
|
state = self.state[p]
|
|
grad = p.grad
|
|
|
|
config = self.get_config(gindex, pindex, group)
|
|
|
|
state["step"] += 1
|
|
step = state["step"]
|
|
|
|
if config["percentile_clipping"] < 100:
|
|
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
|
|
grad, state["gnorm_vec"], step, config["percentile_clipping"]
|
|
)
|
|
else:
|
|
gnorm_scale = 1.0
|
|
|
|
if state["state1"].dtype == torch.float:
|
|
F.optimizer_update_32bit(
|
|
self.optimizer_name,
|
|
grad,
|
|
p,
|
|
state["state1"],
|
|
config["betas"][0],
|
|
config["eps"],
|
|
step,
|
|
config["lr"],
|
|
None,
|
|
0.0,
|
|
config["weight_decay"],
|
|
gnorm_scale,
|
|
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
|
|
max_unorm=config["max_unorm"],
|
|
skip_zeros=config["skip_zeros"],
|
|
)
|
|
|
|
elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
|
|
F.optimizer_update_8bit(
|
|
self.optimizer_name,
|
|
grad,
|
|
p,
|
|
state["state1"],
|
|
None,
|
|
config["betas"][0],
|
|
config["betas"][1],
|
|
config["eps"],
|
|
step,
|
|
config["lr"],
|
|
state["qmap1"],
|
|
None,
|
|
state["max1"],
|
|
None,
|
|
state["new_max1"],
|
|
None,
|
|
config["weight_decay"],
|
|
gnorm_scale,
|
|
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
|
|
max_unorm=config["max_unorm"],
|
|
)
|
|
|
|
state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
|
|
elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
|
|
F.optimizer_update_8bit_blockwise(
|
|
self.optimizer_name,
|
|
grad,
|
|
p,
|
|
state["state1"],
|
|
None,
|
|
config["betas"][0],
|
|
config["betas"][1],
|
|
config["eps"],
|
|
step,
|
|
config["lr"],
|
|
state["qmap1"],
|
|
None,
|
|
state["absmax1"],
|
|
None,
|
|
config["weight_decay"],
|
|
gnorm_scale=gnorm_scale,
|
|
skip_zeros=config["skip_zeros"],
|
|
)
|