Added paged optimizers.

This commit is contained in:
Tim Dettmers 2023-05-06 14:59:29 -07:00
parent ec38ba95b0
commit 44d68ff29c
8 changed files with 157 additions and 266 deletions

View File

@ -27,7 +27,6 @@ try:
lib.get_context.restype = ct.c_void_p lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p
lib.cget_stream.restype = ct.c_void_p
COMPILED_WITH_CUDA = True COMPILED_WITH_CUDA = True
except AttributeError: except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. " warn("The installed version of bitsandbytes was compiled without GPU support. "

View File

@ -83,6 +83,27 @@ if COMPILED_WITH_CUDA:
lib.cadagrad_8bit_blockwise_fp16, lib.cadagrad_8bit_blockwise_fp16,
) )
class GlobalPageManager:
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.paged_tensors = []
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def prefetch_all(self, to_cpu=False):
for t in self.paged_tensors:
prefetch_tensor(t, to_cpu)
class CUBLAS_Context: class CUBLAS_Context:
_instance = None _instance = None
@ -142,7 +163,7 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0))
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape) new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)) out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape)
out.is_paged = True out.is_paged = True
out.page_deviceid = device.index out.page_deviceid = device.index
return out return out
@ -415,10 +436,14 @@ def is_on_gpu(tensors):
gpu_ids = set() gpu_ids = set()
for t in tensors: for t in tensors:
if t is None: continue # NULL pointers are fine if t is None: continue # NULL pointers are fine
on_gpu &= t.device.type == 'cuda' is_paged = getattr(t, 'is_paged', False)
gpu_ids.add(t.device.index) on_gpu &= (t.device.type == 'cuda' or is_paged)
if not is_paged:
gpu_ids.add(t.device.index)
if not on_gpu:
raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}')
if len(gpu_ids) > 1: if len(gpu_ids) > 1:
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:{[(t.shape, t.device) for t in tensors]}') raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
return on_gpu return on_gpu
def get_ptr(A: Tensor) -> ct.c_void_p: def get_ptr(A: Tensor) -> ct.c_void_p:

View File

@ -6,8 +6,8 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA from bitsandbytes.cextension import COMPILED_WITH_CUDA
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit
from .lamb import LAMB, LAMB8bit, LAMB32bit from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager from .optimizer import GlobalOptimManager

View File

@ -14,92 +14,34 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State): class Adam(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adam8bit(Optimizer2State): class Adam8bit(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adam32bit(Optimizer2State): class Adam32bit(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class AnalysisAdam(torch.optim.Optimizer): class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis. """Adam that performs 8-bit vs 32-bit error analysis.

View File

@ -5,89 +5,35 @@
from bitsandbytes.optim.optimizer import Optimizer2State from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
class AdamW8bit(Optimizer2State): class AdamW8bit(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW32bit(Optimizer2State): class AdamW32bit(Optimizer2State):
def __init__( def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
self, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
params, super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8, class PagedAdamW(Optimizer2State):
weight_decay=1e-2, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
amsgrad=False, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
args=None, super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
min_8bit_size=4096,
percentile_clipping=100, class PagedAdamW8bit(Optimizer2State):
block_wise=True, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
): args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super().__init__( super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
"adam",
params, class PagedAdamW32bit(Optimizer2State):
lr, def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
betas, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
eps, super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)

View File

@ -92,10 +92,12 @@ class GlobalOptimManager:
class Optimizer8bit(torch.optim.Optimizer): class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32): def __init__(self, params, defaults, optim_bits=32, is_paged=False):
super().__init__(params, defaults) super().__init__(params, defaults)
self.initialized = False self.initialized = False
self.name2qmap = {} self.name2qmap = {}
self.is_paged = is_paged
self.page_mng = F.GlobalPageManager.get_instance()
self.mng = GlobalOptimManager.get_instance() self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = { self.non_castable_tensor_keys = {
@ -207,7 +209,9 @@ class Optimizer8bit(torch.optim.Optimizer):
values = self.state[p] values = self.state[p]
for k, v in values.items(): for k, v in values.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
self.state[p][k] = v.to(p.device) is_paged = getattr(v, 'is_paged', False)
if not is_paged:
self.state[p][k] = v.to(p.device)
def check_overrides(self): def check_overrides(self):
for module, attr, config in self.mng.module_weight_config_triple: for module, attr, config in self.mng.module_weight_config_triple:
@ -252,6 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True self.initialized = True
if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]): for pindex, p in enumerate(group["params"]):
if p.grad is None: if p.grad is None:
@ -261,6 +266,11 @@ class Optimizer8bit(torch.optim.Optimizer):
self.init_state(group, p, gindex, pindex) self.init_state(group, p, gindex, pindex)
self.update_step(group, p, gindex, pindex) self.update_step(group, p, gindex, pindex)
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss return loss
@ -289,6 +299,16 @@ class Optimizer8bit(torch.optim.Optimizer):
"The update_step method needs to be overridden" "The update_step method needs to be overridden"
) )
def get_state_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 1e5:
return torch.zeros_like(p, dtype=dtype, device=p.device)
else:
# > 1 MB
buff = F.get_paged(*p.shape, dtype=dtype, device=p.device)
F.fill(buff, 0)
self.page_mng.paged_tensors.append(buff)
return buff
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):
def __init__( def __init__(
@ -306,6 +326,7 @@ class Optimizer2State(Optimizer8bit):
block_wise=True, block_wise=True,
max_unorm=0.0, max_unorm=0.0,
skip_zeros=False, skip_zeros=False,
is_paged=False
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
@ -325,7 +346,7 @@ class Optimizer2State(Optimizer8bit):
f"Invalid weight_decay value: {weight_decay}" f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits, is_paged)
if args is None: if args is None:
args = {} args = {}
@ -365,18 +386,8 @@ class Optimizer2State(Optimizer8bit):
if dtype == torch.float32 or ( if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096 dtype == torch.uint8 and p.numel() < 4096
): ):
state["state1"] = torch.zeros_like( state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
p, state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
@ -388,20 +399,10 @@ class Optimizer2State(Optimizer8bit):
p.device p.device
) )
state["state1"] = torch.zeros_like( state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"] state["qmap1"] = self.name2qmap["dynamic"]
state["state2"] = torch.zeros_like( state["state2"] = self.get_state_buffer(p, dtype=torch.uint8)
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap2"] = self.name2qmap["udynamic"] state["qmap2"] = self.name2qmap["udynamic"]
if config["block_wise"]: if config["block_wise"]:
@ -538,6 +539,7 @@ class Optimizer1State(Optimizer8bit):
block_wise=True, block_wise=True,
max_unorm=0.0, max_unorm=0.0,
skip_zeros=False, skip_zeros=False,
is_paged=False
): ):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
@ -553,7 +555,7 @@ class Optimizer1State(Optimizer8bit):
f"Invalid weight_decay value: {weight_decay}" f"Invalid weight_decay value: {weight_decay}"
) )
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits) super().__init__(params, defaults, optim_bits, is_paged)
if args is None: if args is None:
args = {} args = {}
@ -593,12 +595,7 @@ class Optimizer1State(Optimizer8bit):
if dtype == torch.float32 or ( if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096 dtype == torch.uint8 and p.numel() < 4096
): ):
state["state1"] = torch.zeros_like( state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state["step"] == 0: if state["step"] == 0:
if "dynamic" not in self.name2qmap: if "dynamic" not in self.name2qmap:
@ -607,12 +604,7 @@ class Optimizer1State(Optimizer8bit):
p.device p.device
) )
state["state1"] = torch.zeros_like( state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"] state["qmap1"] = self.name2qmap["dynamic"]
if config["block_wise"]: if config["block_wise"]:

View File

@ -172,8 +172,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
relerr = sum(reldiffs)/len(reldiffs) relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.011 assert abserr < 0.011
assert relerr < 0.018 assert relerr < 0.018
print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) #print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) #print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs = [] diffs = []
for i in range(100): for i in range(100):
@ -189,8 +189,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
relerr = sum(reldiffs)/len(reldiffs) relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035 assert abserr < 0.0035
assert relerr < 0.015 assert relerr < 0.015
print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization(): def test_dynamic_blockwise_stochastic_quantization():
@ -320,7 +320,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim2 = dim2 - (dim2 % 32) dim2 = dim2 - (dim2 % 32)
errors = [] errors = []
relerrors = [] relerrors = []
print("") #print("")
for i in range(5): for i in range(5):
if batched: if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
@ -349,8 +349,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
relerr = err / torch.abs(out2) relerr = err / torch.abs(out2)
errors.append(err.mean().item()) errors.append(err.mean().item())
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
print(mean(errors)) #print(mean(errors))
print(mean(relerrors)) #print(mean(relerrors))
def test_stable_embedding(): def test_stable_embedding():

View File

@ -39,6 +39,8 @@ str2optimizers["momentum_pytorch"] = (
bnb.optim.Adam, bnb.optim.Adam,
) )
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
@ -48,10 +50,7 @@ str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["adam8bit"] = ( str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = ( str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
@ -61,10 +60,9 @@ str2optimizers["rmsprop8bit"] = (
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
) )
str2optimizers["adam8bit_blockwise"] = ( str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
torch.optim.Adam, str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
)
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
@ -76,36 +74,25 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames = {} str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")] str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [ str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
("exp_avg", "state1", "qmap1", "max1"), str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
("exp_avg_sq", "state2", "qmap2", "max2"), str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
] str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["lamb8bit"] = [ str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
("exp_avg", "state1", "qmap1", "max1"), str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
("exp_avg_sq", "state2", "qmap2", "max2"), str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1")
]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [ str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
("square_avg", "state1", "qmap1", "absmax1")
]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16, torch.bfloat16] gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop"] optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@ -135,14 +122,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step() torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose( torch.testing.assert_close(
torch_optimizer.state[p1][name1], torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2], bnb_optimizer.state[p2][name2].cuda(),
atol=atol, atol=atol,
rtol=rtol, rtol=rtol,
) )
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
if i % (k // 5) == 0 and i > 0: if i % (k // 5) == 0 and i > 0:
path = get_temp_dir() path = get_temp_dir()
@ -152,9 +139,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose( torch.testing.assert_close(
torch_optimizer.state[p1][name1], torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2], bnb_optimizer.state[p2][name2],
atol=atol, atol=atol,
@ -168,7 +155,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
# --> copy the state to keep weights close # --> copy the state to keep weights close
p1.data = p1.data.to(p2.dtype).float() p1.data = p1.data.to(p2.dtype).float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(p2.dtype), p2) torch.testing.assert_close(p1.to(p2.dtype), p2)
if optim_name in ["lars", "lamb"]: if optim_name in ["lars", "lamb"]:
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
@ -277,7 +264,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step() bnb_optimizer.step()
torch_optimizer.step() torch_optimizer.step()
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
dequant_states = [] dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]: for name1, name2, qmap, max_val in str2statenames[optim_name]:
@ -331,8 +318,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
if "blockwise" in optim_name: if "blockwise" in optim_name:
s1 = F.dequantize_blockwise( s1 = F.dequantize_blockwise(
@ -347,17 +334,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
absmax=bnb_optimizer.state[p2][max_val], absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2], A=bnb_optimizer.state[p2][name2],
) )
torch.testing.assert_allclose(s1cpy, s1) torch.testing.assert_close(s1cpy, s1)
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
# the parameters diverge quickly. Here we keep them close # the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error # together so we can test against the Adam error
p1.data = p1.data.to(gtype).float() p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2) torch.testing.assert_close(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
torch_optimizer.state[p1][name1].copy_(s.data) torch_optimizer.state[p1][name1].copy_(s.data)
@ -419,28 +406,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32: if optim_bits == 32:
torch.testing.assert_allclose(p1, p2) torch.testing.assert_close(p1, p2)
torch.testing.assert_allclose( torch.testing.assert_close(
adam1.state[p1]["state1"], adam1.state[p1]["state1"],
adam2.state[p2]["state1"], adam2.state[p2]["state1"],
atol=5e-5, atol=5e-5,
rtol=1e-4, rtol=1e-4,
) )
torch.testing.assert_allclose( torch.testing.assert_close(
adam1.state[p1]["state2"], adam1.state[p1]["state2"],
adam2.state[p2]["state2"], adam2.state[p2]["state2"],
atol=5e-5, atol=5e-5,
rtol=1e-4, rtol=1e-4,
) )
elif optim_bits == 8: elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose( torch.testing.assert_close(
adam1.state[p1]["state1"], adam1.state[p1]["state1"],
adam2.state[p2]["state1"], adam2.state[p2]["state1"],
atol=2, atol=2,
rtol=1e-3, rtol=1e-3,
) )
torch.testing.assert_allclose( torch.testing.assert_close(
adam1.state[p1]["state2"], adam1.state[p1]["state2"],
adam2.state[p2]["state2"], adam2.state[p2]["state2"],
atol=2, atol=2,
@ -472,7 +459,7 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit'] # optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"] optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise']
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values