2022-08-01 10:31:48 +00:00
|
|
|
import ctypes
|
2021-10-06 02:16:20 +00:00
|
|
|
import os
|
|
|
|
import shutil
|
2022-08-01 10:31:48 +00:00
|
|
|
import time
|
2021-10-06 02:16:20 +00:00
|
|
|
import uuid
|
2022-08-01 10:31:48 +00:00
|
|
|
from itertools import product
|
|
|
|
from os.path import join
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
import bitsandbytes as bnb
|
|
|
|
import bitsandbytes.functional as F
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# import apex
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
k = 20
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def get_temp_dir():
|
2022-08-01 10:31:48 +00:00
|
|
|
path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
|
2021-10-06 02:16:20 +00:00
|
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
return path
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def rm_path(path):
|
|
|
|
shutil.rmtree(path)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
str2optimizers = {}
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
|
|
|
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
|
|
|
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
|
|
|
str2optimizers["momentum_pytorch"] = (
|
|
|
|
None,
|
|
|
|
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
|
|
|
bnb.optim.Adam,
|
|
|
|
)
|
|
|
|
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
|
|
|
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
|
|
|
str2optimizers["momentum"] = (
|
|
|
|
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
|
|
|
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
|
|
|
)
|
|
|
|
str2optimizers["lars"] = (
|
|
|
|
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
|
|
|
|
lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
|
|
|
|
)
|
|
|
|
str2optimizers["rmsprop"] = (
|
|
|
|
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
|
|
|
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
|
|
|
|
)
|
|
|
|
str2optimizers["adam8bit"] = (
|
|
|
|
torch.optim.Adam,
|
|
|
|
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
|
|
|
|
)
|
|
|
|
str2optimizers["momentum8bit"] = (
|
|
|
|
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
|
|
|
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
|
|
|
|
)
|
|
|
|
str2optimizers["rmsprop8bit"] = (
|
|
|
|
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
|
|
|
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
|
|
|
|
)
|
|
|
|
str2optimizers["lars8bit"] = (
|
|
|
|
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
|
|
|
|
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
|
|
|
|
)
|
|
|
|
|
|
|
|
str2optimizers["adam8bit_blockwise"] = (
|
|
|
|
torch.optim.Adam,
|
|
|
|
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
|
|
|
|
)
|
|
|
|
str2optimizers["momentum8bit_blockwise"] = (
|
|
|
|
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
|
|
|
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
|
|
|
)
|
|
|
|
str2optimizers["rmsprop8bit_blockwise"] = (
|
|
|
|
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
|
|
|
|
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
str2statenames = {}
|
2022-08-01 10:31:48 +00:00
|
|
|
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
|
|
|
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
|
|
|
str2statenames["lars"] = [("momentum_buffer", "state1")]
|
|
|
|
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
|
|
|
str2statenames["rmsprop"] = [("square_avg", "state1")]
|
|
|
|
str2statenames["adam8bit"] = [
|
|
|
|
("exp_avg", "state1", "qmap1", "max1"),
|
|
|
|
("exp_avg_sq", "state2", "qmap2", "max2"),
|
|
|
|
]
|
|
|
|
str2statenames["lamb8bit"] = [
|
|
|
|
("exp_avg", "state1", "qmap1", "max1"),
|
|
|
|
("exp_avg_sq", "state2", "qmap2", "max2"),
|
|
|
|
]
|
|
|
|
str2statenames["adam8bit_blockwise"] = [
|
|
|
|
("exp_avg", "state1", "qmap1", "absmax1"),
|
|
|
|
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
|
|
|
]
|
2022-08-01 16:32:47 +00:00
|
|
|
str2statenames["momentum8bit"] = [
|
|
|
|
("momentum_buffer", "state1", "qmap1", "max1")
|
|
|
|
]
|
2022-08-01 10:31:48 +00:00
|
|
|
str2statenames["momentum8bit_blockwise"] = [
|
|
|
|
("momentum_buffer", "state1", "qmap1", "absmax1")
|
|
|
|
]
|
|
|
|
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
|
|
|
|
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
|
2022-08-01 16:32:47 +00:00
|
|
|
str2statenames["rmsprop8bit_blockwise"] = [
|
|
|
|
("square_avg", "state1", "qmap1", "absmax1")
|
|
|
|
]
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
dim1 = [1024]
|
|
|
|
dim2 = [32, 1024, 4097, 1]
|
|
|
|
gtype = [torch.float32, torch.float16]
|
2022-10-24 18:54:25 +00:00
|
|
|
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(dim1, dim2, gtype, optimizer_names))
|
2022-08-01 16:32:47 +00:00
|
|
|
names = [
|
|
|
|
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
|
|
|
]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
|
|
|
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
2022-08-01 10:31:48 +00:00
|
|
|
if dim1 == 1 and dim2 == 1:
|
|
|
|
return
|
|
|
|
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
|
2021-10-06 02:16:20 +00:00
|
|
|
p2 = p1.clone()
|
|
|
|
p1 = p1.float()
|
|
|
|
|
|
|
|
torch_optimizer = str2optimizers[optim_name][0]([p1])
|
|
|
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
|
|
|
|
|
|
|
if gtype == torch.float32:
|
2022-07-22 21:41:05 +00:00
|
|
|
atol, rtol = 1e-6, 1e-5
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
|
|
|
atol, rtol = 1e-4, 1e-3
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
|
2021-10-06 02:16:20 +00:00
|
|
|
p1.grad = g.clone().float()
|
|
|
|
p2.grad = g.clone()
|
|
|
|
|
|
|
|
bnb_optimizer.step()
|
|
|
|
torch_optimizer.step()
|
|
|
|
|
|
|
|
for name1, name2 in str2statenames[optim_name]:
|
2022-08-01 10:31:48 +00:00
|
|
|
torch.testing.assert_allclose(
|
|
|
|
torch_optimizer.state[p1][name1],
|
|
|
|
bnb_optimizer.state[p2][name2],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if i % (k // 5) == 0 and i > 0:
|
2021-10-06 02:16:20 +00:00
|
|
|
path = get_temp_dir()
|
2022-08-01 10:31:48 +00:00
|
|
|
torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
|
2021-10-06 02:16:20 +00:00
|
|
|
del bnb_optimizer
|
|
|
|
bnb_optimizer = None
|
|
|
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
2022-08-01 10:31:48 +00:00
|
|
|
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
2021-10-06 02:16:20 +00:00
|
|
|
rm_path(path)
|
|
|
|
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
|
|
|
|
for name1, name2 in str2statenames[optim_name]:
|
2022-08-01 10:31:48 +00:00
|
|
|
torch.testing.assert_allclose(
|
|
|
|
torch_optimizer.state[p1][name1],
|
|
|
|
bnb_optimizer.state[p2][name2],
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
if gtype == torch.float16:
|
|
|
|
# the adam buffers should also be close because they are 32-bit
|
|
|
|
# but the paramters can diverge because they are 16-bit
|
|
|
|
# the difference grow larger and larger with each update
|
|
|
|
# --> copy the state to keep weights close
|
|
|
|
p1.data = p1.data.half().float()
|
|
|
|
p2.copy_(p1.data)
|
|
|
|
torch.testing.assert_allclose(p1.half(), p2)
|
2022-08-01 10:31:48 +00:00
|
|
|
if optim_name in ["lars", "lamb"]:
|
|
|
|
assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
dim1 = [1024]
|
|
|
|
dim2 = [32, 1024, 4097]
|
|
|
|
gtype = [torch.float32, torch.float16]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(dim1, dim2, gtype))
|
|
|
|
names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
|
|
|
|
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
|
|
|
|
def test_global_config(dim1, dim2, gtype):
|
2022-08-01 10:31:48 +00:00
|
|
|
if dim1 == 1 and dim2 == 1:
|
|
|
|
return
|
|
|
|
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
|
|
|
|
p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
|
|
|
|
p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
|
2021-10-06 02:16:20 +00:00
|
|
|
mask = torch.rand_like(p2) < 0.1
|
|
|
|
beta1 = 0.9
|
|
|
|
beta2 = 0.999
|
|
|
|
lr = 0.001
|
|
|
|
eps = 1e-8
|
|
|
|
|
|
|
|
bnb.optim.GlobalOptimManager.get_instance().initialize()
|
2022-08-01 16:32:47 +00:00
|
|
|
bnb.optim.GlobalOptimManager.get_instance().override_config(
|
|
|
|
p3, "optim_bits", 8
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
bnb.optim.GlobalOptimManager.get_instance().register_parameters(
|
|
|
|
[p1, p2, p3]
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
p1 = p1.cuda()
|
|
|
|
p2 = p2.cuda()
|
|
|
|
p3 = p3.cuda()
|
|
|
|
|
|
|
|
adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)
|
|
|
|
|
|
|
|
if gtype == torch.float32:
|
|
|
|
atol, rtol = 1e-6, 1e-5
|
|
|
|
else:
|
|
|
|
atol, rtol = 1e-4, 1e-3
|
|
|
|
|
|
|
|
for i in range(50):
|
2022-08-01 10:31:48 +00:00
|
|
|
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
|
|
|
|
g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
|
|
|
|
g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
|
2021-10-06 02:16:20 +00:00
|
|
|
p1.grad = g1
|
|
|
|
p2.grad = g2
|
|
|
|
p3.grad = g3
|
|
|
|
|
|
|
|
adam2.step()
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
assert adam2.state[p3]["state1"].dtype == torch.uint8
|
|
|
|
assert adam2.state[p3]["state2"].dtype == torch.uint8
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
|
|
|
dim1 = [1024]
|
|
|
|
dim2 = [32, 1024, 4097]
|
|
|
|
gtype = [torch.float32, torch.float16]
|
2022-08-01 10:31:48 +00:00
|
|
|
optimizer_names = [
|
|
|
|
"adam8bit",
|
|
|
|
"momentum8bit",
|
|
|
|
"rmsprop8bit",
|
|
|
|
"adam8bit_blockwise",
|
|
|
|
"lars8bit",
|
|
|
|
"momentum8bit_blockwise",
|
|
|
|
"rmsprop8bit_blockwise",
|
|
|
|
]
|
|
|
|
values = list(product(dim1, dim2, gtype, optimizer_names))
|
2022-08-01 16:32:47 +00:00
|
|
|
names = [
|
|
|
|
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
|
|
|
]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
|
|
|
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
2022-08-01 10:31:48 +00:00
|
|
|
if dim1 == 1 and dim2 == 1:
|
|
|
|
return
|
|
|
|
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
|
2021-10-06 02:16:20 +00:00
|
|
|
p2 = p1.clone()
|
|
|
|
p1 = p1.float()
|
|
|
|
blocksize = 2048
|
|
|
|
|
|
|
|
torch_optimizer = str2optimizers[optim_name][0]([p1])
|
|
|
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
|
|
|
|
|
|
|
if gtype == torch.float32:
|
|
|
|
atol, rtol = 3e-3, 1e-3
|
|
|
|
patol, prtol = 1e-5, 1e-3
|
|
|
|
|
|
|
|
else:
|
|
|
|
atol, rtol = 3e-3, 1e-3
|
|
|
|
patol, prtol = 1e-5, 1e-3
|
|
|
|
|
|
|
|
errors = []
|
|
|
|
relerrors = []
|
|
|
|
|
|
|
|
for i in range(50):
|
2022-08-01 10:31:48 +00:00
|
|
|
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
|
2021-10-06 02:16:20 +00:00
|
|
|
p1.grad = g.clone().float()
|
|
|
|
p2.grad = g.clone()
|
|
|
|
|
|
|
|
bnb_optimizer.step()
|
|
|
|
torch_optimizer.step()
|
|
|
|
|
|
|
|
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
|
|
|
|
|
|
|
|
dequant_states = []
|
|
|
|
for name1, name2, qmap, max_val in str2statenames[optim_name]:
|
2022-08-01 10:31:48 +00:00
|
|
|
# print(bnb_optimizer.state[p2][max_val], name1)
|
|
|
|
if "blockwise" in optim_name:
|
|
|
|
s1 = F.dequantize_blockwise(
|
|
|
|
code=bnb_optimizer.state[p2][qmap],
|
|
|
|
absmax=bnb_optimizer.state[p2][max_val],
|
|
|
|
A=bnb_optimizer.state[p2][name2],
|
|
|
|
blocksize=blocksize,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
s1 = F.dequantize(
|
|
|
|
code=bnb_optimizer.state[p2][qmap],
|
|
|
|
absmax=bnb_optimizer.state[p2][max_val],
|
|
|
|
A=bnb_optimizer.state[p2][name2],
|
|
|
|
)
|
|
|
|
num_not_close = (
|
|
|
|
torch.isclose(
|
|
|
|
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
|
|
|
|
)
|
|
|
|
== 0
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
assert num_not_close.sum().item() < 20
|
|
|
|
dequant_states.append(s1.clone())
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
err = torch.abs(p1 - p2)
|
|
|
|
relerr = err / torch.abs(p1)
|
2021-10-06 02:16:20 +00:00
|
|
|
assert err.mean() < 0.0001
|
|
|
|
assert relerr.mean() < 0.001
|
|
|
|
|
|
|
|
errors.append(err.mean().item())
|
|
|
|
relerrors.append(relerr.mean().item())
|
|
|
|
|
|
|
|
if i % 10 == 0 and i > 0:
|
2022-08-01 10:31:48 +00:00
|
|
|
for (name1, name2, qmap, max_val), s in zip(
|
|
|
|
str2statenames[optim_name], dequant_states
|
|
|
|
):
|
2021-10-06 02:16:20 +00:00
|
|
|
s1cpy = s.clone()
|
|
|
|
raws1cpy = bnb_optimizer.state[p2][name2].clone()
|
|
|
|
qmap1 = bnb_optimizer.state[p2][qmap].clone()
|
|
|
|
|
|
|
|
path = get_temp_dir()
|
2022-08-01 10:31:48 +00:00
|
|
|
torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
|
2021-10-06 02:16:20 +00:00
|
|
|
del bnb_optimizer
|
|
|
|
bnb_optimizer = None
|
|
|
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
2022-08-01 10:31:48 +00:00
|
|
|
bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
|
2021-10-06 02:16:20 +00:00
|
|
|
rm_path(path)
|
2022-08-01 16:32:47 +00:00
|
|
|
torch.testing.assert_allclose(
|
|
|
|
raws1cpy, bnb_optimizer.state[p2][name2]
|
|
|
|
)
|
|
|
|
torch.testing.assert_allclose(
|
|
|
|
qmap1, bnb_optimizer.state[p2][qmap]
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if "blockwise" in optim_name:
|
|
|
|
s1 = F.dequantize_blockwise(
|
|
|
|
code=bnb_optimizer.state[p2][qmap],
|
|
|
|
absmax=bnb_optimizer.state[p2][max_val],
|
|
|
|
A=bnb_optimizer.state[p2][name2],
|
|
|
|
blocksize=blocksize,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
s1 = F.dequantize(
|
|
|
|
code=bnb_optimizer.state[p2][qmap],
|
|
|
|
absmax=bnb_optimizer.state[p2][max_val],
|
|
|
|
A=bnb_optimizer.state[p2][name2],
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
torch.testing.assert_allclose(s1cpy, s1)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
num_not_close = (
|
|
|
|
torch.isclose(
|
2022-08-01 16:32:47 +00:00
|
|
|
torch_optimizer.state[p1][name1],
|
|
|
|
s1,
|
|
|
|
atol=atol,
|
|
|
|
rtol=rtol,
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
|
|
|
== 0
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
assert num_not_close.sum().item() < 20
|
2022-08-01 16:32:47 +00:00
|
|
|
torch.testing.assert_allclose(
|
|
|
|
p1, p2.float(), atol=patol, rtol=prtol
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
# the parameters diverge quickly. Here we keep them close
|
|
|
|
# together so we can test against the Adam error
|
|
|
|
p1.data = p1.data.to(gtype).float()
|
|
|
|
p2.copy_(p1.data)
|
|
|
|
torch.testing.assert_allclose(p1.to(gtype), p2)
|
2022-08-01 10:31:48 +00:00
|
|
|
for (name1, name2, qmap, max_val), s in zip(
|
|
|
|
str2statenames[optim_name], dequant_states
|
|
|
|
):
|
2021-10-06 02:16:20 +00:00
|
|
|
torch_optimizer.state[p1][name1].copy_(s.data)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
# print(sum(errors)/len(errors))
|
|
|
|
# print(sum(relerrors)/len(relerrors))
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
|
|
|
dim1 = [1024]
|
|
|
|
dim2 = [32, 1024, 4097]
|
|
|
|
gtype = [torch.float32]
|
|
|
|
optim_bits = [32, 8]
|
2022-08-01 10:31:48 +00:00
|
|
|
values = list(product(dim1, dim2, gtype, optim_bits))
|
2022-08-01 16:32:47 +00:00
|
|
|
names = [
|
|
|
|
"dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals)
|
|
|
|
for vals in values
|
|
|
|
]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
|
|
|
|
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
|
2022-08-01 10:31:48 +00:00
|
|
|
if dim1 == 1 and dim2 == 1:
|
|
|
|
return
|
|
|
|
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
|
2021-10-06 02:16:20 +00:00
|
|
|
beta1 = 0.9
|
|
|
|
beta2 = 0.999
|
|
|
|
lr = 0.001
|
|
|
|
eps = 1e-8
|
|
|
|
p1 = p1.cuda()
|
|
|
|
p2 = p1.clone()
|
|
|
|
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
|
2022-08-01 10:31:48 +00:00
|
|
|
adam2 = bnb.optim.Adam(
|
2022-08-01 16:32:47 +00:00
|
|
|
[p2],
|
|
|
|
lr,
|
|
|
|
(beta1, beta2),
|
|
|
|
eps,
|
|
|
|
optim_bits=optim_bits,
|
|
|
|
percentile_clipping=5,
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
gnorm_vec = torch.zeros(100).cuda()
|
|
|
|
step = 0
|
|
|
|
|
|
|
|
for i in range(50):
|
|
|
|
step += 1
|
2022-08-01 16:32:47 +00:00
|
|
|
g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (
|
|
|
|
0.01 * i
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
g2 = g1.clone()
|
|
|
|
p2.grad = g2
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
|
|
|
|
g1, gnorm_vec, step, 5
|
|
|
|
)
|
|
|
|
g1 = (g1.float() * gnorm_scale).to(gtype)
|
2021-10-06 02:16:20 +00:00
|
|
|
p1.grad = g1
|
|
|
|
|
|
|
|
adam1.step()
|
|
|
|
adam2.step()
|
|
|
|
|
|
|
|
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
|
|
|
|
if optim_bits == 32:
|
|
|
|
torch.testing.assert_allclose(p1, p2)
|
2022-08-01 10:31:48 +00:00
|
|
|
torch.testing.assert_allclose(
|
|
|
|
adam1.state[p1]["state1"],
|
|
|
|
adam2.state[p2]["state1"],
|
|
|
|
atol=5e-5,
|
|
|
|
rtol=1e-4,
|
|
|
|
)
|
|
|
|
torch.testing.assert_allclose(
|
|
|
|
adam1.state[p1]["state2"],
|
|
|
|
adam2.state[p2]["state2"],
|
|
|
|
atol=5e-5,
|
|
|
|
rtol=1e-4,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
elif optim_bits == 8:
|
|
|
|
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
|
2022-08-01 10:31:48 +00:00
|
|
|
torch.testing.assert_allclose(
|
2022-08-01 16:32:47 +00:00
|
|
|
adam1.state[p1]["state1"],
|
|
|
|
adam2.state[p2]["state1"],
|
|
|
|
atol=2,
|
|
|
|
rtol=1e-3,
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
|
|
|
torch.testing.assert_allclose(
|
2022-08-01 16:32:47 +00:00
|
|
|
adam1.state[p1]["state2"],
|
|
|
|
adam2.state[p2]["state2"],
|
|
|
|
atol=2,
|
|
|
|
rtol=1e-3,
|
2022-08-01 10:31:48 +00:00
|
|
|
)
|
|
|
|
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
|
|
|
|
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
|
2021-10-06 02:16:20 +00:00
|
|
|
if i % 10 == 0 and i > 0:
|
|
|
|
path = get_temp_dir()
|
2022-08-01 10:31:48 +00:00
|
|
|
torch.save(adam2.state_dict(), join(path, "opt.pt"))
|
2021-10-06 02:16:20 +00:00
|
|
|
del adam2
|
|
|
|
adam2 = None
|
2022-08-01 10:31:48 +00:00
|
|
|
adam2 = bnb.optim.Adam(
|
|
|
|
[p2],
|
|
|
|
lr,
|
|
|
|
(beta1, beta2),
|
|
|
|
eps,
|
|
|
|
optim_bits=optim_bits,
|
|
|
|
percentile_clipping=5,
|
|
|
|
)
|
|
|
|
adam2.load_state_dict(torch.load(join(path, "opt.pt")))
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
|
|
|
dim1 = [4096]
|
|
|
|
dim2 = [4096]
|
|
|
|
gtype = [torch.float32, torch.float16]
|
2022-08-01 10:31:48 +00:00
|
|
|
# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
|
|
|
|
# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
|
|
|
|
# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
|
|
|
|
# optimizer_names = ['lamb_apex', 'lamb8bit']
|
|
|
|
# optimizer_names = ['lars_apex', 'lars8bit']
|
|
|
|
optimizer_names = ["adam8bit_blockwise"]
|
|
|
|
values = list(product(dim1, dim2, gtype, optimizer_names))
|
2022-08-01 16:32:47 +00:00
|
|
|
names = [
|
|
|
|
"dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values
|
|
|
|
]
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
|
|
|
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
|
2022-08-01 10:31:48 +00:00
|
|
|
if dim1 == 1 and dim2 == 1:
|
|
|
|
return
|
|
|
|
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
bnb_optimizer = str2optimizers[optim_name][1]([p1])
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
|
2021-10-06 02:16:20 +00:00
|
|
|
p1.grad = g
|
2022-07-22 21:41:05 +00:00
|
|
|
for i in range(k):
|
2022-08-01 10:31:48 +00:00
|
|
|
if i == k // 5:
|
2021-10-06 02:16:20 +00:00
|
|
|
# 100 iterations for burn-in
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
t0 = time.time()
|
|
|
|
|
|
|
|
bnb_optimizer.step()
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
2022-08-01 10:31:48 +00:00
|
|
|
s = time.time() - t0
|
|
|
|
print("")
|
|
|
|
params = (k - k // 5) * dim1 * dim2
|
|
|
|
print(optim_name, gtype, s / params)
|
|
|
|
# assert s < 3.9
|