From 44d68ff29cc19e54db13242e7f8cff3c7e4c5196 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sat, 6 May 2023 14:59:29 -0700 Subject: [PATCH] Added paged optimizers. --- bitsandbytes/cextension.py | 1 - bitsandbytes/functional.py | 33 ++++++++-- bitsandbytes/optim/__init__.py | 4 +- bitsandbytes/optim/adam.py | 104 +++++++----------------------- bitsandbytes/optim/adamw.py | 108 ++++++++------------------------ bitsandbytes/optim/optimizer.py | 72 ++++++++++----------- tests/test_functional.py | 14 ++--- tests/test_optim.py | 87 +++++++++++-------------- 8 files changed, 157 insertions(+), 266 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 17c2a46..29621c9 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -27,7 +27,6 @@ try: lib.get_context.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p - lib.cget_stream.restype = ct.c_void_p COMPILED_WITH_CUDA = True except AttributeError: warn("The installed version of bitsandbytes was compiled without GPU support. " diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f548475..a6ed675 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -83,6 +83,27 @@ if COMPILED_WITH_CUDA: lib.cadagrad_8bit_blockwise_fp16, ) +class GlobalPageManager: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.paged_tensors = [] + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def prefetch_all(self, to_cpu=False): + for t in self.paged_tensors: + prefetch_tensor(t, to_cpu) + + class CUBLAS_Context: _instance = None @@ -142,7 +163,7 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)) cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) new_array = np.ctypeslib.as_array(c_ptr, shape=shape) - out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)) + out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape) out.is_paged = True out.page_deviceid = device.index return out @@ -415,10 +436,14 @@ def is_on_gpu(tensors): gpu_ids = set() for t in tensors: if t is None: continue # NULL pointers are fine - on_gpu &= t.device.type == 'cuda' - gpu_ids.add(t.device.index) + is_paged = getattr(t, 'is_paged', False) + on_gpu &= (t.device.type == 'cuda' or is_paged) + if not is_paged: + gpu_ids.add(t.device.index) + if not on_gpu: + raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') if len(gpu_ids) > 1: - raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:{[(t.shape, t.device) for t in tensors]}') + raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') return on_gpu def get_ptr(A: Tensor) -> ct.c_void_p: diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index 8c8a8f4..994dae5 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -6,8 +6,8 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit -from .adam import Adam, Adam8bit, Adam32bit -from .adamw import AdamW, AdamW8bit, AdamW32bit +from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit +from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit from .lamb import LAMB, LAMB8bit, LAMB32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .optimizer import GlobalOptimManager diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 396aeb8..86981eb 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -14,92 +14,34 @@ from bitsandbytes.optim.optimizer import Optimizer2State class Adam(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) class Adam32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) +class PagedAdam(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 022e64c..21077f1 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -5,89 +5,35 @@ from bitsandbytes.optim.optimizer import Optimizer2State -class AdamW(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - optim_bits=32, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - optim_bits, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) +class AdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW8bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 8, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) - + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) class AdamW32bit(Optimizer2State): - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, - args=None, - min_8bit_size=4096, - percentile_clipping=100, - block_wise=True, - ): - super().__init__( - "adam", - params, - lr, - betas, - eps, - weight_decay, - 32, - args, - min_8bit_size, - percentile_clipping, - block_wise, - ) + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + + +class PagedAdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 867ad3d..4f8dcc7 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -92,10 +92,12 @@ class GlobalOptimManager: class Optimizer8bit(torch.optim.Optimizer): - def __init__(self, params, defaults, optim_bits=32): + def __init__(self, params, defaults, optim_bits=32, is_paged=False): super().__init__(params, defaults) self.initialized = False self.name2qmap = {} + self.is_paged = is_paged + self.page_mng = F.GlobalPageManager.get_instance() self.mng = GlobalOptimManager.get_instance() self.non_castable_tensor_keys = { @@ -207,7 +209,9 @@ class Optimizer8bit(torch.optim.Optimizer): values = self.state[p] for k, v in values.items(): if isinstance(v, torch.Tensor): - self.state[p][k] = v.to(p.device) + is_paged = getattr(v, 'is_paged', False) + if not is_paged: + self.state[p][k] = v.to(p.device) def check_overrides(self): for module, attr, config in self.mng.module_weight_config_triple: @@ -252,6 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer): self.to_gpu() # needed for fairseq pure fp16 training self.initialized = True + if self.is_paged: self.page_mng.prefetch_all() for gindex, group in enumerate(self.param_groups): for pindex, p in enumerate(group["params"]): if p.grad is None: @@ -261,6 +266,11 @@ class Optimizer8bit(torch.optim.Optimizer): self.init_state(group, p, gindex, pindex) self.update_step(group, p, gindex, pindex) + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + return loss @@ -289,6 +299,16 @@ class Optimizer8bit(torch.optim.Optimizer): "The update_step method needs to be overridden" ) + def get_state_buffer(self, p, dtype=torch.float32): + if not self.is_paged or p.numel() < 1e5: + return torch.zeros_like(p, dtype=dtype, device=p.device) + else: + # > 1 MB + buff = F.get_paged(*p.shape, dtype=dtype, device=p.device) + F.fill(buff, 0) + self.page_mng.paged_tensors.append(buff) + return buff + class Optimizer2State(Optimizer8bit): def __init__( @@ -306,6 +326,7 @@ class Optimizer2State(Optimizer8bit): block_wise=True, max_unorm=0.0, skip_zeros=False, + is_paged=False ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -325,7 +346,7 @@ class Optimizer2State(Optimizer8bit): f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits, is_paged) if args is None: args = {} @@ -365,18 +386,8 @@ class Optimizer2State(Optimizer8bit): if dtype == torch.float32 or ( dtype == torch.uint8 and p.numel() < 4096 ): - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) - state["state2"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) + state["state2"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: @@ -388,20 +399,10 @@ class Optimizer2State(Optimizer8bit): p.device ) - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] - state["state2"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state2"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap2"] = self.name2qmap["udynamic"] if config["block_wise"]: @@ -538,6 +539,7 @@ class Optimizer1State(Optimizer8bit): block_wise=True, max_unorm=0.0, skip_zeros=False, + is_paged=False ): if not 0.0 <= lr: raise ValueError(f"Invalid learning rate: {lr}") @@ -553,7 +555,7 @@ class Optimizer1State(Optimizer8bit): f"Invalid weight_decay value: {weight_decay}" ) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) - super().__init__(params, defaults, optim_bits) + super().__init__(params, defaults, optim_bits, is_paged) if args is None: args = {} @@ -593,12 +595,7 @@ class Optimizer1State(Optimizer8bit): if dtype == torch.float32 or ( dtype == torch.uint8 and p.numel() < 4096 ): - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.float32, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) elif dtype == torch.uint8: if state["step"] == 0: if "dynamic" not in self.name2qmap: @@ -607,12 +604,7 @@ class Optimizer1State(Optimizer8bit): p.device ) - state["state1"] = torch.zeros_like( - p, - memory_format=torch.preserve_format, - dtype=torch.uint8, - device=p.device, - ) + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) state["qmap1"] = self.name2qmap["dynamic"] if config["block_wise"]: diff --git a/tests/test_functional.py b/tests/test_functional.py index 145c267..6bda1a8 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -172,8 +172,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.011 assert relerr < 0.018 - print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) - print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) + #print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) diffs = [] for i in range(100): @@ -189,8 +189,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize): relerr = sum(reldiffs)/len(reldiffs) assert abserr < 0.0035 assert relerr < 0.015 - print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) - print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) def test_dynamic_blockwise_stochastic_quantization(): @@ -320,7 +320,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): dim2 = dim2 - (dim2 % 32) errors = [] relerrors = [] - print("") + #print("") for i in range(5): if batched: A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") @@ -349,8 +349,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched): relerr = err / torch.abs(out2) errors.append(err.mean().item()) relerrors.append(relerr.mean().item()) - print(mean(errors)) - print(mean(relerrors)) + #print(mean(errors)) + #print(mean(relerrors)) def test_stable_embedding(): diff --git a/tests/test_optim.py b/tests/test_optim.py index a13b332..a5ecb6e 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -39,6 +39,8 @@ str2optimizers["momentum_pytorch"] = ( bnb.optim.Adam, ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) +str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) +str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), @@ -48,10 +50,7 @@ str2optimizers["rmsprop"] = ( lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["adam8bit"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), -) +str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers["momentum8bit"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), @@ -61,10 +60,9 @@ str2optimizers["rmsprop8bit"] = ( lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), ) -str2optimizers["adam8bit_blockwise"] = ( - torch.optim.Adam, - lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), -) +str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) +str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -76,36 +74,25 @@ str2optimizers["rmsprop8bit_blockwise"] = ( str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["rmsprop"] = [("square_avg", "state1")] -str2statenames["adam8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1"), - ("exp_avg_sq", "state2", "qmap2", "max2"), -] -str2statenames["lamb8bit"] = [ - ("exp_avg", "state1", "qmap1", "max1"), - ("exp_avg_sq", "state2", "qmap2", "max2"), -] -str2statenames["adam8bit_blockwise"] = [ - ("exp_avg", "state1", "qmap1", "absmax1"), - ("exp_avg_sq", "state2", "qmap2", "absmax2"), -] -str2statenames["momentum8bit"] = [ - ("momentum_buffer", "state1", "qmap1", "max1") -] -str2statenames["momentum8bit_blockwise"] = [ - ("momentum_buffer", "state1", "qmap1", "absmax1") -] +str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] -str2statenames["rmsprop8bit_blockwise"] = [ - ("square_avg", "state1", "qmap1", "absmax1") -] +str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] dim1 = [1024] dim2 = [32, 1024, 4097, 1] -gtype = [torch.float32, torch.float16, torch.bfloat16] -optimizer_names = ["adam", "momentum", "rmsprop"] +gtype = [torch.float32, torch.float16] +optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -135,14 +122,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): torch_optimizer.step() for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose( + torch.testing.assert_close( torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2], + bnb_optimizer.state[p2][name2].cuda(), atol=atol, rtol=rtol, ) - torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) + torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol) if i % (k // 5) == 0 and i > 0: path = get_temp_dir() @@ -152,9 +139,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) + torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol) for name1, name2 in str2statenames[optim_name]: - torch.testing.assert_allclose( + torch.testing.assert_close( torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, @@ -168,7 +155,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): # --> copy the state to keep weights close p1.data = p1.data.to(p2.dtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.to(p2.dtype), p2) + torch.testing.assert_close(p1.to(p2.dtype), p2) if optim_name in ["lars", "lamb"]: assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 @@ -277,7 +264,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer.step() torch_optimizer.step() - torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol) dequant_states = [] for name1, name2, qmap, max_val in str2statenames[optim_name]: @@ -331,8 +318,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) rm_path(path) - torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) - torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) + torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) if "blockwise" in optim_name: s1 = F.dequantize_blockwise( @@ -347,17 +334,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name): absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], ) - torch.testing.assert_allclose(s1cpy, s1) + torch.testing.assert_close(s1cpy, s1) num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) assert num_not_close.sum().item() < 20 - torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) + torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol) # the parameters diverge quickly. Here we keep them close # together so we can test against the Adam error p1.data = p1.data.to(gtype).float() p2.copy_(p1.data) - torch.testing.assert_allclose(p1.to(gtype), p2) + torch.testing.assert_close(p1.to(gtype), p2) for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): torch_optimizer.state[p1][name1].copy_(s.data) @@ -419,28 +406,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state if optim_bits == 32: - torch.testing.assert_allclose(p1, p2) - torch.testing.assert_allclose( + torch.testing.assert_close(p1, p2) + torch.testing.assert_close( adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=5e-5, rtol=1e-4, ) - torch.testing.assert_allclose( + torch.testing.assert_close( adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=5e-5, rtol=1e-4, ) elif optim_bits == 8: - torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) - torch.testing.assert_allclose( + torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3) + torch.testing.assert_close( adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3, ) - torch.testing.assert_allclose( + torch.testing.assert_close( adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, @@ -472,7 +459,7 @@ gtype = [torch.float32, torch.float16] # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lars_apex', 'lars8bit'] -optimizer_names = ["adam8bit_blockwise"] +optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise'] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values