Added paged optimizers.
This commit is contained in:
parent
ec38ba95b0
commit
44d68ff29c
|
@ -27,7 +27,6 @@ try:
|
|||
lib.get_context.restype = ct.c_void_p
|
||||
lib.get_cusparse.restype = ct.c_void_p
|
||||
lib.cget_managed_ptr.restype = ct.c_void_p
|
||||
lib.cget_stream.restype = ct.c_void_p
|
||||
COMPILED_WITH_CUDA = True
|
||||
except AttributeError:
|
||||
warn("The installed version of bitsandbytes was compiled without GPU support. "
|
||||
|
|
|
@ -83,6 +83,27 @@ if COMPILED_WITH_CUDA:
|
|||
lib.cadagrad_8bit_blockwise_fp16,
|
||||
)
|
||||
|
||||
class GlobalPageManager:
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError("Call get_instance() instead")
|
||||
|
||||
def initialize(self):
|
||||
self.paged_tensors = []
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls.__new__(cls)
|
||||
cls._instance.initialize()
|
||||
return cls._instance
|
||||
|
||||
def prefetch_all(self, to_cpu=False):
|
||||
for t in self.paged_tensors:
|
||||
prefetch_tensor(t, to_cpu)
|
||||
|
||||
|
||||
|
||||
class CUBLAS_Context:
|
||||
_instance = None
|
||||
|
@ -142,7 +163,7 @@ def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0))
|
|||
cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes))
|
||||
c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int))
|
||||
new_array = np.ctypeslib.as_array(c_ptr, shape=shape)
|
||||
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape))
|
||||
out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape)
|
||||
out.is_paged = True
|
||||
out.page_deviceid = device.index
|
||||
return out
|
||||
|
@ -415,10 +436,14 @@ def is_on_gpu(tensors):
|
|||
gpu_ids = set()
|
||||
for t in tensors:
|
||||
if t is None: continue # NULL pointers are fine
|
||||
on_gpu &= t.device.type == 'cuda'
|
||||
gpu_ids.add(t.device.index)
|
||||
is_paged = getattr(t, 'is_paged', False)
|
||||
on_gpu &= (t.device.type == 'cuda' or is_paged)
|
||||
if not is_paged:
|
||||
gpu_ids.add(t.device.index)
|
||||
if not on_gpu:
|
||||
raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}')
|
||||
if len(gpu_ids) > 1:
|
||||
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:{[(t.shape, t.device) for t in tensors]}')
|
||||
raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}')
|
||||
return on_gpu
|
||||
|
||||
def get_ptr(A: Tensor) -> ct.c_void_p:
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
from bitsandbytes.cextension import COMPILED_WITH_CUDA
|
||||
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
from .adam import Adam, Adam8bit, Adam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
||||
from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .optimizer import GlobalOptimManager
|
||||
|
|
|
@ -14,92 +14,34 @@ from bitsandbytes.optim.optimizer import Optimizer2State
|
|||
|
||||
|
||||
class Adam(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class Adam8bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class Adam32bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=0,
|
||||
amsgrad=False,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
class PagedAdam(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedAdam8bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedAdam32bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class AnalysisAdam(torch.optim.Optimizer):
|
||||
"""Adam that performs 8-bit vs 32-bit error analysis.
|
||||
|
|
|
@ -5,89 +5,35 @@
|
|||
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||
|
||||
|
||||
class AdamW(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
optim_bits=32,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
optim_bits,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
class AdamW(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
|
||||
|
||||
class AdamW8bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
8,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
|
||||
|
||||
class AdamW32bit(Optimizer2State):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
amsgrad=False,
|
||||
args=None,
|
||||
min_8bit_size=4096,
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
params,
|
||||
lr,
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
32,
|
||||
args,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
)
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged)
|
||||
|
||||
|
||||
class PagedAdamW(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedAdamW8bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
class PagedAdamW32bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,
|
||||
args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True)
|
||||
|
||||
|
|
|
@ -92,10 +92,12 @@ class GlobalOptimManager:
|
|||
|
||||
|
||||
class Optimizer8bit(torch.optim.Optimizer):
|
||||
def __init__(self, params, defaults, optim_bits=32):
|
||||
def __init__(self, params, defaults, optim_bits=32, is_paged=False):
|
||||
super().__init__(params, defaults)
|
||||
self.initialized = False
|
||||
self.name2qmap = {}
|
||||
self.is_paged = is_paged
|
||||
self.page_mng = F.GlobalPageManager.get_instance()
|
||||
|
||||
self.mng = GlobalOptimManager.get_instance()
|
||||
self.non_castable_tensor_keys = {
|
||||
|
@ -207,7 +209,9 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
values = self.state[p]
|
||||
for k, v in values.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
self.state[p][k] = v.to(p.device)
|
||||
is_paged = getattr(v, 'is_paged', False)
|
||||
if not is_paged:
|
||||
self.state[p][k] = v.to(p.device)
|
||||
|
||||
def check_overrides(self):
|
||||
for module, attr, config in self.mng.module_weight_config_triple:
|
||||
|
@ -252,6 +256,7 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
self.to_gpu() # needed for fairseq pure fp16 training
|
||||
self.initialized = True
|
||||
|
||||
if self.is_paged: self.page_mng.prefetch_all()
|
||||
for gindex, group in enumerate(self.param_groups):
|
||||
for pindex, p in enumerate(group["params"]):
|
||||
if p.grad is None:
|
||||
|
@ -261,6 +266,11 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
self.init_state(group, p, gindex, pindex)
|
||||
|
||||
self.update_step(group, p, gindex, pindex)
|
||||
if self.is_paged:
|
||||
# all paged operation are asynchronous, we need
|
||||
# to sync to make sure all tensors are in the right state
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
return loss
|
||||
|
||||
|
@ -289,6 +299,16 @@ class Optimizer8bit(torch.optim.Optimizer):
|
|||
"The update_step method needs to be overridden"
|
||||
)
|
||||
|
||||
def get_state_buffer(self, p, dtype=torch.float32):
|
||||
if not self.is_paged or p.numel() < 1e5:
|
||||
return torch.zeros_like(p, dtype=dtype, device=p.device)
|
||||
else:
|
||||
# > 1 MB
|
||||
buff = F.get_paged(*p.shape, dtype=dtype, device=p.device)
|
||||
F.fill(buff, 0)
|
||||
self.page_mng.paged_tensors.append(buff)
|
||||
return buff
|
||||
|
||||
|
||||
class Optimizer2State(Optimizer8bit):
|
||||
def __init__(
|
||||
|
@ -306,6 +326,7 @@ class Optimizer2State(Optimizer8bit):
|
|||
block_wise=True,
|
||||
max_unorm=0.0,
|
||||
skip_zeros=False,
|
||||
is_paged=False
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
|
@ -325,7 +346,7 @@ class Optimizer2State(Optimizer8bit):
|
|||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults, optim_bits)
|
||||
super().__init__(params, defaults, optim_bits, is_paged)
|
||||
|
||||
if args is None:
|
||||
args = {}
|
||||
|
@ -365,18 +386,8 @@ class Optimizer2State(Optimizer8bit):
|
|||
if dtype == torch.float32 or (
|
||||
dtype == torch.uint8 and p.numel() < 4096
|
||||
):
|
||||
state["state1"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.float32,
|
||||
device=p.device,
|
||||
)
|
||||
state["state2"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.float32,
|
||||
device=p.device,
|
||||
)
|
||||
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
|
||||
state["state2"] = self.get_state_buffer(p, dtype=torch.float32)
|
||||
elif dtype == torch.uint8:
|
||||
if state["step"] == 0:
|
||||
if "dynamic" not in self.name2qmap:
|
||||
|
@ -388,20 +399,10 @@ class Optimizer2State(Optimizer8bit):
|
|||
p.device
|
||||
)
|
||||
|
||||
state["state1"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.uint8,
|
||||
device=p.device,
|
||||
)
|
||||
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
|
||||
state["qmap1"] = self.name2qmap["dynamic"]
|
||||
|
||||
state["state2"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.uint8,
|
||||
device=p.device,
|
||||
)
|
||||
state["state2"] = self.get_state_buffer(p, dtype=torch.uint8)
|
||||
state["qmap2"] = self.name2qmap["udynamic"]
|
||||
|
||||
if config["block_wise"]:
|
||||
|
@ -538,6 +539,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
block_wise=True,
|
||||
max_unorm=0.0,
|
||||
skip_zeros=False,
|
||||
is_paged=False
|
||||
):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
|
@ -553,7 +555,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
f"Invalid weight_decay value: {weight_decay}"
|
||||
)
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
super().__init__(params, defaults, optim_bits)
|
||||
super().__init__(params, defaults, optim_bits, is_paged)
|
||||
|
||||
if args is None:
|
||||
args = {}
|
||||
|
@ -593,12 +595,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
if dtype == torch.float32 or (
|
||||
dtype == torch.uint8 and p.numel() < 4096
|
||||
):
|
||||
state["state1"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.float32,
|
||||
device=p.device,
|
||||
)
|
||||
state["state1"] = self.get_state_buffer(p, dtype=torch.float32)
|
||||
elif dtype == torch.uint8:
|
||||
if state["step"] == 0:
|
||||
if "dynamic" not in self.name2qmap:
|
||||
|
@ -607,12 +604,7 @@ class Optimizer1State(Optimizer8bit):
|
|||
p.device
|
||||
)
|
||||
|
||||
state["state1"] = torch.zeros_like(
|
||||
p,
|
||||
memory_format=torch.preserve_format,
|
||||
dtype=torch.uint8,
|
||||
device=p.device,
|
||||
)
|
||||
state["state1"] = self.get_state_buffer(p, dtype=torch.uint8)
|
||||
state["qmap1"] = self.name2qmap["dynamic"]
|
||||
|
||||
if config["block_wise"]:
|
||||
|
|
|
@ -172,8 +172,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
|
|||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.011
|
||||
assert relerr < 0.018
|
||||
print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
|
||||
print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
#print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
diffs = []
|
||||
for i in range(100):
|
||||
|
@ -189,8 +189,8 @@ def test_dynamic_blockwise_quantization(nested, blocksize):
|
|||
relerr = sum(reldiffs)/len(reldiffs)
|
||||
assert abserr < 0.0035
|
||||
assert relerr < 0.015
|
||||
print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
||||
print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
#print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs))
|
||||
#print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs))
|
||||
|
||||
|
||||
def test_dynamic_blockwise_stochastic_quantization():
|
||||
|
@ -320,7 +320,7 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
|||
dim2 = dim2 - (dim2 % 32)
|
||||
errors = []
|
||||
relerrors = []
|
||||
print("")
|
||||
#print("")
|
||||
for i in range(5):
|
||||
if batched:
|
||||
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
|
||||
|
@ -349,8 +349,8 @@ def test_approx_igemm(dim1, dim2, quant_methods, batched):
|
|||
relerr = err / torch.abs(out2)
|
||||
errors.append(err.mean().item())
|
||||
relerrors.append(relerr.mean().item())
|
||||
print(mean(errors))
|
||||
print(mean(relerrors))
|
||||
#print(mean(errors))
|
||||
#print(mean(relerrors))
|
||||
|
||||
|
||||
def test_stable_embedding():
|
||||
|
|
|
@ -39,6 +39,8 @@ str2optimizers["momentum_pytorch"] = (
|
|||
bnb.optim.Adam,
|
||||
)
|
||||
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
||||
str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW)
|
||||
str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam)
|
||||
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers["momentum"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
|
@ -48,10 +50,7 @@ str2optimizers["rmsprop"] = (
|
|||
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
str2optimizers["adam8bit"] = (
|
||||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
|
||||
)
|
||||
str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
|
||||
str2optimizers["momentum8bit"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -61,10 +60,9 @@ str2optimizers["rmsprop8bit"] = (
|
|||
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||
)
|
||||
|
||||
str2optimizers["adam8bit_blockwise"] = (
|
||||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
|
||||
)
|
||||
str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True))
|
||||
str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True))
|
||||
str2optimizers["momentum8bit_blockwise"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
||||
|
@ -76,36 +74,25 @@ str2optimizers["rmsprop8bit_blockwise"] = (
|
|||
|
||||
str2statenames = {}
|
||||
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["rmsprop"] = [("square_avg", "state1")]
|
||||
str2statenames["adam8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||
]
|
||||
str2statenames["lamb8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||
]
|
||||
str2statenames["adam8bit_blockwise"] = [
|
||||
("exp_avg", "state1", "qmap1", "absmax1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
||||
]
|
||||
str2statenames["momentum8bit"] = [
|
||||
("momentum_buffer", "state1", "qmap1", "max1")
|
||||
]
|
||||
str2statenames["momentum8bit_blockwise"] = [
|
||||
("momentum_buffer", "state1", "qmap1", "absmax1")
|
||||
]
|
||||
str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
|
||||
str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")]
|
||||
str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
|
||||
str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
|
||||
str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")]
|
||||
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
||||
str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")]
|
||||
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
|
||||
str2statenames["rmsprop8bit_blockwise"] = [
|
||||
("square_avg", "state1", "qmap1", "absmax1")
|
||||
]
|
||||
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
|
||||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16, torch.bfloat16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop"]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam']
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
|
@ -135,14 +122,14 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
torch_optimizer.step()
|
||||
|
||||
for name1, name2 in str2statenames[optim_name]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
torch_optimizer.state[p1][name1],
|
||||
bnb_optimizer.state[p2][name2],
|
||||
bnb_optimizer.state[p2][name2].cuda(),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
|
||||
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
|
||||
|
||||
if i % (k // 5) == 0 and i > 0:
|
||||
path = get_temp_dir()
|
||||
|
@ -152,9 +139,9 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
||||
rm_path(path)
|
||||
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(p1, p2.float(), atol=atol, rtol=rtol)
|
||||
for name1, name2 in str2statenames[optim_name]:
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
torch_optimizer.state[p1][name1],
|
||||
bnb_optimizer.state[p2][name2],
|
||||
atol=atol,
|
||||
|
@ -168,7 +155,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
# --> copy the state to keep weights close
|
||||
p1.data = p1.data.to(p2.dtype).float()
|
||||
p2.copy_(p1.data)
|
||||
torch.testing.assert_allclose(p1.to(p2.dtype), p2)
|
||||
torch.testing.assert_close(p1.to(p2.dtype), p2)
|
||||
if optim_name in ["lars", "lamb"]:
|
||||
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
|
||||
|
||||
|
@ -277,7 +264,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer.step()
|
||||
torch_optimizer.step()
|
||||
|
||||
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
|
||||
torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
|
||||
|
||||
dequant_states = []
|
||||
for name1, name2, qmap, max_val in str2statenames[optim_name]:
|
||||
|
@ -331,8 +318,8 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
||||
rm_path(path)
|
||||
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
|
||||
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
|
||||
torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2])
|
||||
torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap])
|
||||
|
||||
if "blockwise" in optim_name:
|
||||
s1 = F.dequantize_blockwise(
|
||||
|
@ -347,17 +334,17 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
|||
absmax=bnb_optimizer.state[p2][max_val],
|
||||
A=bnb_optimizer.state[p2][name2],
|
||||
)
|
||||
torch.testing.assert_allclose(s1cpy, s1)
|
||||
torch.testing.assert_close(s1cpy, s1)
|
||||
|
||||
num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0)
|
||||
assert num_not_close.sum().item() < 20
|
||||
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
|
||||
torch.testing.assert_close(p1, p2.float(), atol=patol, rtol=prtol)
|
||||
|
||||
# the parameters diverge quickly. Here we keep them close
|
||||
# together so we can test against the Adam error
|
||||
p1.data = p1.data.to(gtype).float()
|
||||
p2.copy_(p1.data)
|
||||
torch.testing.assert_allclose(p1.to(gtype), p2)
|
||||
torch.testing.assert_close(p1.to(gtype), p2)
|
||||
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
|
||||
torch_optimizer.state[p1][name1].copy_(s.data)
|
||||
|
||||
|
@ -419,28 +406,28 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
|
|||
|
||||
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
|
||||
if optim_bits == 32:
|
||||
torch.testing.assert_allclose(p1, p2)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(p1, p2)
|
||||
torch.testing.assert_close(
|
||||
adam1.state[p1]["state1"],
|
||||
adam2.state[p2]["state1"],
|
||||
atol=5e-5,
|
||||
rtol=1e-4,
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
adam1.state[p1]["state2"],
|
||||
adam2.state[p2]["state2"],
|
||||
atol=5e-5,
|
||||
rtol=1e-4,
|
||||
)
|
||||
elif optim_bits == 8:
|
||||
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3)
|
||||
torch.testing.assert_close(
|
||||
adam1.state[p1]["state1"],
|
||||
adam2.state[p2]["state1"],
|
||||
atol=2,
|
||||
rtol=1e-3,
|
||||
)
|
||||
torch.testing.assert_allclose(
|
||||
torch.testing.assert_close(
|
||||
adam1.state[p1]["state2"],
|
||||
adam2.state[p2]["state2"],
|
||||
atol=2,
|
||||
|
@ -472,7 +459,7 @@ gtype = [torch.float32, torch.float16]
|
|||
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
|
||||
# optimizer_names = ['lamb_apex', 'lamb8bit']
|
||||
# optimizer_names = ['lars_apex', 'lars8bit']
|
||||
optimizer_names = ["adam8bit_blockwise"]
|
||||
optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise']
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
|
|
Loading…
Reference in New Issue
Block a user