add some code in test_optim.py, although it seems to be failing
This commit is contained in:
parent
9b656f461a
commit
a43cd2008d
|
@ -1 +1,2 @@
|
||||||
|
lion-pytorch
|
||||||
pytest
|
pytest
|
||||||
|
|
|
@ -7,6 +7,8 @@ from itertools import product
|
||||||
from os.path import join
|
from os.path import join
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from lion_pytorch import Lion
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
@ -31,6 +33,7 @@ str2optimizers = {}
|
||||||
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||||
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||||
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||||
|
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
|
||||||
str2optimizers["momentum_pytorch"] = (
|
str2optimizers["momentum_pytorch"] = (
|
||||||
None,
|
None,
|
||||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||||
|
@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = (
|
||||||
)
|
)
|
||||||
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
||||||
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||||
|
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
|
||||||
str2optimizers["momentum"] = (
|
str2optimizers["momentum"] = (
|
||||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||||
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
||||||
|
@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = (
|
||||||
torch.optim.Adam,
|
torch.optim.Adam,
|
||||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
|
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
|
||||||
)
|
)
|
||||||
|
str2optimizers["lion8bit"] = (
|
||||||
|
Lion,
|
||||||
|
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
|
||||||
|
)
|
||||||
str2optimizers["momentum8bit"] = (
|
str2optimizers["momentum8bit"] = (
|
||||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
|
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||||
|
@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = (
|
||||||
torch.optim.Adam,
|
torch.optim.Adam,
|
||||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
|
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
|
||||||
)
|
)
|
||||||
|
str2optimizers["lion8bit_blockwise"] = (
|
||||||
|
Lion,
|
||||||
|
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
|
||||||
|
)
|
||||||
str2optimizers["momentum8bit_blockwise"] = (
|
str2optimizers["momentum8bit_blockwise"] = (
|
||||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
||||||
|
@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
|
||||||
|
|
||||||
str2statenames = {}
|
str2statenames = {}
|
||||||
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||||
|
str2statenames["lion"] = [("exp_avg", "state1")]
|
||||||
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
||||||
str2statenames["lars"] = [("momentum_buffer", "state1")]
|
str2statenames["lars"] = [("momentum_buffer", "state1")]
|
||||||
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||||
|
@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [
|
||||||
("exp_avg", "state1", "qmap1", "max1"),
|
("exp_avg", "state1", "qmap1", "max1"),
|
||||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||||
]
|
]
|
||||||
|
str2statenames["lion8bit"] = [
|
||||||
|
("exp_avg", "state1", "qmap1", "max1")
|
||||||
|
]
|
||||||
str2statenames["lamb8bit"] = [
|
str2statenames["lamb8bit"] = [
|
||||||
("exp_avg", "state1", "qmap1", "max1"),
|
("exp_avg", "state1", "qmap1", "max1"),
|
||||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||||
|
@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [
|
||||||
("exp_avg", "state1", "qmap1", "absmax1"),
|
("exp_avg", "state1", "qmap1", "absmax1"),
|
||||||
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
||||||
]
|
]
|
||||||
|
str2statenames["lion8bit_blockwise"] = [
|
||||||
|
("exp_avg", "state1", "qmap1", "absmax1")
|
||||||
|
]
|
||||||
str2statenames["momentum8bit"] = [
|
str2statenames["momentum8bit"] = [
|
||||||
("momentum_buffer", "state1", "qmap1", "max1")
|
("momentum_buffer", "state1", "qmap1", "max1")
|
||||||
]
|
]
|
||||||
|
@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
|
||||||
dim1 = [1024]
|
dim1 = [1024]
|
||||||
dim2 = [32, 1024, 4097, 1]
|
dim2 = [32, 1024, 4097, 1]
|
||||||
gtype = [torch.float32, torch.float16]
|
gtype = [torch.float32, torch.float16]
|
||||||
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
|
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
|
||||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||||
names = [
|
names = [
|
||||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||||
|
@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097]
|
||||||
gtype = [torch.float32, torch.float16]
|
gtype = [torch.float32, torch.float16]
|
||||||
optimizer_names = [
|
optimizer_names = [
|
||||||
"adam8bit",
|
"adam8bit",
|
||||||
|
"lion8bit",
|
||||||
"momentum8bit",
|
"momentum8bit",
|
||||||
"rmsprop8bit",
|
"rmsprop8bit",
|
||||||
"adam8bit_blockwise",
|
"adam8bit_blockwise",
|
||||||
|
"lion8bit_blockwise",
|
||||||
"lars8bit",
|
"lars8bit",
|
||||||
"momentum8bit_blockwise",
|
"momentum8bit_blockwise",
|
||||||
"rmsprop8bit_blockwise",
|
"rmsprop8bit_blockwise",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user