Add a CPU-only build option

This commit is contained in:
Max Ryabinin 2022-07-01 17:16:10 +03:00
parent 33efe4a09f
commit 8258b4364a
14 changed files with 454 additions and 382 deletions

View File

@ -10,10 +10,10 @@ NVCC := $(CUDA_HOME)/bin/nvcc
###########################################
CSRC := $(ROOT_DIR)/csrc
BUILD_DIR:= $(ROOT_DIR)/cuda_build
BUILD_DIR:= $(ROOT_DIR)/build
FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_CPP := $(CSRC)/pythonInterface.c
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
@ -46,27 +46,30 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda110: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda11x: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cpuonly: $(BUILD_DIR) env
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so
env:
@echo "ENVIRONMENT"
@ -80,7 +83,7 @@ env:
@echo "============================"
$(BUILD_DIR):
mkdir -p cuda_build
mkdir -p build
mkdir -p dependencies
$(ROOT_DIR)/dependencies/cub:

View File

@ -2,8 +2,13 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .optim import adam
from .nn import modules
from cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA:
from .optim import adam
__pdoc__ = {'libBitsNBytes': False,
'optim.optimizer.Optimizer8bit': False,
'optim.optimizer.MockArgs': False

View File

@ -0,0 +1,13 @@
import ctypes as ct
import os
from warnings import warn
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
try:
lib.cadam32bit_g32
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False

View File

@ -3,18 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import ctypes as ct
import os
import random
from typing import Tuple
import torch
from torch import Tensor
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
from cextension import lib, COMPILED_WITH_CUDA
name2qmap = {}
if COMPILED_WITH_CUDA:
''' C FUNCTIONS FOR OPTIMIZERS '''
str2optimizer32bit = {}
str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
@ -138,7 +138,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else:
raise NotImplementError(f'Not supported data type {A.dtype}')
raise NotImplementedError(f'Not supported data type {A.dtype}')
return out
def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor:
@ -384,7 +384,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
param_norm = torch.norm(p.data.float())
if optimizer_name not in str2optimizer32bit:
raise NotImplementError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')
raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')
if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),

View File

@ -2,6 +2,10 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from bitsandbytes.cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA:
from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit

View File

@ -31,6 +31,6 @@ class RMSprop32bit(Optimizer1State):
if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered:
raise NotImplementError(f'Centered RMSprop is not supported!')
raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)

39
csrc/common.cpp Normal file
View File

@ -0,0 +1,39 @@
#include <common.h>
#include <float.h>
void *quantize_block(void *arguments) {
// 1. find absmax in block
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
// 4. check minimal distance
// 5. store index
struct quantize_block_args *args = (quantize_block_args *) arguments;
// 1. find absmax in block
float absmax_block = -FLT_MAX;
for (int i = args->block_idx; i < args->block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i]));
args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block;
for (int i = args->block_idx; i < args->block_end; i++) {
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float normed_value = args->A[i] / absmax_block;
int idx = args->bin_searcher->scalar(normed_value);
// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
if (idx < 255) {
float dist_left = fabs(normed_value - (args->code[idx]));
float dist_right = fabs(normed_value - (args->code[idx + 1]));
if (dist_right < dist_left) { idx += 1; }
}
// 5. store index
args->out[i] = (unsigned char) idx;
}
return NULL;
}

23
csrc/common.h Normal file
View File

@ -0,0 +1,23 @@
#include <BinSearch.h>
#ifndef common
#define common
using namespace BinSearch;
struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
int block_end;
int block_idx;
int threadidx;
};
#define BLOCK_SIZE 4096
void *quantize_block(void *arguments);
#endif

57
csrc/cpu_ops.cpp Normal file
View File

@ -0,0 +1,57 @@
#include <BinSearch.h>
#include <pthread.h>
#include <common.h>
using namespace BinSearch;
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) {
for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
int block_end = block_idx + valid_items;
for (int i = block_idx; i < block_end; i++)
out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE];
}
}
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) {
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
int num_blocks = n / BLOCK_SIZE;
num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks);
struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *));
for (int i = 0; i < num_blocks; i++)
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
const uint32 elements_code = 256;
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
int block_end = block_idx + valid_items;
struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE];
arg->bin_searcher = &bin_searcher;
arg->code = code;
arg->A = A;
arg->absmax = absmax;
arg->out = out;
arg->block_end = block_end;
arg->block_idx = block_idx;
arg->threadidx = block_idx / BLOCK_SIZE;
pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg);
}
for (int i = 0; i < num_blocks; i++)
int err = pthread_join(threads[i], NULL);
free(threads);
for (int i = 0; i < num_blocks; i++)
free(args[i]);
free(args);
}

9
csrc/cpu_ops.h Normal file
View File

@ -0,0 +1,9 @@
#ifndef BITSANDBYTES_CPU_OPS_H
#define BITSANDBYTES_CPU_OPS_H
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
#endif

View File

@ -8,125 +8,14 @@
#include <cub/device/device_scan.cuh>
#include <limits>
#include <BinSearch.h>
#include <common.h>
using namespace BinSearch;
using std::cout;
using std::endl;
#define BLOCK_SIZE 4096
struct quantize_block_args
{
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
int block_end;
int block_idx;
int threadidx;
};
void *quantize_block(void *arguments)
{
// 1. find absmax in block
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
// 4. check minimal distance
// 5. store index
struct quantize_block_args *args = (quantize_block_args*)arguments;
// 1. find absmax in block
float absmax_block = -FLT_MAX;
for (int i = args->block_idx; i < args->block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i]));
args->absmax[args->block_idx/BLOCK_SIZE] = absmax_block;
for (int i = args->block_idx; i < args->block_end; i++)
{
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float normed_value = args->A[i]/absmax_block;
int idx = args->bin_searcher->scalar(normed_value);
// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
if(idx < 255)
{
float dist_left = fabs(normed_value-(args->code[idx]));
float dist_right = fabs(normed_value-(args->code[idx+1]));
if(dist_right < dist_left){ idx+=1; }
}
// 5. store index
args->out[i] = (unsigned char)idx;
}
return NULL;
}
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n)
{
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
int num_blocks = n/BLOCK_SIZE;
num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t)*num_blocks);
struct quantize_block_args **args = (quantize_block_args**)malloc(num_blocks*sizeof(quantize_block_args*));
for(int i = 0; i < num_blocks; i++)
args[i] = (quantize_block_args*)malloc(sizeof(quantize_block_args));
const uint32 elements_code = 256;
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
{
int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
int block_end = block_idx + valid_items;
struct quantize_block_args *arg = args[block_idx/BLOCK_SIZE];
arg->bin_searcher = &bin_searcher;
arg->code = code;
arg->A = A;
arg->absmax = absmax;
arg->out = out;
arg->block_end = block_end;
arg->block_idx = block_idx;
arg->threadidx = block_idx/BLOCK_SIZE;
pthread_create(&threads[block_idx/BLOCK_SIZE], NULL, &quantize_block, (void *)arg);
}
for(int i = 0; i < num_blocks; i++)
int err = pthread_join(threads[i], NULL);
free(threads);
for(int i = 0; i < num_blocks; i++)
free(args[i]);
free(args);
}
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n)
{
for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
{
int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
int block_end = block_idx + valid_items;
for (int i = block_idx; i < block_end; i++)
out[i] = code[A[i]]*absmax[block_idx/BLOCK_SIZE];
}
}
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{
void histogramScatterAdd2D(float *histogram, int *index1, int *index2, float *src, int maxidx1, int n) {
int threads = 512;
int blocks = n / threads;
blocks = n % threads == 0 ? blocks : blocks + 1;
@ -134,8 +23,8 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
{
template<typename T>
void estimateQuantiles(T *A, float *code, float offset, int n) {
int blocks = n / 4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256 * sizeof(float)));
@ -143,32 +32,30 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void quantize(float *code, float *A, unsigned char *out, int n)
{
void quantize(float *code, float *A, unsigned char *out, int n) {
int blocks = n / 1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1;
kQuantize<<<blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void dequantize(float *code, unsigned char *A, float *out, int n)
{
void dequantize(float *code, unsigned char *A, float *out, int n) {
int blocks = n / 1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1;
kDequantize<<<blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
{
template<typename T, int STOCHASTIC>
void quantizeBlockwise(float *code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) {
int blocks = n / 4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
kQuantizeBlockwise < T, 4096, 4, STOCHASTIC ><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
{
template<typename T>
void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) {
int blocks = n / blocksize;
blocks = n % blocksize == 0 ? blocks : blocks + 1;
if (blocksize == 4096)
@ -178,43 +65,45 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
template<typename T, int OPTIMIZER>
void optimizer32bit(T *g, T *p,
float *state1, float *state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) {
int blocks = n / 4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
switch(OPTIMIZER)
{
switch (OPTIMIZER) {
case ADAM:
if(max_unorm > 0.0f)
{
if (max_unorm > 0.0f) {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
kPreconditionOptimizer32bit2State < T, OPTIMIZER, 4096,
8 ><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit2State < T,
OPTIMIZER ><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
if(max_unorm > 0.0f)
{
if (max_unorm > 0.0f) {
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
kPreconditionOptimizer32bit1State < T, OPTIMIZER, 4096,
8 ><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State < T,
OPTIMIZER ><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
}
}
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
template<typename T, int OPTIMIZER>
void optimizerStatic8bit(T *p, T *g,
unsigned char *state1, unsigned char *state2,
float *unorm, float max_unorm, float param_norm,
float beta1, float beta2,
@ -222,21 +111,21 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
float *quantiles1, float *quantiles2,
float *max1, float *max2, float *new_max1, float *new_max2,
float weight_decay,
const float gnorm_scale, int n)
{
const float gnorm_scale, int n) {
int blocks = n / 4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1;
if (max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); }
switch(OPTIMIZER)
{
switch (OPTIMIZER) {
case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
kPreconditionOptimizerStatic8bit2State < T,
OPTIMIZER ><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
kOptimizerStatic8bit2State < T,
OPTIMIZER ><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@ -244,7 +133,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
kPreconditionOptimizerStatic8bit1State < T,
OPTIMIZER ><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State < T, OPTIMIZER ><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
@ -260,18 +150,19 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
#define BLOCKSIZE_1STATE 2048
#define NUM_1STATE 8
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
template<typename T, int OPTIMIZER>
void optimizerStatic8bitBlockwise(T *p, T *g,
unsigned char *state1, unsigned char *state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
float *quantiles1, float *quantiles2, float *absmax1, float *absmax2, float weight_decay,
const float gnorm_scale, bool skip_zeros, int n) {
int blocks = 0;
switch(OPTIMIZER)
{
switch (OPTIMIZER) {
case ADAM:
blocks = n / BLOCKSIZE_2STATE;
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
kOptimizerStatic8bit2StateBlockwise < T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE ><<<blocks, BLOCKSIZE_2STATE /
NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@ -280,7 +171,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
case ADAGRAD:
blocks = n / BLOCKSIZE_1STATE;
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
kOptimizerStatic8bit1StateBlockwise < T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE ><<<blocks, BLOCKSIZE_1STATE /
NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
@ -288,9 +180,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
}
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
template<typename T>
void percentileClipping(T *g, float *gnorm_vec, int step, const int n) {
int blocks = n / 2048;
blocks = n % 2048 == 0 ? blocks : blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float)));
@ -304,13 +195,23 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
//==============================================================
template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n);
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void
quantizeBlockwise<half, 0>(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
template void
quantizeBlockwise<float, 0>(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
template void
quantizeBlockwise<half, 1>(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
template void
quantizeBlockwise<float, 1>(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
#define MAKE_optimizer32bit(name, gtype) \
@ -320,12 +221,19 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bit(name, gtype) \
@ -338,11 +246,17 @@ template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char
float weight_decay, \
const float gnorm_scale, int n); \
MAKE_optimizerStatic8bit(ADAM, half)
MAKE_optimizerStatic8bit(ADAM, float)
MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
@ -350,14 +264,23 @@ template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
template void percentileClipping(float *g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half *g, float *gnorm_vec, int step, const int n);

View File

@ -68,16 +68,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
#endif

View File

@ -3,7 +3,10 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#if BUILD_CUDA
#include <ops.cuh>
#endif
#include <cpu_ops.h>
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
@ -12,6 +15,7 @@
// UNMANGLED CALLS
//===================================================================================
#if BUILD_CUDA
void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
@ -78,9 +82,11 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
#endif
extern "C"
{
#if BUILD_CUDA
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
@ -147,11 +153,11 @@ extern "C"
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
}

View File

@ -6,17 +6,18 @@ import os
from setuptools import setup, find_packages
def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read()
version = os.getenv("CUDA_VERSION", "cpu")
setup(
name = f"bitsandbytes-cuda{os.environ['CUDA_VERSION']}",
version = "0.26.0",
name="bitsandbytes",
version=f"0.26.0+{version}",
author="Tim Dettmers",
author_email="dettmers@cs.washington.edu",
description = ("8-bit optimizers and quantization routines."),
description="8-bit optimizers and quantization routines.",
license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression",
url="http://packages.python.org/bitsandbytes",
@ -29,4 +30,3 @@ setup(
'Topic :: Scientific/Engineering :: Artificial Intelligence'
],
)