forked from mrq/bitsandbytes-rocm
Add a CPU-only build option
This commit is contained in:
parent
33efe4a09f
commit
8258b4364a
19
Makefile
19
Makefile
|
@ -10,10 +10,10 @@ NVCC := $(CUDA_HOME)/bin/nvcc
|
|||
###########################################
|
||||
|
||||
CSRC := $(ROOT_DIR)/csrc
|
||||
BUILD_DIR:= $(ROOT_DIR)/cuda_build
|
||||
BUILD_DIR:= $(ROOT_DIR)/build
|
||||
|
||||
FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
|
||||
FILES_CPP := $(CSRC)/pythonInterface.c
|
||||
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
|
||||
|
||||
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
|
||||
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||
|
@ -46,27 +46,30 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86
|
|||
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
|
||||
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
|
||||
cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
|
||||
cuda110: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
|
||||
cuda11x: $(BUILD_DIR) env
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||
|
||||
cpuonly: $(BUILD_DIR) env
|
||||
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so
|
||||
|
||||
env:
|
||||
@echo "ENVIRONMENT"
|
||||
|
@ -80,7 +83,7 @@ env:
|
|||
@echo "============================"
|
||||
|
||||
$(BUILD_DIR):
|
||||
mkdir -p cuda_build
|
||||
mkdir -p build
|
||||
mkdir -p dependencies
|
||||
|
||||
$(ROOT_DIR)/dependencies/cub:
|
||||
|
|
|
@ -2,9 +2,14 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .optim import adam
|
||||
|
||||
from .nn import modules
|
||||
__pdoc__ = {'libBitsNBytes' : False,
|
||||
from cextension import COMPILED_WITH_CUDA
|
||||
|
||||
if COMPILED_WITH_CUDA:
|
||||
from .optim import adam
|
||||
|
||||
__pdoc__ = {'libBitsNBytes': False,
|
||||
'optim.optimizer.Optimizer8bit': False,
|
||||
'optim.optimizer.MockArgs': False
|
||||
}
|
||||
}
|
||||
|
|
13
bitsandbytes/cextension.py
Normal file
13
bitsandbytes/cextension.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
import ctypes as ct
|
||||
import os
|
||||
from warnings import warn
|
||||
|
||||
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
|
||||
|
||||
try:
|
||||
lib.cadam32bit_g32
|
||||
COMPILED_WITH_CUDA = True
|
||||
except AttributeError:
|
||||
warn("The installed version of bitsandbytes was compiled without GPU support. "
|
||||
"8-bit optimizers and GPU quantization are unavailable.")
|
||||
COMPILED_WITH_CUDA = False
|
File diff suppressed because one or more lines are too long
|
@ -2,11 +2,15 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .adam import Adam, Adam8bit, Adam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
from .optimizer import GlobalOptimManager
|
||||
|
||||
from bitsandbytes.cextension import COMPILED_WITH_CUDA
|
||||
|
||||
if COMPILED_WITH_CUDA:
|
||||
from .adam import Adam, Adam8bit, Adam32bit
|
||||
from .adamw import AdamW, AdamW8bit, AdamW32bit
|
||||
from .sgd import SGD, SGD8bit, SGD32bit
|
||||
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||
from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit
|
||||
from .optimizer import GlobalOptimManager
|
||||
|
|
|
@ -31,6 +31,6 @@ class RMSprop32bit(Optimizer1State):
|
|||
if alpha == 0:
|
||||
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
|
||||
if centered:
|
||||
raise NotImplementError(f'Centered RMSprop is not supported!')
|
||||
raise NotImplementedError(f'Centered RMSprop is not supported!')
|
||||
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
|
||||
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
|
||||
|
|
39
csrc/common.cpp
Normal file
39
csrc/common.cpp
Normal file
|
@ -0,0 +1,39 @@
|
|||
#include <common.h>
|
||||
#include <float.h>
|
||||
|
||||
void *quantize_block(void *arguments) {
|
||||
// 1. find absmax in block
|
||||
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
|
||||
// 3. do binary search to find the closest value
|
||||
// 4. check minimal distance
|
||||
// 5. store index
|
||||
|
||||
struct quantize_block_args *args = (quantize_block_args *) arguments;
|
||||
|
||||
// 1. find absmax in block
|
||||
float absmax_block = -FLT_MAX;
|
||||
for (int i = args->block_idx; i < args->block_end; i++)
|
||||
absmax_block = fmax(absmax_block, fabs(args->A[i]));
|
||||
|
||||
args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block;
|
||||
|
||||
for (int i = args->block_idx; i < args->block_end; i++) {
|
||||
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
|
||||
// 3. do binary search to find the closest value
|
||||
float normed_value = args->A[i] / absmax_block;
|
||||
int idx = args->bin_searcher->scalar(normed_value);
|
||||
|
||||
// 4. check minimal distance
|
||||
// The binary search returns always the value to the left, which might not be the closest value
|
||||
if (idx < 255) {
|
||||
float dist_left = fabs(normed_value - (args->code[idx]));
|
||||
float dist_right = fabs(normed_value - (args->code[idx + 1]));
|
||||
if (dist_right < dist_left) { idx += 1; }
|
||||
}
|
||||
|
||||
// 5. store index
|
||||
args->out[i] = (unsigned char) idx;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
23
csrc/common.h
Normal file
23
csrc/common.h
Normal file
|
@ -0,0 +1,23 @@
|
|||
#include <BinSearch.h>
|
||||
|
||||
#ifndef common
|
||||
#define common
|
||||
|
||||
using namespace BinSearch;
|
||||
|
||||
struct quantize_block_args {
|
||||
BinAlgo<Scalar, float, Direct2> *bin_searcher;
|
||||
float *code;
|
||||
float *A;
|
||||
float *absmax;
|
||||
unsigned char *out;
|
||||
int block_end;
|
||||
int block_idx;
|
||||
int threadidx;
|
||||
};
|
||||
|
||||
#define BLOCK_SIZE 4096
|
||||
|
||||
void *quantize_block(void *arguments);
|
||||
|
||||
#endif
|
57
csrc/cpu_ops.cpp
Normal file
57
csrc/cpu_ops.cpp
Normal file
|
@ -0,0 +1,57 @@
|
|||
#include <BinSearch.h>
|
||||
#include <pthread.h>
|
||||
#include <common.h>
|
||||
|
||||
using namespace BinSearch;
|
||||
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) {
|
||||
for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
|
||||
int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
|
||||
int block_end = block_idx + valid_items;
|
||||
for (int i = block_idx; i < block_end; i++)
|
||||
out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE];
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) {
|
||||
|
||||
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
|
||||
code[0] = -1.0f;
|
||||
|
||||
int num_blocks = n / BLOCK_SIZE;
|
||||
num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
|
||||
|
||||
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks);
|
||||
struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *));
|
||||
|
||||
for (int i = 0; i < num_blocks; i++)
|
||||
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
|
||||
|
||||
const uint32 elements_code = 256;
|
||||
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
|
||||
|
||||
for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
|
||||
int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
|
||||
int block_end = block_idx + valid_items;
|
||||
|
||||
struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE];
|
||||
arg->bin_searcher = &bin_searcher;
|
||||
arg->code = code;
|
||||
arg->A = A;
|
||||
arg->absmax = absmax;
|
||||
arg->out = out;
|
||||
arg->block_end = block_end;
|
||||
arg->block_idx = block_idx;
|
||||
arg->threadidx = block_idx / BLOCK_SIZE;
|
||||
|
||||
pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_blocks; i++)
|
||||
int err = pthread_join(threads[i], NULL);
|
||||
|
||||
free(threads);
|
||||
for (int i = 0; i < num_blocks; i++)
|
||||
free(args[i]);
|
||||
free(args);
|
||||
}
|
9
csrc/cpu_ops.h
Normal file
9
csrc/cpu_ops.h
Normal file
|
@ -0,0 +1,9 @@
|
|||
#ifndef BITSANDBYTES_CPU_OPS_H
|
||||
#define BITSANDBYTES_CPU_OPS_H
|
||||
|
||||
|
||||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
|
||||
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
|
||||
|
||||
#endif
|
453
csrc/ops.cu
453
csrc/ops.cu
|
@ -8,251 +8,141 @@
|
|||
#include <cub/device/device_scan.cuh>
|
||||
#include <limits>
|
||||
#include <BinSearch.h>
|
||||
#include <common.h>
|
||||
|
||||
|
||||
using namespace BinSearch;
|
||||
using std::cout;
|
||||
using std::endl;
|
||||
|
||||
#define BLOCK_SIZE 4096
|
||||
void histogramScatterAdd2D(float *histogram, int *index1, int *index2, float *src, int maxidx1, int n) {
|
||||
int threads = 512;
|
||||
int blocks = n / threads;
|
||||
blocks = n % threads == 0 ? blocks : blocks + 1;
|
||||
kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
struct quantize_block_args
|
||||
{
|
||||
BinAlgo<Scalar, float, Direct2> *bin_searcher;
|
||||
float *code;
|
||||
float *A;
|
||||
float *absmax;
|
||||
unsigned char *out;
|
||||
int block_end;
|
||||
int block_idx;
|
||||
int threadidx;
|
||||
};
|
||||
template<typename T>
|
||||
void estimateQuantiles(T *A, float *code, float offset, int n) {
|
||||
int blocks = n / 4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256 * sizeof(float)));
|
||||
kEstimateQuantiles < T ><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void *quantize_block(void *arguments)
|
||||
{
|
||||
// 1. find absmax in block
|
||||
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
|
||||
// 3. do binary search to find the closest value
|
||||
// 4. check minimal distance
|
||||
// 5. store index
|
||||
void quantize(float *code, float *A, unsigned char *out, int n) {
|
||||
int blocks = n / 1024;
|
||||
blocks = n % 1024 == 0 ? blocks : blocks + 1;
|
||||
kQuantize<<<blocks, 1024>>>(code, A, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
struct quantize_block_args *args = (quantize_block_args*)arguments;
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n) {
|
||||
int blocks = n / 1024;
|
||||
blocks = n % 1024 == 0 ? blocks : blocks + 1;
|
||||
kDequantize<<<blocks, 1024>>>(code, A, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
// 1. find absmax in block
|
||||
float absmax_block = -FLT_MAX;
|
||||
for (int i = args->block_idx; i < args->block_end; i++)
|
||||
absmax_block = fmax(absmax_block, fabs(args->A[i]));
|
||||
template<typename T, int STOCHASTIC>
|
||||
void quantizeBlockwise(float *code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) {
|
||||
int blocks = n / 4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
kQuantizeBlockwise < T, 4096, 4, STOCHASTIC ><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
args->absmax[args->block_idx/BLOCK_SIZE] = absmax_block;
|
||||
template<typename T>
|
||||
void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) {
|
||||
int blocks = n / blocksize;
|
||||
blocks = n % blocksize == 0 ? blocks : blocks + 1;
|
||||
if (blocksize == 4096)
|
||||
kDequantizeBlockwise < T, 4096, 1024, 4 ><<<blocks, 4096 / 4>>>(code, A, absmax, out, n);
|
||||
else if (blocksize == 2048)
|
||||
kDequantizeBlockwise < T, 2048, 512, 4 ><<<blocks, 2048 / 4>>>(code, A, absmax, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
for (int i = args->block_idx; i < args->block_end; i++)
|
||||
{
|
||||
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
|
||||
// 3. do binary search to find the closest value
|
||||
float normed_value = args->A[i]/absmax_block;
|
||||
int idx = args->bin_searcher->scalar(normed_value);
|
||||
template<typename T, int OPTIMIZER>
|
||||
void optimizer32bit(T *g, T *p,
|
||||
float *state1, float *state2, float *unorm, float max_unorm, float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) {
|
||||
int blocks = n / 4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
switch (OPTIMIZER) {
|
||||
case ADAM:
|
||||
if (max_unorm > 0.0f) {
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
|
||||
kPreconditionOptimizer32bit2State < T, OPTIMIZER, 4096,
|
||||
8 ><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
kOptimizer32bit2State < T,
|
||||
OPTIMIZER ><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
|
||||
// 4. check minimal distance
|
||||
// The binary search returns always the value to the left, which might not be the closest value
|
||||
if(idx < 255)
|
||||
{
|
||||
float dist_left = fabs(normed_value-(args->code[idx]));
|
||||
float dist_right = fabs(normed_value-(args->code[idx+1]));
|
||||
if(dist_right < dist_left){ idx+=1; }
|
||||
if (max_unorm > 0.0f) {
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State < T, OPTIMIZER, 4096,
|
||||
8 ><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
kOptimizer32bit1State < T,
|
||||
OPTIMIZER ><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
|
||||
// 5. store index
|
||||
args->out[i] = (unsigned char)idx;
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n)
|
||||
{
|
||||
template<typename T, int OPTIMIZER>
|
||||
void optimizerStatic8bit(T *p, T *g,
|
||||
unsigned char *state1, unsigned char *state2,
|
||||
float *unorm, float max_unorm, float param_norm,
|
||||
float beta1, float beta2,
|
||||
float eps, int step, float lr,
|
||||
float *quantiles1, float *quantiles2,
|
||||
float *max1, float *max2, float *new_max1, float *new_max2,
|
||||
float weight_decay,
|
||||
const float gnorm_scale, int n) {
|
||||
int blocks = n / 4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
|
||||
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
|
||||
code[0] = -1.0f;
|
||||
if (max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); }
|
||||
|
||||
int num_blocks = n/BLOCK_SIZE;
|
||||
num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
|
||||
|
||||
pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t)*num_blocks);
|
||||
struct quantize_block_args **args = (quantize_block_args**)malloc(num_blocks*sizeof(quantize_block_args*));
|
||||
|
||||
for(int i = 0; i < num_blocks; i++)
|
||||
args[i] = (quantize_block_args*)malloc(sizeof(quantize_block_args));
|
||||
|
||||
const uint32 elements_code = 256;
|
||||
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
|
||||
|
||||
for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
|
||||
{
|
||||
int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
|
||||
int block_end = block_idx + valid_items;
|
||||
|
||||
struct quantize_block_args *arg = args[block_idx/BLOCK_SIZE];
|
||||
arg->bin_searcher = &bin_searcher;
|
||||
arg->code = code;
|
||||
arg->A = A;
|
||||
arg->absmax = absmax;
|
||||
arg->out = out;
|
||||
arg->block_end = block_end;
|
||||
arg->block_idx = block_idx;
|
||||
arg->threadidx = block_idx/BLOCK_SIZE;
|
||||
|
||||
pthread_create(&threads[block_idx/BLOCK_SIZE], NULL, &quantize_block, (void *)arg);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_blocks; i++)
|
||||
int err = pthread_join(threads[i], NULL);
|
||||
|
||||
free(threads);
|
||||
for(int i = 0; i < num_blocks; i++)
|
||||
free(args[i]);
|
||||
free(args);
|
||||
}
|
||||
|
||||
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n)
|
||||
{
|
||||
for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
|
||||
{
|
||||
int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
|
||||
int block_end = block_idx + valid_items;
|
||||
for (int i = block_idx; i < block_end; i++)
|
||||
out[i] = code[A[i]]*absmax[block_idx/BLOCK_SIZE];
|
||||
}
|
||||
}
|
||||
|
||||
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
|
||||
{
|
||||
int threads = 512;
|
||||
int blocks = n/threads;
|
||||
blocks = n % threads == 0 ? blocks : blocks + 1;
|
||||
kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
|
||||
kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void quantize(float *code, float *A, unsigned char *out, int n)
|
||||
{
|
||||
int blocks = n/1024;
|
||||
blocks = n % 1024 == 0 ? blocks : blocks + 1;
|
||||
kQuantize<<<blocks, 1024>>>(code, A, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void dequantize(float *code, unsigned char *A, float *out, int n)
|
||||
{
|
||||
int blocks = n/1024;
|
||||
blocks = n % 1024 == 0 ? blocks : blocks + 1;
|
||||
kDequantize<<<blocks, 1024>>>(code, A, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||
{
|
||||
int blocks = n/blocksize;
|
||||
blocks = n % blocksize == 0 ? blocks : blocks + 1;
|
||||
if(blocksize == 4096)
|
||||
kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||
else if(blocksize == 2048)
|
||||
kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
|
||||
if(max_unorm > 0.0f)
|
||||
{
|
||||
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
||||
unsigned char* state1, unsigned char* state2,
|
||||
float *unorm, float max_unorm, float param_norm,
|
||||
float beta1, float beta2,
|
||||
float eps, int step, float lr,
|
||||
float* quantiles1, float* quantiles2,
|
||||
float* max1, float* max2, float* new_max1, float* new_max2,
|
||||
float weight_decay,
|
||||
const float gnorm_scale, int n)
|
||||
{
|
||||
int blocks = n/4096;
|
||||
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||
|
||||
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
|
||||
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
switch (OPTIMIZER) {
|
||||
case ADAM:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit2State < T,
|
||||
OPTIMIZER ><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
kOptimizerStatic8bit2State < T,
|
||||
OPTIMIZER ><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
|
||||
kPreconditionOptimizerStatic8bit1State < T,
|
||||
OPTIMIZER ><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
kOptimizerStatic8bit1State < T, OPTIMIZER ><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#define BLOCKSIZE_2STATE 2048
|
||||
|
@ -260,42 +150,43 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
|||
#define BLOCKSIZE_1STATE 2048
|
||||
#define NUM_1STATE 8
|
||||
|
||||
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
|
||||
{
|
||||
template<typename T, int OPTIMIZER>
|
||||
void optimizerStatic8bitBlockwise(T *p, T *g,
|
||||
unsigned char *state1, unsigned char *state2, float beta1, float beta2, float eps, int step, float lr,
|
||||
float *quantiles1, float *quantiles2, float *absmax1, float *absmax2, float weight_decay,
|
||||
const float gnorm_scale, bool skip_zeros, int n) {
|
||||
|
||||
int blocks = 0;
|
||||
switch(OPTIMIZER)
|
||||
{
|
||||
case ADAM:
|
||||
blocks = n/BLOCKSIZE_2STATE;
|
||||
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
blocks = n/BLOCKSIZE_1STATE;
|
||||
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
int blocks = 0;
|
||||
switch (OPTIMIZER) {
|
||||
case ADAM:
|
||||
blocks = n / BLOCKSIZE_2STATE;
|
||||
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit2StateBlockwise < T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE ><<<blocks, BLOCKSIZE_2STATE /
|
||||
NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
|
||||
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
case MOMENTUM:
|
||||
case RMSPROP:
|
||||
case ADAGRAD:
|
||||
blocks = n / BLOCKSIZE_1STATE;
|
||||
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
|
||||
kOptimizerStatic8bit1StateBlockwise < T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE ><<<blocks, BLOCKSIZE_1STATE /
|
||||
NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
|
||||
{
|
||||
int blocks = n/2048;
|
||||
blocks = n % 2048 == 0 ? blocks : blocks + 1;
|
||||
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
|
||||
kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
template<typename T>
|
||||
void percentileClipping(T *g, float *gnorm_vec, int step, const int n) {
|
||||
int blocks = n / 2048;
|
||||
blocks = n % 2048 == 0 ? blocks : blocks + 1;
|
||||
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float)));
|
||||
kPercentileClipping < T, 2048, 4 ><<<blocks, 512>>>(g, gnorm_vec, step, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
|
@ -304,13 +195,23 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
|
|||
//==============================================================
|
||||
|
||||
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||
|
||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||
template void
|
||||
quantizeBlockwise<half, 0>(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
|
||||
|
||||
template void
|
||||
quantizeBlockwise<float, 0>(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
|
||||
|
||||
template void
|
||||
quantizeBlockwise<half, 1>(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
|
||||
|
||||
template void
|
||||
quantizeBlockwise<float, 1>(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
|
||||
|
||||
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||
|
||||
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||
|
||||
#define MAKE_optimizer32bit(name, gtype) \
|
||||
|
@ -320,12 +221,19 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
|||
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
|
||||
|
||||
MAKE_optimizer32bit(ADAM, half)
|
||||
|
||||
MAKE_optimizer32bit(ADAM, float)
|
||||
|
||||
MAKE_optimizer32bit(MOMENTUM, half)
|
||||
|
||||
MAKE_optimizer32bit(MOMENTUM, float)
|
||||
|
||||
MAKE_optimizer32bit(RMSPROP, half)
|
||||
|
||||
MAKE_optimizer32bit(RMSPROP, float)
|
||||
|
||||
MAKE_optimizer32bit(ADAGRAD, half)
|
||||
|
||||
MAKE_optimizer32bit(ADAGRAD, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bit(name, gtype) \
|
||||
|
@ -338,11 +246,17 @@ template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char
|
|||
float weight_decay, \
|
||||
const float gnorm_scale, int n); \
|
||||
|
||||
|
||||
MAKE_optimizerStatic8bit(ADAM, half)
|
||||
|
||||
MAKE_optimizerStatic8bit(ADAM, float)
|
||||
|
||||
MAKE_optimizerStatic8bit(MOMENTUM, half)
|
||||
|
||||
MAKE_optimizerStatic8bit(MOMENTUM, float)
|
||||
|
||||
MAKE_optimizerStatic8bit(RMSPROP, half)
|
||||
|
||||
MAKE_optimizerStatic8bit(RMSPROP, float)
|
||||
|
||||
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
|
||||
|
@ -350,14 +264,23 @@ template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g
|
|||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
|
||||
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
|
||||
|
||||
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
|
||||
|
||||
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
|
||||
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
|
||||
template void percentileClipping(float *g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
template void percentileClipping(half *g, float *gnorm_vec, int step, const int n);
|
||||
|
|
10
csrc/ops.cuh
10
csrc/ops.cuh
|
@ -68,16 +68,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
|
|||
|
||||
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
|
||||
|
||||
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
|
||||
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
|
||||
|
||||
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,10 @@
|
|||
// This source code is licensed under the MIT license found in the
|
||||
// LICENSE file in the root directory of this source tree.
|
||||
|
||||
#if BUILD_CUDA
|
||||
#include <ops.cuh>
|
||||
#endif
|
||||
#include <cpu_ops.h>
|
||||
|
||||
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
|
||||
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
|
||||
|
@ -12,6 +15,7 @@
|
|||
// UNMANGLED CALLS
|
||||
//===================================================================================
|
||||
|
||||
#if BUILD_CUDA
|
||||
void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
|
||||
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
|
||||
|
||||
|
@ -34,15 +38,15 @@ MAKE_FUNC32(adagrad, ADAGRAD, half, 16)
|
|||
|
||||
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
|
||||
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
float *unorm, float max_unorm, float param_norm, \
|
||||
float *unorm, float max_unorm, float param_norm, \
|
||||
float beta1, float beta2, \
|
||||
float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, \
|
||||
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||
float weight_decay, float gnorm_scale, int n) \
|
||||
{ \
|
||||
optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||
optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||
} \
|
||||
|
||||
MAKE_FUNC8(adam, ADAM, float, 32)
|
||||
|
@ -78,39 +82,41 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un
|
|||
|
||||
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||
#endif
|
||||
|
||||
extern "C"
|
||||
{
|
||||
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
|
||||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
||||
#if BUILD_CUDA
|
||||
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
|
||||
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
|
||||
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
||||
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
||||
|
||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||
|
||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
|
||||
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
|
||||
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
|
||||
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CFUNC32(adam, float, 32)
|
||||
MAKE_CFUNC32(adam, half, 16)
|
||||
MAKE_CFUNC32(momentum, float, 32)
|
||||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
MAKE_CFUNC32(rmsprop, half, 16)
|
||||
MAKE_CFUNC32(adagrad, float, 32)
|
||||
MAKE_CFUNC32(adagrad, half, 16)
|
||||
MAKE_CFUNC32(adam, float, 32)
|
||||
MAKE_CFUNC32(adam, half, 16)
|
||||
MAKE_CFUNC32(momentum, float, 32)
|
||||
MAKE_CFUNC32(momentum, half, 16)
|
||||
MAKE_CFUNC32(rmsprop, float, 32)
|
||||
MAKE_CFUNC32(rmsprop, half, 16)
|
||||
MAKE_CFUNC32(adagrad, float, 32)
|
||||
MAKE_CFUNC32(adagrad, half, 16)
|
||||
|
||||
#define MAKE_CFUNC8(name, gtype, gbits) \
|
||||
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
#define MAKE_CFUNC8(name, gtype, gbits) \
|
||||
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||
float *unorm, float max_unorm, float param_norm, \
|
||||
float beta1, float beta2, \
|
||||
float eps, int step, float lr, \
|
||||
|
@ -118,40 +124,40 @@ extern "C"
|
|||
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||
float weight_decay, float gnorm_scale, int n) \
|
||||
{ \
|
||||
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||
} \
|
||||
|
||||
MAKE_CFUNC8(adam, float, 32)
|
||||
MAKE_CFUNC8(adam, half, 16)
|
||||
MAKE_CFUNC8(momentum, float, 32)
|
||||
MAKE_CFUNC8(momentum, half, 16)
|
||||
MAKE_CFUNC8(rmsprop, float, 32)
|
||||
MAKE_CFUNC8(rmsprop, half, 16)
|
||||
MAKE_CFUNC8(adam, float, 32)
|
||||
MAKE_CFUNC8(adam, half, 16)
|
||||
MAKE_CFUNC8(momentum, float, 32)
|
||||
MAKE_CFUNC8(momentum, half, 16)
|
||||
MAKE_CFUNC8(rmsprop, float, 32)
|
||||
MAKE_CFUNC8(rmsprop, half, 16)
|
||||
|
||||
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
|
||||
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
|
||||
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, float, 32)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
|
||||
MAKE_CBLOCKWISE8(adam, ADAM, float, 32)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16)
|
||||
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, 16)
|
||||
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, 32)
|
||||
|
||||
|
||||
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
|
||||
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
|
||||
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
|
||||
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
|
||||
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
|
||||
#endif
|
||||
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
|
||||
|
||||
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
|
||||
}
|
||||
|
||||
|
||||
|
|
22
setup.py
22
setup.py
|
@ -6,27 +6,27 @@ import os
|
|||
from setuptools import setup, find_packages
|
||||
|
||||
|
||||
|
||||
def read(fname):
|
||||
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||
|
||||
|
||||
version = os.getenv("CUDA_VERSION", "cpu")
|
||||
|
||||
setup(
|
||||
name = f"bitsandbytes-cuda{os.environ['CUDA_VERSION']}",
|
||||
version = "0.26.0",
|
||||
author = "Tim Dettmers",
|
||||
author_email = "dettmers@cs.washington.edu",
|
||||
description = ("8-bit optimizers and quantization routines."),
|
||||
license = "MIT",
|
||||
keywords = "gpu optimizers optimization 8-bit quantization compression",
|
||||
url = "http://packages.python.org/bitsandbytes",
|
||||
name="bitsandbytes",
|
||||
version=f"0.26.0+{version}",
|
||||
author="Tim Dettmers",
|
||||
author_email="dettmers@cs.washington.edu",
|
||||
description="8-bit optimizers and quantization routines.",
|
||||
license="MIT",
|
||||
keywords="gpu optimizers optimization 8-bit quantization compression",
|
||||
url="http://packages.python.org/bitsandbytes",
|
||||
packages=find_packages(),
|
||||
package_data={'': ['libbitsandbytes.so']},
|
||||
long_description=read('README.md'),
|
||||
long_description_content_type = 'text/markdown',
|
||||
long_description_content_type='text/markdown',
|
||||
classifiers=[
|
||||
"Development Status :: 4 - Beta",
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence'
|
||||
],
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user