From 2f8083bd8b084290f888fe59b329d98ebd6dd468 Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 28 Nov 2021 21:18:11 -0800 Subject: [PATCH] Added AdamW. #10 #13 --- CHANGELOG.md | 4 ++++ Makefile | 19 ++++++++++--------- bitsandbytes/optim/__init__.py | 1 + bitsandbytes/optim/adam.py | 1 - bitsandbytes/optim/adamw.py | 29 +++++++++++++++++++++++++++++ csrc/kernels.cu | 3 +++ tests/test_optim.py | 10 +++++++--- 7 files changed, 54 insertions(+), 13 deletions(-) create mode 100644 bitsandbytes/optim/adamw.py diff --git a/CHANGELOG.md b/CHANGELOG.md index beaa256..e943fa2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,3 +42,7 @@ Docs: Features: - Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer + - Added AdamW (copy of Adam with weight decay init 1e-2) + +Bug fixes: + - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam diff --git a/Makefile b/Makefile index 6055541..5093410 100644 --- a/Makefile +++ b/Makefile @@ -19,15 +19,16 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags -COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler -COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler -COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell -COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell -COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal -COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal -COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta -COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta +#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler +#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler +#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell +#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell +#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal +#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal +#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta +#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta +#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta +COMPUTE_CAPABILITY := -gencode arch=compute_75,code=sm_75 # Volta # CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not CC_CUDA92 := -gencode arch=compute_30,code=sm_30 diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py index af8a488..5e73414 100644 --- a/bitsandbytes/optim/__init__.py +++ b/bitsandbytes/optim/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .adam import Adam, Adam8bit, Adam32bit +from .adamw import AdamW, AdamW8bit, AdamW32bit from .sgd import SGD, SGD8bit, SGD32bit from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS from .lamb import LAMB, LAMB8bit, LAMB32bit diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index eb951ee..1e93a60 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -28,7 +28,6 @@ class Adam32bit(Optimizer2State): weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) - class AnalysisAdam(torch.optim.Optimizer): """Adam that performs 8-bit vs 32-bit error analysis. diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py new file mode 100644 index 0000000..7761f3b --- /dev/null +++ b/bitsandbytes/optim/adamw.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from bitsandbytes.optim.optimizer import Optimizer2State +import bitsandbytes.functional as F + +class AdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW, self).__init__('adam', params, lr, betas, eps, + weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) + +class AdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW8bit, self).__init__('adam', params, lr, betas, eps, + weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) + +class AdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=1e-2, amsgrad=False, args=None, + min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super(AdamW32bit, self).__init__('adam', params, lr, betas, eps, + weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) + diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 56f6a76..d0aabff 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p, s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); } break; } diff --git a/tests/test_optim.py b/tests/test_optim.py index ff0734b..d306511 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -34,6 +34,7 @@ str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) +str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW) str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) @@ -47,12 +48,14 @@ str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_ str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True)) str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True)) str2statenames = {} str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] +str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['momentum'] = [('momentum_buffer', 'state1')] str2statenames['lars'] = [('momentum_buffer', 'state1')] str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] @@ -61,6 +64,7 @@ str2statenames['adagrad'] = [('sum', 'state1')] str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] +str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] @@ -71,7 +75,7 @@ str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1') dim1 = [1024] dim2 = [32, 1024, 4097, 1] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad'] +optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @@ -86,7 +90,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name): bnb_optimizer = str2optimizers[optim_name][1]([p2]) if gtype == torch.float32: - atol, rtol = 1e-6, 1e-5 + atol, rtol = 2e-6, 1e-5 else: atol, rtol = 1e-4, 1e-3 @@ -201,7 +205,7 @@ def test_global_config(dim1, dim2, gtype): dim1 = [1024] dim2 = [32, 1024, 4097] gtype = [torch.float32, torch.float16] -optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise'] +optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise'] values = list(product(dim1,dim2, gtype, optimizer_names)) names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)