Add a CPU-only build option

This commit is contained in:
Max Ryabinin 2022-07-01 17:16:10 +03:00
parent 33efe4a09f
commit 8258b4364a
14 changed files with 454 additions and 382 deletions

View File

@ -10,10 +10,10 @@ NVCC := $(CUDA_HOME)/bin/nvcc
########################################### ###########################################
CSRC := $(ROOT_DIR)/csrc CSRC := $(ROOT_DIR)/csrc
BUILD_DIR:= $(ROOT_DIR)/cuda_build BUILD_DIR:= $(ROOT_DIR)/build
FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_CPP := $(CSRC)/pythonInterface.c FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
@ -46,27 +46,30 @@ CC_CUDA11x += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda110: $(BUILD_DIR) env cuda110: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cuda11x: $(BUILD_DIR) env cuda11x: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
cpuonly: $(BUILD_DIR) env
$(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so
env: env:
@echo "ENVIRONMENT" @echo "ENVIRONMENT"
@ -80,7 +83,7 @@ env:
@echo "============================" @echo "============================"
$(BUILD_DIR): $(BUILD_DIR):
mkdir -p cuda_build mkdir -p build
mkdir -p dependencies mkdir -p dependencies
$(ROOT_DIR)/dependencies/cub: $(ROOT_DIR)/dependencies/cub:

View File

@ -2,8 +2,13 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .optim import adam
from .nn import modules from .nn import modules
from cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA:
from .optim import adam
__pdoc__ = {'libBitsNBytes': False, __pdoc__ = {'libBitsNBytes': False,
'optim.optimizer.Optimizer8bit': False, 'optim.optimizer.Optimizer8bit': False,
'optim.optimizer.MockArgs': False 'optim.optimizer.MockArgs': False

View File

@ -0,0 +1,13 @@
import ctypes as ct
import os
from warnings import warn
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
try:
lib.cadam32bit_g32
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.")
COMPILED_WITH_CUDA = False

View File

@ -3,18 +3,18 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
import os
import random import random
from typing import Tuple from typing import Tuple
import torch import torch
from torch import Tensor from torch import Tensor
lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so') from cextension import lib, COMPILED_WITH_CUDA
name2qmap = {} name2qmap = {}
if COMPILED_WITH_CUDA:
''' C FUNCTIONS FOR OPTIMIZERS ''' ''' C FUNCTIONS FOR OPTIMIZERS '''
str2optimizer32bit = {} str2optimizer32bit = {}
str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16) str2optimizer32bit['adam'] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16) str2optimizer32bit['momentum'] = (lib.cmomentum32bit_g32, lib.cmomentum32bit_g16)
@ -138,7 +138,7 @@ def estimate_quantiles(A: Tensor, out: Tensor=None, offset: float=1/512) -> Tens
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
else: else:
raise NotImplementError(f'Not supported data type {A.dtype}') raise NotImplementedError(f'Not supported data type {A.dtype}')
return out return out
def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor: def quantize_blockwise(A: Tensor, code: Tensor=None, absmax: Tensor=None, rand=None, out: Tensor=None) -> Tensor:
@ -384,7 +384,7 @@ def optimizer_update_32bit(optimizer_name:str, g: Tensor, p: Tensor, state1: Ten
param_norm = torch.norm(p.data.float()) param_norm = torch.norm(p.data.float())
if optimizer_name not in str2optimizer32bit: if optimizer_name not in str2optimizer32bit:
raise NotImplementError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}') raise NotImplementedError(f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}')
if g.dtype == torch.float32 and state1.dtype == torch.float32: if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm), str2optimizer32bit[optimizer_name][0](get_ptr(g), get_ptr(p), get_ptr(state1), get_ptr(state2), get_ptr(unorm_vec), ct.c_float(max_unorm),

View File

@ -2,6 +2,10 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA:
from .adam import Adam, Adam8bit, Adam32bit from .adam import Adam, Adam8bit, Adam32bit
from .adamw import AdamW, AdamW8bit, AdamW32bit from .adamw import AdamW, AdamW8bit, AdamW32bit
from .sgd import SGD, SGD8bit, SGD32bit from .sgd import SGD, SGD8bit, SGD32bit

View File

@ -31,6 +31,6 @@ class RMSprop32bit(Optimizer1State):
if alpha == 0: if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!')
if centered: if centered:
raise NotImplementError(f'Centered RMSprop is not supported!') raise NotImplementedError(f'Centered RMSprop is not supported!')
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)

39
csrc/common.cpp Normal file
View File

@ -0,0 +1,39 @@
#include <common.h>
#include <float.h>
void *quantize_block(void *arguments) {
// 1. find absmax in block
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
// 4. check minimal distance
// 5. store index
struct quantize_block_args *args = (quantize_block_args *) arguments;
// 1. find absmax in block
float absmax_block = -FLT_MAX;
for (int i = args->block_idx; i < args->block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i]));
args->absmax[args->block_idx / BLOCK_SIZE] = absmax_block;
for (int i = args->block_idx; i < args->block_end; i++) {
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float normed_value = args->A[i] / absmax_block;
int idx = args->bin_searcher->scalar(normed_value);
// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
if (idx < 255) {
float dist_left = fabs(normed_value - (args->code[idx]));
float dist_right = fabs(normed_value - (args->code[idx + 1]));
if (dist_right < dist_left) { idx += 1; }
}
// 5. store index
args->out[i] = (unsigned char) idx;
}
return NULL;
}

23
csrc/common.h Normal file
View File

@ -0,0 +1,23 @@
#include <BinSearch.h>
#ifndef common
#define common
using namespace BinSearch;
struct quantize_block_args {
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
int block_end;
int block_idx;
int threadidx;
};
#define BLOCK_SIZE 4096
void *quantize_block(void *arguments);
#endif

57
csrc/cpu_ops.cpp Normal file
View File

@ -0,0 +1,57 @@
#include <BinSearch.h>
#include <pthread.h>
#include <common.h>
using namespace BinSearch;
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n) {
for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
int block_end = block_idx + valid_items;
for (int i = block_idx; i < block_end; i++)
out[i] = code[A[i]] * absmax[block_idx / BLOCK_SIZE];
}
}
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n) {
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
int num_blocks = n / BLOCK_SIZE;
num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * num_blocks);
struct quantize_block_args **args = (quantize_block_args **) malloc(num_blocks * sizeof(quantize_block_args *));
for (int i = 0; i < num_blocks; i++)
args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args));
const uint32 elements_code = 256;
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
for (int block_idx = 0; block_idx < n; block_idx += BLOCK_SIZE) {
int valid_items = n - block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
int block_end = block_idx + valid_items;
struct quantize_block_args *arg = args[block_idx / BLOCK_SIZE];
arg->bin_searcher = &bin_searcher;
arg->code = code;
arg->A = A;
arg->absmax = absmax;
arg->out = out;
arg->block_end = block_end;
arg->block_idx = block_idx;
arg->threadidx = block_idx / BLOCK_SIZE;
pthread_create(&threads[block_idx / BLOCK_SIZE], NULL, &quantize_block, (void *) arg);
}
for (int i = 0; i < num_blocks; i++)
int err = pthread_join(threads[i], NULL);
free(threads);
for (int i = 0; i < num_blocks; i++)
free(args[i]);
free(args);
}

9
csrc/cpu_ops.h Normal file
View File

@ -0,0 +1,9 @@
#ifndef BITSANDBYTES_CPU_OPS_H
#define BITSANDBYTES_CPU_OPS_H
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
#endif

View File

@ -8,125 +8,14 @@
#include <cub/device/device_scan.cuh> #include <cub/device/device_scan.cuh>
#include <limits> #include <limits>
#include <BinSearch.h> #include <BinSearch.h>
#include <common.h>
using namespace BinSearch; using namespace BinSearch;
using std::cout; using std::cout;
using std::endl; using std::endl;
#define BLOCK_SIZE 4096 void histogramScatterAdd2D(float *histogram, int *index1, int *index2, float *src, int maxidx1, int n) {
struct quantize_block_args
{
BinAlgo<Scalar, float, Direct2> *bin_searcher;
float *code;
float *A;
float *absmax;
unsigned char *out;
int block_end;
int block_idx;
int threadidx;
};
void *quantize_block(void *arguments)
{
// 1. find absmax in block
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
// 4. check minimal distance
// 5. store index
struct quantize_block_args *args = (quantize_block_args*)arguments;
// 1. find absmax in block
float absmax_block = -FLT_MAX;
for (int i = args->block_idx; i < args->block_end; i++)
absmax_block = fmax(absmax_block, fabs(args->A[i]));
args->absmax[args->block_idx/BLOCK_SIZE] = absmax_block;
for (int i = args->block_idx; i < args->block_end; i++)
{
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
// 3. do binary search to find the closest value
float normed_value = args->A[i]/absmax_block;
int idx = args->bin_searcher->scalar(normed_value);
// 4. check minimal distance
// The binary search returns always the value to the left, which might not be the closest value
if(idx < 255)
{
float dist_left = fabs(normed_value-(args->code[idx]));
float dist_right = fabs(normed_value-(args->code[idx+1]));
if(dist_right < dist_left){ idx+=1; }
}
// 5. store index
args->out[i] = (unsigned char)idx;
}
return NULL;
}
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n)
{
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
code[0] = -1.0f;
int num_blocks = n/BLOCK_SIZE;
num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t)*num_blocks);
struct quantize_block_args **args = (quantize_block_args**)malloc(num_blocks*sizeof(quantize_block_args*));
for(int i = 0; i < num_blocks; i++)
args[i] = (quantize_block_args*)malloc(sizeof(quantize_block_args));
const uint32 elements_code = 256;
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
{
int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
int block_end = block_idx + valid_items;
struct quantize_block_args *arg = args[block_idx/BLOCK_SIZE];
arg->bin_searcher = &bin_searcher;
arg->code = code;
arg->A = A;
arg->absmax = absmax;
arg->out = out;
arg->block_end = block_end;
arg->block_idx = block_idx;
arg->threadidx = block_idx/BLOCK_SIZE;
pthread_create(&threads[block_idx/BLOCK_SIZE], NULL, &quantize_block, (void *)arg);
}
for(int i = 0; i < num_blocks; i++)
int err = pthread_join(threads[i], NULL);
free(threads);
for(int i = 0; i < num_blocks; i++)
free(args[i]);
free(args);
}
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n)
{
for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
{
int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
int block_end = block_idx + valid_items;
for (int i = block_idx; i < block_end; i++)
out[i] = code[A[i]]*absmax[block_idx/BLOCK_SIZE];
}
}
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
{
int threads = 512; int threads = 512;
int blocks = n / threads; int blocks = n / threads;
blocks = n % threads == 0 ? blocks : blocks + 1; blocks = n % threads == 0 ? blocks : blocks + 1;
@ -134,8 +23,8 @@ void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *sr
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n) template<typename T>
{ void estimateQuantiles(T *A, float *code, float offset, int n) {
int blocks = n / 4096; int blocks = n / 4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; blocks = n % 4096 == 0 ? blocks : blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256 * sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(code, 0, 256 * sizeof(float)));
@ -143,32 +32,30 @@ template <typename T> void estimateQuantiles(T *A, float *code, float offset, in
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void quantize(float *code, float *A, unsigned char *out, int n) void quantize(float *code, float *A, unsigned char *out, int n) {
{
int blocks = n / 1024; int blocks = n / 1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1; blocks = n % 1024 == 0 ? blocks : blocks + 1;
kQuantize<<<blocks, 1024>>>(code, A, out, n); kQuantize<<<blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void dequantize(float *code, unsigned char *A, float *out, int n) void dequantize(float *code, unsigned char *A, float *out, int n) {
{
int blocks = n / 1024; int blocks = n / 1024;
blocks = n % 1024 == 0 ? blocks : blocks + 1; blocks = n % 1024 == 0 ? blocks : blocks + 1;
kDequantize<<<blocks, 1024>>>(code, A, out, n); kDequantize<<<blocks, 1024>>>(code, A, out, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) template<typename T, int STOCHASTIC>
{ void quantizeBlockwise(float *code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n) {
int blocks = n / 4096; int blocks = n / 4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; blocks = n % 4096 == 0 ? blocks : blocks + 1;
kQuantizeBlockwise < T, 4096, 4, STOCHASTIC ><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n); kQuantizeBlockwise < T, 4096, 4, STOCHASTIC ><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) template<typename T>
{ void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) {
int blocks = n / blocksize; int blocks = n / blocksize;
blocks = n % blocksize == 0 ? blocks : blocks + 1; blocks = n % blocksize == 0 ? blocks : blocks + 1;
if (blocksize == 4096) if (blocksize == 4096)
@ -178,43 +65,45 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, template<typename T, int OPTIMIZER>
void optimizer32bit(T *g, T *p,
float *state1, float *state2, float *unorm, float max_unorm, float param_norm, float *state1, float *state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) {
{
int blocks = n / 4096; int blocks = n / 4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; blocks = n % 4096 == 0 ? blocks : blocks + 1;
switch(OPTIMIZER) switch (OPTIMIZER) {
{
case ADAM: case ADAM:
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit2State < T, OPTIMIZER, 4096,
8 ><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit2State < T,
OPTIMIZER ><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
case MOMENTUM: case MOMENTUM:
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
if(max_unorm > 0.0f) if (max_unorm > 0.0f) {
{
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float)));
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n); kPreconditionOptimizer32bit1State < T, OPTIMIZER, 4096,
8 ><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); kOptimizer32bit1State < T,
OPTIMIZER ><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
} }
} }
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, template<typename T, int OPTIMIZER>
void optimizerStatic8bit(T *p, T *g,
unsigned char *state1, unsigned char *state2, unsigned char *state1, unsigned char *state2,
float *unorm, float max_unorm, float param_norm, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float beta1, float beta2,
@ -222,21 +111,21 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
float *quantiles1, float *quantiles2, float *quantiles1, float *quantiles2,
float *max1, float *max2, float *new_max1, float *new_max2, float *max1, float *max2, float *new_max1, float *new_max2,
float weight_decay, float weight_decay,
const float gnorm_scale, int n) const float gnorm_scale, int n) {
{
int blocks = n / 4096; int blocks = n / 4096;
blocks = n % 4096 == 0 ? blocks : blocks + 1; blocks = n % 4096 == 0 ? blocks : blocks + 1;
if (max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); } if (max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1 * sizeof(float))); }
switch(OPTIMIZER) switch (OPTIMIZER) {
{
case ADAM: case ADAM:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); kPreconditionOptimizerStatic8bit2State < T,
OPTIMIZER ><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, kOptimizerStatic8bit2State < T,
OPTIMIZER ><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
@ -244,7 +133,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
case RMSPROP: case RMSPROP:
case ADAGRAD: case ADAGRAD:
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1 * sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); kPreconditionOptimizerStatic8bit1State < T,
OPTIMIZER ><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State < T, OPTIMIZER ><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr, kOptimizerStatic8bit1State < T, OPTIMIZER ><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
@ -260,18 +150,19 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
#define BLOCKSIZE_1STATE 2048 #define BLOCKSIZE_1STATE 2048
#define NUM_1STATE 8 #define NUM_1STATE 8
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, template<typename T, int OPTIMIZER>
void optimizerStatic8bitBlockwise(T *p, T *g,
unsigned char *state1, unsigned char *state2, float beta1, float beta2, float eps, int step, float lr, unsigned char *state1, unsigned char *state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) float *quantiles1, float *quantiles2, float *absmax1, float *absmax2, float weight_decay,
{ const float gnorm_scale, bool skip_zeros, int n) {
int blocks = 0; int blocks = 0;
switch(OPTIMIZER) switch (OPTIMIZER) {
{
case ADAM: case ADAM:
blocks = n / BLOCKSIZE_2STATE; blocks = n / BLOCKSIZE_2STATE;
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1; blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr, kOptimizerStatic8bit2StateBlockwise < T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE ><<<blocks, BLOCKSIZE_2STATE /
NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
@ -280,7 +171,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
case ADAGRAD: case ADAGRAD:
blocks = n / BLOCKSIZE_1STATE; blocks = n / BLOCKSIZE_1STATE;
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, kOptimizerStatic8bit1StateBlockwise < T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE ><<<blocks, BLOCKSIZE_1STATE /
NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
break; break;
@ -288,9 +180,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
} }
template<typename T>
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n) void percentileClipping(T *g, float *gnorm_vec, int step, const int n) {
{
int blocks = n / 2048; int blocks = n / 2048;
blocks = n % 2048 == 0 ? blocks : blocks + 1; blocks = n % 2048 == 0 ? blocks : blocks + 1;
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float))); CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1 * sizeof(float)));
@ -304,13 +195,23 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
//============================================================== //==============================================================
template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n);
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template void
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); quantizeBlockwise<half, 0>(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template void
quantizeBlockwise<float, 0>(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
template void
quantizeBlockwise<half, 1>(float *code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
template void
quantizeBlockwise<float, 1>(float *code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n);
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
#define MAKE_optimizer32bit(name, gtype) \ #define MAKE_optimizer32bit(name, gtype) \
@ -320,12 +221,19 @@ template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
MAKE_optimizer32bit(ADAM, half) MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float) MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float) MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(ADAGRAD, half) MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float) MAKE_optimizer32bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bit(name, gtype) \ #define MAKE_optimizerStatic8bit(name, gtype) \
@ -338,11 +246,17 @@ template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char
float weight_decay, \ float weight_decay, \
const float gnorm_scale, int n); \ const float gnorm_scale, int n); \
MAKE_optimizerStatic8bit(ADAM, half) MAKE_optimizerStatic8bit(ADAM, half)
MAKE_optimizerStatic8bit(ADAM, float) MAKE_optimizerStatic8bit(ADAM, float)
MAKE_optimizerStatic8bit(MOMENTUM, half) MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float) MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half) MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float) MAKE_optimizerStatic8bit(RMSPROP, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
@ -350,14 +264,23 @@ template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
MAKE_optimizerStatic8bitBlockwise(half, ADAM); MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM); MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
template void percentileClipping(float *g, float *gnorm_vec, int step, const int n); template void percentileClipping(float *g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half *g, float *gnorm_vec, int step, const int n); template void percentileClipping(half *g, float *gnorm_vec, int step, const int n);

View File

@ -68,16 +68,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n); template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
#endif #endif

View File

@ -3,7 +3,10 @@
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#if BUILD_CUDA
#include <ops.cuh> #include <ops.cuh>
#endif
#include <cpu_ops.h>
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
@ -12,6 +15,7 @@
// UNMANGLED CALLS // UNMANGLED CALLS
//=================================================================================== //===================================================================================
#if BUILD_CUDA
void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); } void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); } void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
@ -78,9 +82,11 @@ void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, un
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); } void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
#endif
extern "C" extern "C"
{ {
#if BUILD_CUDA
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
@ -147,11 +153,11 @@ extern "C"
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
} }

View File

@ -6,17 +6,18 @@ import os
from setuptools import setup, find_packages from setuptools import setup, find_packages
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
version = os.getenv("CUDA_VERSION", "cpu")
setup( setup(
name = f"bitsandbytes-cuda{os.environ['CUDA_VERSION']}", name="bitsandbytes",
version = "0.26.0", version=f"0.26.0+{version}",
author="Tim Dettmers", author="Tim Dettmers",
author_email="dettmers@cs.washington.edu", author_email="dettmers@cs.washington.edu",
description = ("8-bit optimizers and quantization routines."), description="8-bit optimizers and quantization routines.",
license="MIT", license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression", keywords="gpu optimizers optimization 8-bit quantization compression",
url="http://packages.python.org/bitsandbytes", url="http://packages.python.org/bitsandbytes",
@ -29,4 +30,3 @@ setup(
'Topic :: Scientific/Engineering :: Artificial Intelligence' 'Topic :: Scientific/Engineering :: Artificial Intelligence'
], ],
) )