add some code in test_optim.py, although it seems to be failing
This commit is contained in:
parent
9b656f461a
commit
a43cd2008d
|
@ -1 +1,2 @@
|
|||
lion-pytorch
|
||||
pytest
|
||||
|
|
|
@ -7,6 +7,8 @@ from itertools import product
|
|||
from os.path import join
|
||||
|
||||
import pytest
|
||||
from lion_pytorch import Lion
|
||||
|
||||
import torch
|
||||
|
||||
import bitsandbytes as bnb
|
||||
|
@ -31,6 +33,7 @@ str2optimizers = {}
|
|||
str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||
# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||
str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion)
|
||||
str2optimizers["momentum_pytorch"] = (
|
||||
None,
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
|
@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = (
|
|||
)
|
||||
str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
|
||||
# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers["lion"] = (Lion, bnb.optim.Lion)
|
||||
str2optimizers["momentum"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = (
|
|||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
|
||||
)
|
||||
str2optimizers["lion8bit"] = (
|
||||
Lion,
|
||||
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False),
|
||||
)
|
||||
str2optimizers["momentum8bit"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
|
||||
|
@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = (
|
|||
torch.optim.Adam,
|
||||
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
|
||||
)
|
||||
str2optimizers["lion8bit_blockwise"] = (
|
||||
Lion,
|
||||
lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True),
|
||||
)
|
||||
str2optimizers["momentum8bit_blockwise"] = (
|
||||
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
|
||||
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
|
||||
|
@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = (
|
|||
|
||||
str2statenames = {}
|
||||
str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
str2statenames["lion"] = [("exp_avg", "state1")]
|
||||
str2statenames["momentum"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lars"] = [("momentum_buffer", "state1")]
|
||||
str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
|
||||
|
@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [
|
|||
("exp_avg", "state1", "qmap1", "max1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||
]
|
||||
str2statenames["lion8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1")
|
||||
]
|
||||
str2statenames["lamb8bit"] = [
|
||||
("exp_avg", "state1", "qmap1", "max1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "max2"),
|
||||
|
@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [
|
|||
("exp_avg", "state1", "qmap1", "absmax1"),
|
||||
("exp_avg_sq", "state2", "qmap2", "absmax2"),
|
||||
]
|
||||
str2statenames["lion8bit_blockwise"] = [
|
||||
("exp_avg", "state1", "qmap1", "absmax1")
|
||||
]
|
||||
str2statenames["momentum8bit"] = [
|
||||
("momentum_buffer", "state1", "qmap1", "max1")
|
||||
]
|
||||
|
@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [
|
|||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", "lars"]
|
||||
optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"]
|
||||
values = list(product(dim1, dim2, gtype, optimizer_names))
|
||||
names = [
|
||||
"dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values
|
||||
|
@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097]
|
|||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = [
|
||||
"adam8bit",
|
||||
"lion8bit",
|
||||
"momentum8bit",
|
||||
"rmsprop8bit",
|
||||
"adam8bit_blockwise",
|
||||
"lion8bit_blockwise",
|
||||
"lars8bit",
|
||||
"momentum8bit_blockwise",
|
||||
"rmsprop8bit_blockwise",
|
||||
|
|
Loading…
Reference in New Issue
Block a user