add some code in test_optim.py, although it seems to be failing

This commit is contained in:
Phil Wang 2023-03-22 09:14:05 -07:00
parent 9b656f461a
commit a43cd2008d
2 changed files with 23 additions and 1 deletions

View File

@ -1 +1,2 @@
lion-pytorch
pytest

View File

@ -7,6 +7,8 @@ from itertools import product
from os.path import join
import pytest
from lion_pytorch import Lion
import torch
import bitsandbytes as bnb
@ -31,6 +33,7 @@ str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = (
None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = (
)
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["lion8bit"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["lion8bit_blockwise"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
]
str2statenames["lion8bit"] = [
("exp_avg", "state1", "qmap1", "max1")
]
str2statenames["lamb8bit"] = [
("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"),
@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["lion8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1")
]
str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1")
]
@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1 = [1024]
dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16]
optimizer_names = [
"adam8bit",
"lion8bit",
"momentum8bit",
"rmsprop8bit",
"adam8bit_blockwise",
"lion8bit_blockwise",
"lars8bit",
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",