Most tests passing.

This commit is contained in:
Tim Dettmers 2022-07-22 14:41:05 -07:00
parent 4cd7ea62b2
commit c771b3a75a
16 changed files with 5270 additions and 160 deletions

View File

@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .nn import modules from .nn import modules
from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
from .cextension import COMPILED_WITH_CUDA from .cextension import COMPILED_WITH_CUDA
if COMPILED_WITH_CUDA: if COMPILED_WITH_CUDA:
from .optim import adam from .optim import adam
__pdoc__ = {'libBitsNBytes': False, __pdoc__ = {'libbitsandbytes': False,
'optim.optimizer.Optimizer8bit': False, 'optim.optimizer.Optimizer8bit': False,
'optim.optimizer.MockArgs': False 'optim.optimizer.MockArgs': False
} }

View File

View File

@ -0,0 +1,307 @@
import torch
import bitsandbytes as bnb
import bitsandbytes.functional as F
from dataclasses import dataclass
tensor = torch.Tensor
'''
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
'''
class GlobalOutlierPooler(object):
_instance = None
def __init__(self):
raise RuntimeError('Call get_instance() instead')
def initialize(self):
self.outliers = set()
self.model_dim = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls.__new__(cls)
cls._instance.initialize()
return cls._instance
def add_outliers(self, outlier_idx, feature_dim):
if self.model_dim is None: self.model_dim = feature_dim
if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
self.outliers.update(outlier_idx.tolist())
def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)
class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)
else:
if len(B.shape) == 2: dim = 0
else: dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB)
output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
if A.requires_grad or B.requires_grad:
ctx.save_for_backward(A, B)
ctx.quant_type = quant_type
ctx.precision = precision
return output
@staticmethod
def backward(ctx, grad_output):
A, B = ctx.saved_tensors
quant_type = ctx.quant_type
precision = ctx.precision
grad_A = grad_B = None
if B.requires_grad:
if len(A.shape) == 3:
dims = [0, 1]
# bsi -> ibs
permute_dim = [0, 2, 1]
else:
dims = [0]
# bs -> sb
permute_dim = [1, 0]
if precision[1] != 8:
with torch.no_grad():
grad_B = torch.matmul(A.permute(permute_dim), grad_output)
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
if not grad_output.is_contiguous(): grad_output.contiguous()
qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
if not A.is_contiguous(): A = A.contiguous()
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
if A.requires_grad:
if len(grad_output.shape) == 3: dims = [2]
else: dims = [1]
if len(B.shape) == 3:
# bio -> boi
permute_dim = [0, 2, 1]
dim_B = dims
else:
# io -> oi
permute_dim = [1, 0]
dim_B = [1]
if precision[2] != 8:
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
return grad_A, grad_B, None, None, None
mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
@dataclass
class MatmulLtState:
CB = None
CxB = None
SB = None
SCB = None
CxBt = None
SBt = None
CBt = None
subB = None
outlier_pool = None
has_accumulated_gradients = False
threshold = 0.0
idx = None
is_training = True
has_fp16_weights = True
use_pool = False
formatB = F.get_special_format_str()
def reset_grads(self):
self.CB = None
self.CxB = None
self.SB = None
self.SCB = None
self.CxBt = None
self.SBt = None
self.CBt = None
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
requires_gradA = A.requires_grad
requires_gradB = B.requires_grad
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
# 1. Quantize A
if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
idx = torch.unique(coo_tensorA.colidx).long()
CA[:, idx] = 0
CAt[:, idx] = 0
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# generate outlier index and subB
outlier_idx = torch.unique(coo_tensorA.colidx).long()
state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# do not use pool for 2nd FFN layer
state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
else:
state.idx = outlier_idx
state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
if state.idx is not None:
# extract outliers
CA[:, state.idx] = 0
CAt[:, state.idx] = 0
subA = A[:, state.idx]
else:
subA = None
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
C32A, SA = F.transform(CA, 'col32')
# 2. Quantize B
if state.has_fp16_weights:
has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed: B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False
shapeB = state.SB[0]
if len(input_shape) == 3:
output_shape = (input_shape[0], input_shape[1], shapeB[0])
else:
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
# 4. Mixed-precision decomposition matmul
if state.threshold > 0.0 and coo_tensorA is not None and subA is not None:
output += torch.matmul(subA, state.subB)
# 5. Save state
ctx.state = state
ctx.formatB = formatB
ctx.grad_shape = input_shape
ctx.req_grads = [requires_gradA, requires_gradB]
if requires_gradA or requires_gradB:
ctx.tensors = (CAt, subA)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone
return clone_func(output.view(output_shape))
@staticmethod
def backward(ctx, grad_output):
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
if len(grad_output.shape) == 3:
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
grad_A = grad_B = None
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, 'col32')
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
return grad_A, grad_B, None, None, None, None, None
matmul = MatMul8bitLt.apply
def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, state)

View File

@ -6,6 +6,8 @@ lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
try: try:
lib.cadam32bit_g32 lib.cadam32bit_g32
lib.get_context.restype = ct.c_void_p
lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True COMPILED_WITH_CUDA = True
except AttributeError: except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. " warn("The installed version of bitsandbytes was compiled without GPU support. "

File diff suppressed because one or more lines are too long

View File

@ -2,4 +2,4 @@
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .modules import StableEmbedding, Embedding from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params

View File

@ -3,14 +3,19 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import bitsandbytes as bnb
from typing import Optional from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
from torch import Tensor from torch import Tensor, device, dtype
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
T = TypeVar('T', bound='torch.nn.Module')
class StableEmbedding(torch.nn.Embedding): class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding):
self.norm_type, self.scale_grad_by_freq, self.sparse) self.norm_type, self.scale_grad_by_freq, self.sparse)
return emb return emb
class Int8Params(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
cls.has_fp16_weights = has_fp16_weights
cls.CB = None
cls.SCB = None
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def cuda(self, device):
if self.has_fp16_weights:
return super().cuda(device)
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
self.data = CB
setattr(self, 'CB', CB)
setattr(self, 'SCB', SCB)
return self
@overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...) -> T:
...
@overload
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
...
@overload
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
...
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
else:
new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
new_param.CB = self.CB
new_param.SCB = self.SCB
return new_param
class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
super(Linear8bitLt, self).__init__(input_features, output_features, bias)
self.state = bnb.MatmulLtState()
self.index=index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
def init_8bit_state(self):
self.state.CB = self.weight.CB
self.state.SCB = self.weight.SCB
self.weight.CB = None
self.weight.SCB = None
def forward(self, x):
self.state.is_training = self.training
if self.weight.CB is not None: self.init_8bit_state()
#assert not self.state.has_fp16_weights
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out = bnb.matmul(x, self.weight, state=self.state)
if self.bias is not None:
out += self.bias.unsqueeze(0).expand_as(out)
if not self.state.has_fp16_weights and self.state.CB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
return out
class Linear8bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
super(Linear8bit, self).__init__(input_features, output_features, bias)
self.quant_type = quant_type
self.index = index
self.args = args
self.iter = 0
def forward(self, x):
self.iter += 1
if self.iter % self.args.clip_freq == 0:
with torch.no_grad():
maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
if not dist.is_initialized() or dist.get_rank() == 0:
print('clip', maxval[-1].item())
self.weight.clip_(-maxval[-1], maxval[-1])
if self.args is not None:
out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
else:
out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
return out

View File

@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
} }
} }
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
{
// 0. reset stats to -FLT_MAX
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
// 2. compute col max (per thread); store in smem due to register pressure
// 3. compute row max (per block); store in smem to accumulate full global mem transation
// 4. store data via atomicMax
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;
typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
typedef cub::BlockReduce<int, THREADS> BlockRowSum;
typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
__shared__ union {
typename BlockExchange::TempStorage exchange;
typename BlockRowReduce::TempStorage rowreduce;
typename BlockRowSum::TempStorage rowsum;
typename LoadT::TempStorage loadt;
} temp_storage;
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
__shared__ int smem_row_nnz_values[TILE_ROWS];
//__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
half local_data[ITEMS_PER_THREAD];
float local_data_fp32[ITEMS_PER_THREAD];
float local_col_absmax_values[ITEMS_PER_THREAD];
int local_row_nnz_count = 0;
float row_absmax = -FLT_MAX;
// 0. reset stats to -FLT_MAX
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
}
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_col_absmax_values[j] = -FLT_MAX;
__syncthreads();
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
int i = base_idx;
// we load row after row from the base_position
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
for(int row = 0; row < TILE_ROWS; row++)
{
if(base_row+row >= rows){ break; }
local_row_nnz_count = 0;
i = base_idx + ((row)*cols);
// each thread gets data from the same column
__syncthreads();
LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data[j] = fabsf(local_data[j]);
if(SPARSE_DECOMP)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
if((float)local_data[j] >= nnz_threshold)
{
local_row_nnz_count += 1;
local_data[j] = 0.0f;
}
}
// 2. compute col max (per thread); store in smem due to register pressure
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
// take the col max for this row
// we use shared memory because register pressure is too high if we do this locally
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation
__syncthreads();
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data_fp32[j] = local_data[j];
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
if(SPARSE_DECOMP)
{
__syncthreads();
local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
}
// we store the data temporarily in shared memory so we
// can execute a full atomic block transaction into global memory later
// we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
if(threadIdx.x == 0)
{
smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
// each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
smem_row_nnz_values[row] = local_row_nnz_count;
}
__syncthreads();
}
// 4. store data via atomicMax
// to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
// into a striped arangement: [0, 8, 16, 24, ..] for t0
__syncthreads();
BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_col+threadIdx.x+(j*THREADS) < cols)
{
float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
if(val < local_col_absmax_values[j])
atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
}
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_row+threadIdx.x+(j*THREADS) < rows)
{
float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
}
if(SPARSE_DECOMP)
if(threadIdx.x < TILE_ROWS)
nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];
}
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n)
{
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
// since different row/col stats need to be loaded with each thread.
// (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
// and would lead to low global load utilization.
// (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
// for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
// (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
// This allows for efficient row/col loading from shared memory within the tile.
// We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
// the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
// we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
// shared memory loads.
// data is in 32 column-tile major with tile width 32 columns and numRows rows
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
// L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
// C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
// C2. Compute normalization values and store col values in register
// S1. Store C1 into 16-bit output
// S2. Store col/row statistics of new buffer in shared memory
// We allow for sub-tiles to span multiple col32 tiles. This is okay
// since the items per thread only rely on a single column statistic.
const int n_out = numRows*numCols;
int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
// we have tiles of size numRows*32, thus col only increases every numRows
// num_row_tiles is the tiles after which the column increases by 32
// blockIdx.x is the index of the current tile
int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
// base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
// SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
// subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
// Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
// For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
// 1024*1024/(128*32) = 256 tiles
// 256 tiles are 256*128*32/4 = 256*1024 threads
// 1. Figure out how index relates to the start of the sub-tile
// 2. Each thread < SUBTILE_ROWS calculates row index
// 3. Load striped and store in shared memory
int local_values[ITEMS_PER_THREAD];
half local_output[ITEMS_PER_THREAD];
float local_rowStats[ITEMS_PER_THREAD];
__shared__ float smem_rowStats[SUBTILE_ROWS];
typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
__shared__ typename LoadInt32::TempStorage loadint32;
__shared__ typename ExchangeInt32::TempStorage exchangeint32;
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
float colStat = col >= numCols ? 0.0f : colStats[col];
// no block loads for rows for now -- keep it simple
for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
{
// todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
int row = (base_row+j) % numRows; // wrap around
// each warp accesses the same element, for four consequitive elements
// todo: update description about striped shared memory, it is not needed
// rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
smem_rowStats[j] = rowStats[row];
}
__syncthreads();
// each block processes SUBTILE_ROWS*32 elements
const int items_per_load = THREADS*ITEMS_PER_THREAD;
const int rows_per_load = items_per_load/32;
int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
int row_offset = 0;
// subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);
for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
{
int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
int valid_items = valid_rows*32;
if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
break;
// L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);
ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat);
//absmax_col = fmax(fabsf(local_output[j]), absmax_col);
// we store data in row major
// to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
// so that each thread holds ITEMS_PER_THREAD consecutive items for each row
// this way throughput into storage is increased by a factor of ~2x
// for now we use a simple store
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
if(outIdx< n_out && col < numCols)
out[outIdx] = local_output[j];
}
row_offset += rows_per_load;
}
}
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
{
// assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
// Each thread reads the same column but multiple rows
// Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
// 2. quantize data with row/col stats
// 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;
typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
__shared__ typename LoadHalf::TempStorage loadhalf;
typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
__shared__ typename StoreInt8::TempStorage storeint8;
__shared__ float smem_row_stats[TILE_ROWS];
__shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];
half local_data[ITEMS_PER_THREAD];
float local_col_stats[ITEMS_PER_THREAD];
char local_quantized_data[ITEMS_PER_THREAD];
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);
for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
{
if(base_row + i < rows)
smem_row_stats[i] = rowStats[base_row+i];
if(SPARSE_DECOMP)
smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
}
__syncthreads();
// we load row after row from the base_position
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
for(int row = 0; row < TILE_ROWS; row++)
{
if(base_row + row >= rows){ break; }
int i = base_idx + (row*cols);
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
float row_stat = __fdividef(127.0f, smem_row_stats[row]);
// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
// we already pre-normalized the col/row stat:
// what this does is float/absmax*127 = int8
if(SPARSE_DECOMP)
{
if(fabsf((float)local_data[j]) >= threshold)
{
local_quantized_data[j] = 0;
int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);
rowidx[old_idx] = base_row+row;
colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
val[old_idx] = local_data[j];
}
else
{
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
}
}
else
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
}
StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);
// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
// we already pre-normalized the col/row stat:
// what this does is float/absmax*127 = int8
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
}
__syncthreads();
StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);
}
}
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
{
// 0. Load data into 32*32 shared memory tiles
// 1. transpose / reorder in shared memory
// 2. store
// COL32 FORMAT:
// rows*32 tiles
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
// AMPERE FORMAT:
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
// To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
// As such we need:
// at least 32*4 shared memory tiles for col32; preferably 32*32
// at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
// at least 32*8 shared memory tiles for col4_turing: preferably 32*32
// for efficient loading of row major we need to load 128 elements and repeat this 32 items
// this would imply a 32x128 shared memory tile -> 4kb
// It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
// we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
// for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
// register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
//
// to make the shared memory work with that occupancy we might need to union the block loads/stores
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
// we load 128 bytes per warp with
// 32 rows for transposes that fill col32 types
// so that we can have contiguous stores
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
char local_data[ITEMS_PER_THREAD];
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
__shared__ typename BlockExchange::TempStorage temp_storage;
// we load row after row from the base_position
// Load data row by row
int warps = blockDim.x/32;
int warp_id = threadIdx.x/32;
int warp_lane = threadIdx.x % 32;
int offset = 0;
int smem_row = 0;
// each warp loads one row of 128 bytes
for(int row = warp_id; row < TILE_ROWS; row+=warps)
{
int i = base_idx + (row*cols);
// we load up to 128 bytes/items per load
int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;
// 0. Load data into 32*32 shared memory tiles
if(base_row + row < rows)
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int col_idx = warp_lane+(j*32);
if(col_idx < valid_items)
local_data[j] = A[i+col_idx];
else
local_data[j] = 0;
}
}
else
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data[j] = 0;
}
if(TRANSPOSE)
{
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int local_col = (32*j)+warp_lane;
//int local_row = row;
// store as 256x32
smem_data[(local_col*33) + row] = local_data[j];
}
}
else
{
// treat smem as 32x256, that is 32 rows and 256 columns
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
}
smem_row += warps;
// 1. transpose / reorder in shared memory
if(smem_row % 32 == 0)
{
smem_row = 0;
__syncthreads();
for(int subrow = warp_id; subrow < 32; subrow+=warps)
{
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
switch(FORMAT)
{
case COL32:
if(TRANSPOSE)
{
// data lies in shared memory in the following way:
// row0 [col0 col1 ... col31]
// row1 [col0 col1 ... col31]
// ...
//
// As such we read consequtive entries with 256 threads (8rows x 32 columns)
// as j increase, the row increase by a factor of 8
// We load 8 rows per subrow loop, and subrow increase by 8 per loop
// so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size outRows*32 and base_row is done in increments of 32
offset = base_row*outRows;
out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
offset = (base_col/32)*(32*rows);
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
}
}
break;
case COL_TURING:
// TURING FORMAT:
// 8*32 tiles with 4*4 subtiles
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
// index increases by 32
//
// [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
if(TRANSPOSE)
{
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size 8*32 = 256 elements offset
// for each row offset of 8 we increaes the tile first
// after all rows are exhausted, we increase the col
int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
// we increase by row_tile_column every 32 columns
// base_row increase in increments of 32
//int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
//int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 256*outRows/8*base_row/32 = outRows*base_row
int col_offset = outRows*base_row;
offset = row_offset+col_offset;
// since we process even number of rows with each j (8) and with each subrow (8j) we can determine
// odd or even rows with the warp_id (each warp processes one row)
// the col is warp_lane (max 32 columns per row) and the row warp_id
if(warp_id % 2 == 1)
// odd
offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
else
// even
offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);
out[offset] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
// set offset designates the tile offset among the 8*32 tiles
// we first increase rows and then columns. Since we load 128 columns at once
// we increase the offset by outRows*32 every 32 columns
// additionally, we increase the offset by 8*32=256 every 8 rows
offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
// first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
// each of these has 32 values in total for 32*4 = 128 as offset if odd
// every set of 4 columns increases the total offset by 16
// each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
// this happends every 8 rows anew (subrow % 8)
// one writes 4 columns at once that is (col % 4) for the particular index in the subtile
int subcol = warp_lane;
// add local offset (4x4 sub-tile)
if(subrow % 2 == 1)
// odd
offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
else
// even
offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);
out[offset] = data;
}
}
break;
case COL_AMPERE:
// AMPERE FORMAT:
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
if(TRANSPOSE)
{
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
//const int local_row = warp_id; // each warp_id is one row
//const int block_row = base_col; // block offset for row
//const int local_col = warp_lane
//const int global_col = base_row; // block offset for col
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
{
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
// each 32 columns we have new tile
// each tile has size 32*32 = 1024 elements offset
// for each row offset of 32 we increaes the tile first
// after all rows are exhausted, we increase the col
int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
// we increase by row_tile_column every 32 columns
// base_row increase in increments of 32
//int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
//int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 1024*outRows/32*base_row/32 = outRows*base_row
int col_offset = outRows*base_row;
offset = row_offset+col_offset;
// same as in the non-transpose case (see below)
// the difference is that now rows = cols
// in this case warp_id = subrow
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);
// global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
out[offset + (ampere_row*32) + warp_lane] = data;
}
}
else
{
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
{
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
// set offset designates the tile offset among the 32*32 tiles
// we first increase rows and then columns. Since we load 128 columns at once
// we increase the offset by outRows*32 every 32 columns
// additionally, we increase the offset by 32*32=1024 every 32 rows
offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);
// global offset + row with 32 cols each + 32 cols per j + col_idx
out[offset + (local_row*32) + warp_lane] = data;
}
}
break;
}
}
}
}
}
}
#define C 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
// If a block finishes, the next one is scheduled. Since the last blocks like have fewer
// elements they finish faster "fillin up" the gaps left by larger blocks
// without tensor cores
// 1. use rowidx_length to find what to load (as many blocks as there are rows)
// 2. Load A into registers
// 3. each warp loads all required rows of B but each warp is offset by k
// 4. Do mma operations that accumulate into registers
// 5. Each warp stores its output row into matrix C
const int count = max_count[blockIdx.x];
const int local_max_idx = max_idx[blockIdx.x];
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
const int local_row_idx = rowidx[offset];
const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32;
const int warp_offset = (warp_id*32)*SPMM_ITEMS;
const int num_items = BITS == 8 ? 8 : 8;
int idx_col_B = warp_offset;
int local_idx_col_B_offset = 0;
half local_valA[MAX_SPARSE_COUNT];
int local_colidxA[MAX_SPARSE_COUNT];
half local_valC[SPMM_ITEMS];
T local_valsB[num_items];
half local_valOut[num_items];
// 128 byte loads per warp == 4 bytes per thread
// 2. Load A into registers
for(int j = 0; j < MAX_SPARSE_COUNT; j++)
{
local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
local_colidxA[j] = j < count ? colidx[offset+j] : 0;
}
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
// we expect each warp to be SPMM_ITEMS*32 apart
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
// added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__ half smem_dequant_stats[SMEM_SIZE];
while(idx_col_B < colsB)
{
if(dequant_stats != NULL)
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
__syncthreads();
}
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j++)
local_valC[j] = 0.0f;
#pragma unroll
for(int i = 0; i < count; i++)
{
// 3. each warp loads all required rows of B but each warp is offset by k
int row_offset = colsB*local_colidxA[i];
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
{
// 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
if(idx >= colsB){ break; }
//printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx);
if((idx+num_items < colsB))
{
if(BITS == 8)
reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
else
reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
}
else
{
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
if(idx+k < colsB)
local_valsB[k] = B[row_offset+idx+k];
else
local_valsB[k] = 0.0f;
}
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
{
//if((float)local_valsB[k] != 0.0)
// printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB);
if(BITS == 8 && dequant_stats != NULL)
// we do texture cache reads (__ldg) on dequant_stats which should be super fast
{
float valB = local_valsB[k];
float valA = local_valA[i];
if(valB != 0.0 && valA != 0.0)
local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA;
}
else
local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
}
}
}
int idx_row_C = (colsB*local_row_idx);
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
{
//int idx_col_C = idx_col_B + (32*j) + warp_idx;
int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
int idx_val = idx_col_C + idx_row_C;
if(idx_col_C +num_items < colsB)
{
// load outputs to do inplace addition
reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
}
else
{
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
if(idx_col_C + k < colsB)
out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
}
}
idx_col_B += blockDim.x*SPMM_ITEMS;
local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
}
}
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);

View File

@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
#endif #endif

View File

@ -8,6 +8,7 @@
#include <cub/device/device_scan.cuh> #include <cub/device/device_scan.cuh>
#include <limits> #include <limits>
#include <BinSearch.h> #include <BinSearch.h>
#include <cassert>
#include <common.h> #include <common.h>
@ -188,11 +189,416 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
CUDA_CHECK_RETURN(cudaPeekAtLastError()); CUDA_CHECK_RETURN(cudaPeekAtLastError());
} }
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{
const int falpha = 1;
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
cublasStatus_t status;
status = cublasGemmEx(context->m_handle,
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
m, n, k,
alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
C, CUDA_R_32I, ldc,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
if (status != CUBLAS_STATUS_SUCCESS)
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
}
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount)
{
const int falpha = 1;
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
cublasStatus_t status;
//cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k);
//printf("%i %i %i\n", lda,ldb,ldc);
//printf("%i %i %i\n", strideA, strideB, strideC);
//printf("%i\n", batchCount);
status = cublasGemmStridedBatchedEx(context->m_handle,
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
m, n, k,
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
if (status != CUBLAS_STATUS_SUCCESS)
{
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
}
}
int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}
template<int ORDER> cublasLtOrder_t get_order()
{
switch(ORDER)
{
case ROW:
return CUBLASLT_ORDER_ROW;
break;
case COL:
return CUBLASLT_ORDER_COL;
break;
case COL32:
return CUBLASLT_ORDER_COL32;
break;
case COL_TURING:
return CUBLASLT_ORDER_COL4_4R2_8C;
break;
case COL_AMPERE:
return CUBLASLT_ORDER_COL32_2R_4R4;
break;
}
}
template cublasLtOrder_t get_order<ROW>();
template cublasLtOrder_t get_order<COL>();
template cublasLtOrder_t get_order<COL32>();
template cublasLtOrder_t get_order<COL_TURING>();
template cublasLtOrder_t get_order<COL_AMPERE>();
template<int ORDER> int get_leading_dim(int dim1, int dim2)
{
switch(ORDER)
{
case ROW:
return dim2;
break;
case COL:
return dim1;
break;
case COL32:
// 32*row tiles
return dim1*32;
break;
case COL_TURING:
return 32*roundoff(dim1, 8);
break;
case COL_AMPERE:
// 32*32 tiles
return 32*roundoff(dim1, 32);
break;
}
}
template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
{
cublasLtOrder_t orderA = get_order<SRC>();
cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2);
int ldOut = get_leading_dim<TARGET>(dim1, dim2);
cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
cublasOperation_t opTranspose = CUBLAS_OP_T;
float transformAlpha = 1.0f, transformBeta = 0.0f;
if(DTYPE == 8)
{
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
}
else if(DTYPE == 32)
{
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
}
else
{
printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
}
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
}
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{
int has_error = 0;
cublasLtMatmulDesc_t matmulDesc = NULL;
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
cublasOperation_t opT = CUBLAS_OP_T;
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(FORMATB == COL_TURING)
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
else
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
if(DTYPE_OUT == 32)
{
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
int alpha = 1, beta = 0;
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
}
else
{
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
if(!SCALE_ROWS)
{
float alpha = 1.0f, beta = 0.0f;
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
}
else
{
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
}
}
if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
if(has_error == 1)
printf("error detected");
return has_error;
}
int fill_up_to_nearest_multiple(int value, int multiple)
{
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
{
int threads = 512;
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
int n = numRows*tileCols;
int subtile_rows = 128;
int tilesize = 32*subtile_rows;
int num_blocks = numRows/subtile_rows;
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
num_blocks = num_blocks*(tileCols/32);
assert(threads <= tilesize);
//cout << num_blocks << " blocks" << endl;
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
#define STATS_THREADS 64
#define STATS_ITEMS 4
#define STATS_ROWS 16
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
{
int tile_cols = STATS_THREADS*STATS_ITEMS;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
if(nnz_threshold == 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
else if(nnz_threshold != 0.0)
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
{
int threads = 64;
int items_per_thread = 4;
int tile_cols = threads*items_per_thread;
int tile_rows = 16;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
if(threshold > 0.0f)
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
else
kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
{
int threads = 256;
int items_per_thread = 8;
// we load 128 column values per warp
int tile_cols = 32*items_per_thread;
int tile_rows = 32;
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
int outCols = fill_up_to_nearest_multiple(cols, 32);
int outRows = fill_up_to_nearest_multiple(rows, 32);
if(FORMAT == COL_TURING)
{
if(TRANSPOSE)
outRows = fill_up_to_nearest_multiple(cols, 8);
else
outRows = fill_up_to_nearest_multiple(rows, 8);
}
else if(FORMAT == COL_AMPERE)
{
if(TRANSPOSE)
outRows = fill_up_to_nearest_multiple(cols, 32);
else
outRows = fill_up_to_nearest_multiple(rows, 32);
}
else
{
if(TRANSPOSE)
{
outCols = fill_up_to_nearest_multiple(rows, 32);
outRows = cols;
}
}
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
//cout << "num blocks " << num_blocks << endl;
//cout << A << " " << out_col_normed << endl;
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
cusparseSpMatDescr_t descA;
cusparseDnMatDescr_t descB, descC;
float alpha = 1.0f;
float beta = 0.0f;
void *dBuffer = NULL;
size_t bufferSize = 0;
CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
A_rowidx, A_colidx, A_vals,
CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
// Create dense matrix C
CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
// Create dense matrix B
if(transposed_B)
{
int tmp = A_cols;
A_cols = B_cols;
B_cols = tmp;
}
CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
// allocate an external buffer if needed
CHECK_CUSPARSE( cusparseSpMM_bufferSize(
handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
// execute SpMM
CHECK_CUSPARSE( cusparseSpMM(handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));
// destroy matrix/vector descriptors
CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
}
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS
//============================================================== //==============================================================
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template int igemmlt<COL_TURING, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_TURING, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_TURING, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template int igemmlt<COL_AMPERE, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_TURING, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_TURING, 1>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_AMPERE, 0>(char * A, char *out, int rows, int cols);
template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, int cols);
template void estimateQuantiles(half *A, float *code, float offset, int n); template void estimateQuantiles(half *A, float *code, float offset, int n);
template void estimateQuantiles(float *A, float *code, float offset, int n); template void estimateQuantiles(float *A, float *code, float offset, int n);

View File

@ -14,6 +14,11 @@
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cusparse.h>
#include <vector>
#include <functional>
#define CUDA_CHECK_RETURN(value) { \ #define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \ cudaError_t _m_cudaStat = value; \
@ -25,6 +30,34 @@
#define THREADS_PER_BLOCKS (512) #define THREADS_PER_BLOCKS (512)
#define CHECK_CUSPARSE(value) { \
cusparseStatus_t _m_cudaStat = value; \
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
fprintf(stderr, "Error %s at line %d in file %s\n", \
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} }
#define THREADS_PER_BLOCKS (512)
inline void checkCudaStatus(cudaError_t status) {
if (status != cudaSuccess) {
printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
throw std::logic_error("cuda API failed");
}
}
inline int checkCublasStatus(cublasStatus_t status) {
if (status != CUBLAS_STATUS_SUCCESS) {
printf("cuBLAS API failed with status %d\n", status);
//throw std::logic_error("cuBLAS API failed");
return 1;
}
return 0;
}
typedef enum Operations_t typedef enum Operations_t
{ {
ksmul = 0, ksmul = 0,
@ -39,6 +72,57 @@ typedef enum Optimizer_t
ADAGRAD = 4, ADAGRAD = 4,
} Optimizer_t; } Optimizer_t;
typedef enum Transform_t
{
ROW = 0,
COL = 1,
COL32 = 2,
COL_TURING = 3,
COL_AMPERE = 4,
} Transform_t;
class Context
{
public:
cublasHandle_t m_handle;
Context()
{
cublasHandle_t handle;
cublasCreate_v2(&handle);
m_handle = handle;
}
};
class ContextLt
{
public:
cublasLtHandle_t m_handle;
ContextLt()
{
cublasLtHandle_t handle;
cublasLtCreate(&handle);
m_handle = handle;
}
};
class ContextCusparse
{
public:
cusparseHandle_t m_handle;
ContextCusparse()
{
cusparseHandle_t handle;
cusparseCreate(&handle);
m_handle = handle;
}
};
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n); template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
@ -70,4 +154,24 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount);
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols);
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
#endif #endif

View File

@ -84,6 +84,52 @@ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); } void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
#endif #endif
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
} \
MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 0>(A, out, rows, cols); }
void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 1>(A, out, rows, cols); }
void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 0>(A, out, rows, cols); }
void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 1>(A, out, rows, cols); }
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_TURING, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt<COL_AMPERE, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
extern "C" extern "C"
{ {
#if BUILD_CUDA #if BUILD_CUDA
@ -155,7 +201,86 @@ extern "C"
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
#endif void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{ gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); }
void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long strideA, long strideB, long strideC, int batchCount)
{ strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }
Context *get_context(){ return new Context(); }
ContextCusparse *get_cusparse(){ return new ContextCusparse(); }
int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
//{ (cublasLtHandle_t)context->m_handle; return 0; }
//{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
{ \
transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
} \
MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8)
MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); }
void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
{ getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }
void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols)
{ doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); }
void ctransform_row2col32(char * A, char *out, int rows, int cols)
{ transform_row2col32(A, out, rows, cols); }
void ctransform_row2col32T(char * A, char *out, int rows, int cols)
{ transform_row2col32T(A, out, rows, cols); }
void ctransform_row2turing(char * A, char *out, int rows, int cols)
{ transform_row2turing(A, out, rows, cols); }
void ctransform_row2turingT(char * A, char *out, int rows, int cols)
{ transform_row2turingT(A, out, rows, cols); }
void ctransform_row2ampere(char * A, char *out, int rows, int cols)
{ transform_row2ampere(A, out, rows, cols); }
void ctransform_row2ampereT(char * A, char *out, int rows, int cols)
{ transform_row2ampereT(A, out, rows, cols); }
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{ spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
#endif
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); } void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
} }

270
tests/test_autograd.py Normal file
View File

@ -0,0 +1,270 @@
import pytest
import torch
import bitsandbytes as bnb
from itertools import product
n = 1
k = 25
dim1 = torch.randint(16,64, size=(n,)).tolist()
dim2 = torch.randint(32,96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ['bmm', 'matmul']
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ['FF', 'TF', 'TT', 'FT']
transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ['FF', 'FT', 'TT', 'TF']
dtype = [torch.float32, torch.float16]
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose))
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0])
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1])
torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
elif not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
elif transpose[0] and not transpose[1]:
out_torch = funcs[0](A.t(), B)
out_bnb = funcs[1](A.t(), B)
elif transpose[0] and transpose[1]:
out_torch = funcs[0](A.t(), B.t())
out_bnb = funcs[1](A.t(), B.t())
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
# batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1])
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.01
torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02
if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16)
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
torch.nn.init.xavier_uniform_(B)
if transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B.t())
else:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B)
n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02
n = 1
k = 3
dim1 = torch.randint(16,64, size=(n,)).tolist()
dim2 = torch.randint(32,96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist()
#dim1 = (17,)
#dim2 = (7,)
#dim3 = (37,)
#dim4 = (23,)
decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)]
str_funcs = ['matmul']
req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ['FF', 'TF', 'TT', 'FT']
transpose = [(False, True), (False, False)]
str_transpose = ['NT', 'NN']
dtype = [torch.float16]
has_fp16_weights = [True, False]
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights))
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights))
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values]
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names)
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda')
for i in range(k):
# normal multiply
if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype)
if decomp == 6.0:
with torch.no_grad():
A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype)
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype)
torch.nn.init.xavier_uniform_(B)
B2 = B.clone()
state = bnb.MatmulLtState()
state.threshold = decomp
state.has_fp16_weights = has_fp16_weights
if not has_fp16_weights:
if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous()
state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2)
B2 = state.CB
if not transpose[0] and transpose[1]:
out_torch = funcs[0](A, B.t())
out_bnb = funcs[1](A, B2, state=state)
elif not transpose[0] and not transpose[1]:
out_torch = funcs[0](A, B)
out_bnb = funcs[1](A, B2.t(), state=state)
n = out_bnb.numel()
err = torch.abs(out_bnb-out_torch).mean().item()
#print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001
if has_fp16_weights:
if any(req_grad):
out_bnb.data.copy_(out_torch)
torch.cuda.synchronize()
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if req_grad[0]:
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)

File diff suppressed because it is too large Load Diff

View File

@ -1,42 +1,470 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pytest import pytest
import torch import torch
from itertools import product
from torch import nn
import bitsandbytes as bnb import bitsandbytes as bnb
class MockArgs(object):
def __init__(self, initial_data):
for key in initial_data:
setattr(self, key, initial_data[key])
@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding']) class MLP8bit(torch.nn.Module):
def test_embeddings(embcls): def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
bnb.optim.GlobalOptimManager.get_instance().initialize() super(MLP8bit, self).__init__()
emb1 = torch.nn.Embedding(100, 512).cuda() self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
emb2 = embcls(100, 512).cuda() self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)
adam1 = bnb.optim.Adam8bit(emb1.parameters()) def forward(self, x):
adam2 = bnb.optim.Adam8bit(emb2.parameters()) x = self.fc1(x)
x = self.fc2(x)
return x
batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()
def get_args():
args = MockArgs([])
args.quant_type = 'vector'
args.use_8bit_training = 'full'
args.clip_freq = 9999
return args
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
idx = torch.isclose(a, b, rtol, atol)
sumval = (idx==0).sum().item()
if sumval > count:
print(f'Too many values not close: assert {sumval} < {count}')
torch.testing.assert_allclose(a, b, rtol, atol)
class LinearFunction(torch.autograd.Function):
@staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
norm = math.sqrt(math.pi)/math.sqrt(2.0)
#std = torch.abs(x).mean()*norm
std = torch.std(x)
max1 = std*trim_value
x = x/max1*127
x = round_func(x)
x[x > 127] = 127
x[x < -127] = -127
x = x/127*max1
return x
def quant(x, quant_type, dim=1):
if quant_type == 'linear':
max1 = torch.abs(x).max().float()
xq = torch.round(x/max1*127).to(torch.int8)
return xq, max1
elif quant_type == 'vector':
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x/max1*127).to(torch.int8)
return xq, max1
elif quant_type == 'min-max':
maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float()
scale = (maxA-minA)/2.0
xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
return xq, (minA.float(), scale.float())
else: return None
def dequant(xq, S1, S2, dtype, quant_type):
if quant_type == 'linear':
norm = S1*S2/(127*127)
# double cast needed to prevent overflows
return (xq.float()*norm).to(dtype)
elif quant_type == 'vector':
x = xq.float()
if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
#print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2:
x *= S1.t()/127
else:
x *= S1/127
x *= S2/127
return x.to(dtype)
else: return None
def dequant_min_max(xq, A, B, SA, SB, dtype):
offset = B.float().t().sum(0)*(SA[0]+SA[1])
x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
if len(SB.shape) == 2:
x *= SB.t()/127
else:
x *= SB/127
x *= SA[1]/127
x +=offset
return x.to(dtype)
def get_8bit_linear(x, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.abs(x).max()
x = x/max1*127
x = round_func(x)/127*max1
#x = torch.round(x)/128*max1
return x
@staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1==0] = 1.0
x = (x*127)/max1
x = round_func(x)/127*max1
return x
@staticmethod
def round_stoachastic(x):
sign = torch.sign(x)
absx = torch.abs(x)
decimal = absx-torch.floor(absx)
rdm = torch.rand_like(decimal)
return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))
@staticmethod
def fake_8bit_storage(w, exponent_bits):
code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
#C = bnb.functional.quantize_no_absmax(code, w)
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#print(out)
#out = out.half()
code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_stoachstic(w):
rand = torch.rand(1024, device=w.device)
absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
out = bnb.functional.dequantize_blockwise(absmax, C)
out = out.half()
w.copy_(out)
return out
@staticmethod
def fake_8bit_storage_with_max(w, topk=8):
blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk]
max_val = max_val[:, :topk]
mask = torch.zeros_like(blocked_w)
mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
mask = mask.bool()
# 1. zero out max values
# 2. quantize + dequantize
# 3. write back max values
# 4. copy matrix back to weight
values = blocked_w[mask]
blocked_w[mask] = 0
code = bnb.functional.create_dynamic_map()
code = code.to(w.device)
absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)
blocked_w[mask] = values
unblocked_w = blocked_w.flatten().view(w.shape)
w.copy_(unblocked_w)
return unblocked_w
@staticmethod
def forward(ctx, x, weight, bias=None, args=None):
if args.use_8bit_training != 'off':
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
#if torch.rand(1) < 0.01:
#output32 = torch.matmul(x, weight.t())
#err = torch.abs(output-output32).float()
#relerr = err/(torch.abs(output32).float()+1e-8)
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else:
#output = torch.matmul(x, weight.t())
output = torch.einsum('bsi,oi->bso', x, weight)
ctx.save_for_backward(x, weight, bias)
ctx.args = args
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
def backward(ctx, grad_output):
x, weight, bias = ctx.saved_tensors
args = ctx.args
stochastic = False
grad_input = grad_weight = grad_bias = None
if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)
# weight and x are already 8bit
# -> transform grad_output to 8-bit
if args.use_8bit_training == 'forward+wgrad':
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == 'full':
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
else:
grad_input = grad_output.matmul(weight)
grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)
return grad_input, grad_weight, grad_bias, None
class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.args = args
self.weight = nn.Parameter(torch.empty(output_features, input_features))
if bias:
self.bias = nn.Parameter(torch.empty(output_features))
else:
self.register_parameter('bias', None)
torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None:
torch.nn.init.zeros_(self.bias)
def forward(self, x):
self.args.training = self.training
return LinearFunction.apply(x, self.weight, self.bias, self.args)
def test_linear8bit():
l0 = torch.nn.Linear(32, 64).cuda().half()
l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()
l0.weight.data = l2.weight.data.clone()
l0.bias.data = l2.bias.data.clone()
l1.weight.data = l2.weight.data.clone()
l1.bias.data = l2.bias.data.clone()
l3.weight.data = l2.weight.data.clone()
l3.bias.data = l2.bias.data.clone()
for i in range(100): for i in range(100):
batch = batches[i] b1 = torch.randn(16, 8, 32, device='cuda').half()
t = torch.randn(16, 8, 64, device='cuda').half()
b2 = b1.clone()
b3 = b1.clone()
b0 = b1.clone()
embedded1 = emb1(batch) o0 = l0(b0)
embedded2 = emb2(batch) o1 = l1(b1)
o2 = l2(b2)
o3 = l3(b3)
l1 = embedded1.mean() assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
l2 = embedded2.mean() assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)
l1.backward() loss0 = torch.nn.functional.mse_loss(o0, t)
l2.backward() loss1 = torch.nn.functional.mse_loss(o1, t)
loss2 = torch.nn.functional.mse_loss(o2, t)
loss3 = torch.nn.functional.mse_loss(o3, t)
adam1.step() loss0.backward()
adam2.step() loss1.backward()
loss2.backward()
loss3.backward()
adam1.zero_grad() assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
adam2.zero_grad() assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8 err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
assert adam2.state[emb2.weight]['state1'].dtype == torch.float32 err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()
assert err1*0.8 < err2
assert err2*0.8 < err3
assert err3*0.8 < err1
l0.weight.grad = None
l1.weight.grad = None
l2.weight.grad = None
l3.weight.grad = None
l0.bias.grad = None
l1.bias.grad = None
l2.bias.grad = None
l3.bias.grad = None
threshold = [0.0, 3.0]
values = threshold
names = ['threshold_{0}'.format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold):
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
assert l1.weight.device.type == 'cuda'
assert l1.weight.dtype == torch.float16
l1.eval()
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
o1 = l1(b1)
if i == 1:
assert l1.state.CxB is not None
def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
acc_steps = 10
for i in range(10):
b1 = torch.randn(16, 8, 32, device='cuda').half()
o1 = l1(b1)
o2 = l2(b1)
loss1 = o1.mean()
loss2 = o2.mean()
loss1.backward()
loss2.backward()
if i == 2:
assert l1[0].state.CxB is not None
assert l1[1].state.CxB is not None
if i > 0 and i % acc_steps == 0:
opt1.step()
opt1.zero_grad(True)
opt2.step()
opt2.zero_grad(True)
assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
# we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data)
else:
torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
threshold = [0.0, 2.0]
values = threshold
names = ['threshold_{0}'.format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold):
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
assert l1.weight.dtype == torch.int8
l1.eval()
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
o1 = l1(b1)
assert o1.dtype == torch.float16
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == 'cuda'
assert mlp.fc2.weight.device.type == 'cuda'
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')
for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half()
o1 = mlp(b1)
assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None
if threshold > 0: assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == 'cuda'
assert mlp.fc2.weight.device.type == 'cuda'

View File

@ -1,12 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os import os
import time import time
import shutil import shutil
import uuid import uuid
import pytest import pytest
import ctypes
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
@ -14,7 +11,9 @@ import bitsandbytes.functional as F
from os.path import join from os.path import join
from itertools import product from itertools import product
import apex #import apex
k = 20
def get_temp_dir(): def get_temp_dir():
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4())) path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
@ -26,55 +25,47 @@ def rm_path(path):
str2optimizers = {} str2optimizers = {}
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) #str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) #str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam) str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) #str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) #str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW) #str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) #str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False))
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) #str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True))
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True))
str2statenames = {} str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['momentum'] = [('momentum_buffer', 'state1')] str2statenames['momentum'] = [('momentum_buffer', 'state1')]
str2statenames['lars'] = [('momentum_buffer', 'state1')] str2statenames['lars'] = [('momentum_buffer', 'state1')]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
str2statenames['rmsprop'] = [('square_avg', 'state1')] str2statenames['rmsprop'] = [('square_avg', 'state1')]
str2statenames['adagrad'] = [('sum', 'state1')]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad'] optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
values = list(product(dim1,dim2, gtype, optimizer_names)) values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
if gtype == torch.float32: if gtype == torch.float32:
atol, rtol = 2e-6, 1e-5 atol, rtol = 1e-6, 1e-5
else: else:
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
for i in range(50): for i in range(k):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g.clone().float() p1.grad = g.clone().float()
p2.grad = g.clone() p2.grad = g.clone()
@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
if i % 10 == 0 and i > 0: if i % (k//5) == 0 and i > 0:
path = get_temp_dir() path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
del bnb_optimizer del bnb_optimizer
@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8 eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize() bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True)
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype):
else: else:
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
original_p2 = p2[mask].clone()
for i in range(50): for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype):
p2.grad = g2 p2.grad = g2
p3.grad = g3 p3.grad = g3
if i > 30 and i % 10 == 0:
g1.data[mask] = 0.0
g2.data[mask] = 0.0
p1.grad = g1
p2.grad = g2
original_p1 = p1[mask].clone()
original_p2 = p2[mask].clone()
og_s1 = adam2.state[p2]['state1'][mask].clone()
og_s2 = adam2.state[p2]['state2'][mask].clone()
og_s11 = adam2.state[p1]['state1'][mask].clone()
og_s21 = adam2.state[p1]['state2'][mask].clone()
adam2.step() adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]['state1'].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8 assert adam2.state[p3]['state2'].dtype == torch.uint8
if i > 30 and i % 10 == 0:
torch.testing.assert_allclose(original_p2, p2[mask])
torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1)
torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2)
assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel()
assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0
assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise'] optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
values = list(product(dim1,dim2, gtype, optimizer_names)) values = list(product(dim1,dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1: return
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1]) bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
p1.grad = g p1.grad = g
for i in range(5000): for i in range(k):
if i == 500: if i == k//5:
# 100 iterations for burn-in # 100 iterations for burn-in
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch.cuda.synchronize() torch.cuda.synchronize()
s = time.time()-t0 s = time.time()-t0
print('') print('')
params = 4500*4096*4096 params = (k-k//5)*dim1*dim2
print(optim_name, gtype, s/params) print(optim_name, gtype, s/params)
#assert s < 3.9 #assert s < 3.9
def test_str_betas():
betas = (0.80, 0.95)
strbetas = '(0.80, 0.95)'
layer = torch.nn.Linear(10, 10)
base = bnb.optim.Adam(layer.parameters(), betas=betas)
strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
assert base.defaults['betas'][0] == 0.8
assert base.defaults['betas'][1] == 0.95
assert strbase.defaults['betas'][0] == 0.8
assert strbase.defaults['betas'][1] == 0.95