parent
ca2078a697
commit
2f8083bd8b
|
@ -42,3 +42,7 @@ Docs:
|
|||
|
||||
Features:
|
||||
- Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer
|
||||
- Added AdamW (copy of Adam with weight decay init 1e-2)
|
||||
|
||||
Bug fixes:
|
||||
- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam
|
||||
|
|
19
Makefile
19
Makefile
|
@ -19,15 +19,16 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
|
|||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
|
||||
# NVIDIA NVCC compilation flags
|
||||
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
#COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
#COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||
COMPUTE_CAPABILITY := -gencode arch=compute_75,code=sm_75 # Volta
|
||||
|
||||
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
|
||||
CC_CUDA92 := -gencode arch=compute_30,code=sm_30
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .adam import Adam, Adam8bit, Adam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
|
|
|
@ -28,7 +28,6 @@ class Adam32bit(Optimizer2State):
|
|||
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
|
||||
|
||||
|
||||
|
||||
class AnalysisAdam(torch.optim.Optimizer):
|
||||
"""Adam that performs 8-bit vs 32-bit error analysis.
|
||||
|
||||
|
|
29
bitsandbytes/optim/adamw.py
Normal file
29
bitsandbytes/optim/adamw.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
class AdamW(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None,
|
||||
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super(AdamW, self).__init__('adam', params, lr, betas, eps,
|
||||
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
|
||||
|
||||
class AdamW8bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=1e-2, amsgrad=False, args=None,
|
||||
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super(AdamW8bit, self).__init__('adam', params, lr, betas, eps,
|
||||
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
|
||||
|
||||
class AdamW32bit(Optimizer2State):
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||
weight_decay=1e-2, amsgrad=False, args=None,
|
||||
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
super(AdamW32bit, self).__init__('adam', params, lr, betas, eps,
|
||||
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
|
||||
|
|
@ -720,6 +720,9 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
|
|||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
|
||||
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
|
||||
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
|
||||
|
||||
if(weight_decay > 0.0f)
|
||||
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx,
|
|||
str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
|
||||
|
||||
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
|
||||
str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW)
|
||||
str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
|
||||
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
|
||||
|
@ -47,12 +48,14 @@ str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_
|
|||
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
|
||||
|
||||
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
|
||||
str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True))
|
||||
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
|
||||
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
|
||||
str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True))
|
||||
|
||||
str2statenames = {}
|
||||
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||
str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
|
||||
str2statenames['lars'] = [('momentum_buffer', 'state1')]
|
||||
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||
|
@ -61,6 +64,7 @@ str2statenames['adagrad'] = [('sum', 'state1')]
|
|||
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
|
||||
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
|
||||
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
|
||||
str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
|
||||
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
|
||||
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
|
||||
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
|
||||
|
@ -71,7 +75,7 @@ str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')
|
|||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
|
||||
optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
|
||||
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
|
@ -86,7 +90,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||
|
||||
if gtype == torch.float32:
|
||||
atol, rtol = 1e-6, 1e-5
|
||||
atol, rtol = 2e-6, 1e-5
|
||||
else:
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
||||
|
@ -201,7 +205,7 @@ def test_global_config(dim1, dim2, gtype):
|
|||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
|
||||
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
|
||||
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
|
|
Loading…
Reference in New Issue
Block a user