add some code in test_optim.py, although it seems to be failing

This commit is contained in:
Phil Wang 2023-03-22 09:14:05 -07:00
parent 9b656f461a
commit a43cd2008d
2 changed files with 23 additions and 1 deletions

View File

@ -1 +1,2 @@
lion-pytorch
pytest pytest

View File

@ -7,6 +7,8 @@ from itertools import product
from os.path import join from os.path import join
import pytest import pytest
from lion_pytorch import Lion
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
@ -31,6 +33,7 @@ str2optimizers = {}
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
str2optimizers["momentum_pytorch"] = ( str2optimizers["momentum_pytorch"] = (
None, None,
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = (
) )
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["momentum"] = ( str2optimizers["momentum"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = (
torch.optim.Adam, torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
) )
str2optimizers["lion8bit"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = ( str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = (
torch.optim.Adam, torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
) )
str2optimizers["lion8bit_blockwise"] = (
Lion,
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = ( str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
str2statenames = {} str2statenames = {}
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames["lion"] = [("exp_avg", "state1")]
str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [
("exp_avg", "state1", "qmap1", "max1"), ("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"), ("exp_avg_sq", "state2", "qmap2", "max2"),
] ]
str2statenames["lion8bit"] = [
("exp_avg", "state1", "qmap1", "max1")
]
str2statenames["lamb8bit"] = [ str2statenames["lamb8bit"] = [
("exp_avg", "state1", "qmap1", "max1"), ("exp_avg", "state1", "qmap1", "max1"),
("exp_avg_sq", "state2", "qmap2", "max2"), ("exp_avg_sq", "state2", "qmap2", "max2"),
@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"), ("exp_avg_sq", "state2", "qmap2", "absmax2"),
] ]
str2statenames["lion8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1")
]
str2statenames["momentum8bit"] = [ str2statenames["momentum8bit"] = [
("momentum_buffer", "state1", "qmap1", "max1") ("momentum_buffer", "state1", "qmap1", "max1")
] ]
@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ["adam", "momentum", "rmsprop", "lars"] optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
values = list(product(dim1, dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = [ names = [
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = [ optimizer_names = [
"adam8bit", "adam8bit",
"lion8bit",
"momentum8bit", "momentum8bit",
"rmsprop8bit", "rmsprop8bit",
"adam8bit_blockwise", "adam8bit_blockwise",
"lion8bit_blockwise",
"lars8bit", "lars8bit",
"momentum8bit_blockwise", "momentum8bit_blockwise",
"rmsprop8bit_blockwise", "rmsprop8bit_blockwise",