Added adagrad with tests (no clipping).
This commit is contained in:
parent
22b2877c7f
commit
8b3c0f355c
File diff suppressed because one or more lines are too long
|
@ -7,4 +7,5 @@ from .sgd import SGD, SGD8bit, SGD32bit
|
|||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
from .optimizer import GlobalOptimManager
|
||||
|
|
57
bitsandbytes/optim/adagrad.py
Normal file
57
bitsandbytes/optim/adagrad.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
from bitsandbytes.optim.optimizer import Optimizer1State
|
||||
|
||||
torch.optim.Adagrad
|
||||
|
||||
class Adagrad(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
|
||||
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if initial_accumulator_value != 0.0:
|
||||
raise ValueError('Initial accumulator value != 0.0 not supported!')
|
||||
if lr_decay != 0.0:
|
||||
raise ValueError('Lr Decay != 0.0 not supported!')
|
||||
super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
|
||||
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
|
||||
|
||||
class Adagrad8bit(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
|
||||
optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if initial_accumulator_value != 0.0:
|
||||
raise ValueError('Initial accumulator value != 0.0 not supported!')
|
||||
if lr_decay != 0.0:
|
||||
raise ValueError('Lr Decay != 0.0 not supported!')
|
||||
assert block_wise
|
||||
super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
|
||||
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
|
||||
|
||||
class Adagrad32bit(Optimizer1State):
|
||||
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10,
|
||||
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if initial_accumulator_value != 0.0:
|
||||
raise ValueError('Initial accumulator value != 0.0 not supported!')
|
||||
if lr_decay != 0.0:
|
||||
raise ValueError('Lr Decay != 0.0 not supported!')
|
||||
super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps,
|
||||
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
|
|
@ -790,6 +790,11 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
|||
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
case ADAGRAD:
|
||||
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update
|
||||
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
|
||||
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -884,6 +889,10 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
|
|||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
|
||||
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
|
||||
break;
|
||||
case ADAGRAD:
|
||||
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1653,6 +1662,9 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
case RMSPROP:
|
||||
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
|
||||
break;
|
||||
case ADAGRAD:
|
||||
s1_vals[j] = s1_vals[j] + (g_val*g_val);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1688,6 +1700,10 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
g_val = g_vals[j];
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||
break;
|
||||
case ADAGRAD:
|
||||
g_val = g_vals[j];
|
||||
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1738,6 +1754,8 @@ MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
|
|||
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
#define MAKE_Optimizer32bit1State(oname, gtype) \
|
||||
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
|
||||
|
@ -1747,6 +1765,8 @@ MAKE_Optimizer32bit1State(MOMENTUM, half)
|
|||
MAKE_Optimizer32bit1State(MOMENTUM, float)
|
||||
MAKE_Optimizer32bit1State(RMSPROP, half)
|
||||
MAKE_Optimizer32bit1State(RMSPROP, float)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, half)
|
||||
MAKE_Optimizer32bit1State(ADAGRAD, float)
|
||||
|
||||
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
|
||||
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
|
||||
|
@ -1862,3 +1882,5 @@ MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
|
|||
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
|
||||
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
|
||||
|
|
|
@ -199,6 +199,8 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
|||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
|
@ -240,6 +242,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
@ -274,6 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
blocks = n/BLOCKSIZE_1STATE;
|
||||
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
|
@ -321,6 +325,8 @@ MAKE_optimizer32bit(MOMENTUM, half)
|
|||
MAKE_optimizer32bit(MOMENTUM, float)
|
||||
MAKE_optimizer32bit(RMSPROP, half)
|
||||
MAKE_optimizer32bit(RMSPROP, float)
|
||||
MAKE_optimizer32bit(ADAGRAD, half)
|
||||
MAKE_optimizer32bit(ADAGRAD, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bit(name, gtype) \
|
||||
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
|
@ -350,6 +356,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
|
|||
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
|
||||
|
||||
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
|
||||
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
|
||||
|
|
|
@ -36,6 +36,7 @@ typedef enum Optimizer_t
|
|||
MOMENTUM = 1,
|
||||
RMSPROP = 2,
|
||||
LARS = 3,
|
||||
ADAGRAD = 4,
|
||||
} Optimizer_t;
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,8 @@ MAKE_FUNC32(adam, ADAM, float, 32)
|
|||
MAKE_FUNC32(adam, ADAM, half, 16)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
|
||||
|
||||
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
|
||||
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
|
@ -62,6 +64,8 @@ MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
|
|||
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
|
||||
|
||||
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||
|
@ -102,6 +106,8 @@ extern "C"
|
|||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
MAKE_CFUNC32(rmsprop, half, 16)
|
||||
MAKE_CFUNC32(adagrad, float, 32)
|
||||
MAKE_CFUNC32(adagrad, half, 16)
|
||||
|
||||
#define MAKE_CFUNC8(name, gtype, gbits) \
|
||||
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
|
@ -135,6 +141,8 @@ extern "C"
|
|||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
|
||||
|
||||
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
|
||||
|
|
|
@ -39,6 +39,7 @@ str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambd
|
|||
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
|
||||
str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
|
||||
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
|
||||
str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False))
|
||||
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
|
||||
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
|
||||
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
|
||||
|
@ -48,6 +49,7 @@ str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
|
|||
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
|
||||
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
|
||||
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
|
||||
str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True))
|
||||
|
||||
str2statenames = {}
|
||||
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||
|
@ -55,6 +57,7 @@ str2statenames['momentum'] = [('momentum_buffer', 'state1')]
|
|||
str2statenames['lars'] = [('momentum_buffer', 'state1')]
|
||||
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||
str2statenames['rmsprop'] = [('square_avg', 'state1')]
|
||||
str2statenames['adagrad'] = [('sum', 'state1')]
|
||||
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
|
||||
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
|
||||
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
|
||||
|
@ -63,11 +66,12 @@ str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1
|
|||
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
|
||||
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
|
||||
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
|
||||
str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')]
|
||||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
|
||||
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
|
||||
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
|
@ -197,7 +201,7 @@ def test_global_config(dim1, dim2, gtype):
|
|||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
|
||||
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
|
||||
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
|
|
Loading…
Reference in New Issue
Block a user