Added paged optimizers.

This commit is contained in:
Tim Dettmers 2023-05-06 14:59:29 -07:00
parent ec38ba95b0
commit 44d68ff29c
8 changed files with 157 additions and 266 deletions

View File

@ -27,7 +27,6 @@ try:
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
lib.cget_managed_ptr.restype = ct.c_void_p
lib.cget_stream.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "

View File

@ -83,6 +83,27 @@ if COMPILED_WITH_CUDA:
lib.cadagrad_8bit_blockwise_fp16,
)
class GlobalPageManager:
_instance = None
def __init__(self):
raise RuntimeError("Call get_instance() instead")
def initialize(self):
self.paged_tensors = []
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def prefetch_all(self, to_cpu=False):
for t in self.paged_tensors:
prefetch_tensor(t, to_cpu)
class CUBLAS_Context:
_instance = None
@ -142,7 +163,7 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0))
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape))
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape)
out.is_paged = True
out.page_deviceid = device.index
return out
@ -415,10 +436,14 @@ def is_on_gpu(tensors):
gpu_ids = set()
for t in tensors:
if t is None: continue # NULL pointers are fine
on_gpu &= t.device.type == 'cuda'
gpu_ids.add(t.device.index)
is_paged = getattr(t, 'is_paged', False)
on_gpu &= (t.device.type == 'cuda' or is_paged)
if not is_paged:
gpu_ids.add(t.device.index)
if not on_gpu:
raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}')
if len(gpu_ids) > 1:
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:{[(t.shape, t.device) for t in tensors]}')
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
return on_gpu
def get_ptr(A: Tensor) -> ct.c_void_p:

View File

@ -6,8 +6,8 @@
from bitsandbytes.cextension import COMPILED_WITH_CUDA
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .optimizer import GlobalOptimManager

View File

@ -14,92 +14,34 @@ from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Adam8bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class Adam32bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class PagedAdam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class AnalysisAdam(torch.optim.Optimizer):
"""Adam that performs 8-bit vs 32-bit error analysis.

View File

@ -5,89 +5,35 @@
from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
class AdamW8bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
class AdamW32bit(Optimizer2State):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super().__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
class PagedAdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
class PagedAdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)

View File

@ -92,10 +92,12 @@ class GlobalOptimManager:
class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32):
def __init__(self, params, defaults, optim_bits=32, is_paged=False):
super().__init__(params, defaults)
self.initialized = False
self.name2qmap = {}
self.is_paged = is_paged
self.page_mng = F.GlobalPageManager.get_instance()
self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = {
@ -207,7 +209,9 @@ class Optimizer8bit(torch.optim.Optimizer):
values = self.state[p]
for k, v in values.items():
if isinstance(v, torch.Tensor):
self.state[p][k] = v.to(p.device)
is_paged = getattr(v, 'is_paged', False)
if not is_paged:
self.state[p][k] = v.to(p.device)
def check_overrides(self):
for module, attr, config in self.mng.module_weight_config_triple:
@ -252,6 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
if self.is_paged: self.page_mng.prefetch_all()
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
@ -261,6 +266,11 @@ class Optimizer8bit(torch.optim.Optimizer):
self.init_state(group, p, gindex, pindex)
self.update_step(group, p, gindex, pindex)
if self.is_paged:
# all paged operation are asynchronous, we need
# to sync to make sure all tensors are in the right state
torch.cuda.synchronize()
return loss
@ -289,6 +299,16 @@ class Optimizer8bit(torch.optim.Optimizer):
"The update_step method needs to be overridden"
)
def get_state_buffer(self, p, dtype=torch.float32):
if not self.is_paged or p.numel() < 1e5:
return torch.zeros_like(p, dtype=dtype, device=p.device)
else:
# > 1 MB
buff = F.get_paged(*p.shape, dtype=dtype, device=p.device)
F.fill(buff, 0)
self.page_mng.paged_tensors.append(buff)
return buff
class Optimizer2State(Optimizer8bit):
def __init__(
@ -306,6 +326,7 @@ class Optimizer2State(Optimizer8bit):
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
is_paged=False
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
@ -325,7 +346,7 @@ class Optimizer2State(Optimizer8bit):
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits)
super().__init__(params, defaults, optim_bits, is_paged)
if args is None:
args = {}
@ -365,18 +386,8 @@ class Optimizer2State(Optimizer8bit):
if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
@ -388,20 +399,10 @@ class Optimizer2State(Optimizer8bit):
p.device
)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"]
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["state2"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap2"] = self.name2qmap["udynamic"]
if config["block_wise"]:
@ -538,6 +539,7 @@ class Optimizer1State(Optimizer8bit):
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
is_paged=False
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
@ -553,7 +555,7 @@ class Optimizer1State(Optimizer8bit):
f"Invalid weight_decay value: {weight_decay}"
)
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super().__init__(params, defaults, optim_bits)
super().__init__(params, defaults, optim_bits, is_paged)
if args is None:
args = {}
@ -593,12 +595,7 @@ class Optimizer1State(Optimizer8bit):
if dtype == torch.float32 or (
dtype == torch.uint8 and p.numel() < 4096
):
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
elif dtype == torch.uint8:
if state["step"] == 0:
if "dynamic" not in self.name2qmap:
@ -607,12 +604,7 @@ class Optimizer1State(Optimizer8bit):
p.device
)
state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
state["qmap1"] = self.name2qmap["dynamic"]
if config["block_wise"]:

View File

@ -172,8 +172,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.011
assert relerr < 0.018
print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs = []
for i in range(100):
@ -189,8 +189,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
relerr = sum(reldiffs)/len(reldiffs)
assert abserr < 0.0035
assert relerr < 0.015
print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
def test_dynamic_blockwise_stochastic_quantization():
@ -320,7 +320,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
dim2 = dim2 - (dim2 % 32)
errors = []
relerrors = []
print("")
#print("")
for i in range(5):
if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
@ -349,8 +349,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
relerr = err / torch.abs(out2)
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
print(mean(errors))
print(mean(relerrors))
#print(mean(errors))
#print(mean(relerrors))
def test_stable_embedding():

View File

@ -39,6 +39,8 @@ str2optimizers["momentum_pytorch"] = (
bnb.optim.Adam,
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
@ -48,10 +50,7 @@ str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
@ -61,10 +60,9 @@ str2optimizers["rmsprop8bit"] = (
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
@ -76,36 +74,25 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames["adam8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["lamb8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1")
]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [
("square_avg", "state1", "qmap1", "absmax1")
]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16, torch.bfloat16]
optimizer_names = ["adam", "momentum", "rmsprop"]
gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam']
values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@ -135,14 +122,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(
torch.testing.assert_close(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
bnb_optimizer.state[p2][name2].cuda(),
atol=atol,
rtol=rtol,
)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
if i % (k // 5) == 0 and i > 0:
path = get_temp_dir()
@ -152,9 +139,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(
torch.testing.assert_close(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
@ -168,7 +155,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
# --> copy the state to keep weights close
p1.data = p1.data.to(p2.dtype).float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(p2.dtype), p2)
torch.testing.assert_close(p1.to(p2.dtype), p2)
if optim_name in ["lars", "lamb"]:
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
@ -277,7 +264,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer.step()
torch_optimizer.step()
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
@ -331,8 +318,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path)
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(
@ -347,17 +334,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
torch.testing.assert_allclose(s1cpy, s1)
torch.testing.assert_close(s1cpy, s1)
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2)
torch.testing.assert_close(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
torch_optimizer.state[p1][name1].copy_(s.data)
@ -419,28 +406,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32:
torch.testing.assert_allclose(p1, p2)
torch.testing.assert_allclose(
torch.testing.assert_close(p1, p2)
torch.testing.assert_close(
adam1.state[p1]["state1"],
adam2.state[p2]["state1"],
atol=5e-5,
rtol=1e-4,
)
torch.testing.assert_allclose(
torch.testing.assert_close(
adam1.state[p1]["state2"],
adam2.state[p2]["state2"],
atol=5e-5,
rtol=1e-4,
)
elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose(
torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_close(
adam1.state[p1]["state1"],
adam2.state[p2]["state1"],
atol=2,
rtol=1e-3,
)
torch.testing.assert_allclose(
torch.testing.assert_close(
adam1.state[p1]["state2"],
adam2.state[p2]["state2"],
atol=2,
@ -472,7 +459,7 @@ gtype = [torch.float32, torch.float16]
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
# optimizer_names = ['lamb_apex', 'lamb8bit']
# optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ["adam8bit_blockwise"]
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise']
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values