diff --git a/requirements.txt b/requirements.txt index e079f8a..883b2e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,2 @@ +lion-pytorch pytest diff --git a/tests/test_optim.py b/tests/test_optim.py index 3df2dad..9f815ab 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -7,6 +7,8 @@ from itertools import product from os.path import join import pytest +from lion_pytorch import Lion + import torch import bitsandbytes as bnb @@ -31,6 +33,7 @@ str2optimizers = {} str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) str2optimizers["momentum_pytorch"] = ( None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), @@ -38,6 +41,7 @@ str2optimizers["momentum_pytorch"] = ( ) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers["lion"] = (Lion, bnb.optim.Lion) str2optimizers["momentum"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), @@ -54,6 +58,10 @@ str2optimizers["adam8bit"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False), ) +str2optimizers["lion8bit"] = ( + Lion, + lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False), +) str2optimizers["momentum8bit"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), @@ -71,6 +79,10 @@ str2optimizers["adam8bit_blockwise"] = ( torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True), ) +str2optimizers["lion8bit_blockwise"] = ( + Lion, + lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True), +) str2optimizers["momentum8bit_blockwise"] = ( lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), @@ -82,6 +94,7 @@ str2optimizers["rmsprop8bit_blockwise"] = ( str2statenames = {} str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["lion"] = [("exp_avg", "state1")] str2statenames["momentum"] = [("momentum_buffer", "state1")] str2statenames["lars"] = [("momentum_buffer", "state1")] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] @@ -90,6 +103,9 @@ str2statenames["adam8bit"] = [ ("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2"), ] +str2statenames["lion8bit"] = [ + ("exp_avg", "state1", "qmap1", "max1") +] str2statenames["lamb8bit"] = [ ("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2"), @@ -98,6 +114,9 @@ str2statenames["adam8bit_blockwise"] = [ ("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2"), ] +str2statenames["lion8bit_blockwise"] = [ + ("exp_avg", "state1", "qmap1", "absmax1") +] str2statenames["momentum8bit"] = [ ("momentum_buffer", "state1", "qmap1", "max1") ] @@ -113,7 +132,7 @@ str2statenames["rmsprop8bit_blockwise"] = [ dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ["adam", "momentum", "rmsprop", "lars"] +optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lion"] values = list(product(dim1, dim2, gtype, optimizer_names)) names = [ "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values @@ -241,9 +260,11 @@ dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] optimizer_names = [ "adam8bit", + "lion8bit", "momentum8bit", "rmsprop8bit", "adam8bit_blockwise", + "lion8bit_blockwise", "lars8bit", "momentum8bit_blockwise", "rmsprop8bit_blockwise",