forked from mrq/bitsandbytes-rocm
Most tests passing.
This commit is contained in:
parent
4cd7ea62b2
commit
c771b3a75a
|
@ -4,12 +4,13 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .nn import modules
|
||||
from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
|
||||
from .cextension import COMPILED_WITH_CUDA
|
||||
|
||||
if COMPILED_WITH_CUDA:
|
||||
from .optim import adam
|
||||
|
||||
__pdoc__ = {'libBitsNBytes': False,
|
||||
__pdoc__ = {'libbitsandbytes': False,
|
||||
'optim.optimizer.Optimizer8bit': False,
|
||||
'optim.optimizer.MockArgs': False
|
||||
}
|
||||
|
|
0
bitsandbytes/autograd/__init__.py
Normal file
0
bitsandbytes/autograd/__init__.py
Normal file
307
bitsandbytes/autograd/_functions.py
Normal file
307
bitsandbytes/autograd/_functions.py
Normal file
|
@ -0,0 +1,307 @@
|
|||
import torch
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional as F
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
tensor = torch.Tensor
|
||||
|
||||
'''
|
||||
This class pools outlier dimensions across layers.
|
||||
This is particularly important for small models where outlier features
|
||||
are less systematic and occur with low frequency.
|
||||
'''
|
||||
class GlobalOutlierPooler(object):
|
||||
_instance = None
|
||||
|
||||
def __init__(self):
|
||||
raise RuntimeError('Call get_instance() instead')
|
||||
|
||||
def initialize(self):
|
||||
self.outliers = set()
|
||||
self.model_dim = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls.__new__(cls)
|
||||
cls._instance.initialize()
|
||||
return cls._instance
|
||||
|
||||
def add_outliers(self, outlier_idx, feature_dim):
|
||||
if self.model_dim is None: self.model_dim = feature_dim
|
||||
if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
|
||||
|
||||
self.outliers.update(outlier_idx.tolist())
|
||||
|
||||
def get_current_outlier_idx(self):
|
||||
return torch.Tensor(list(self.outliers)).to(torch.int64)
|
||||
|
||||
class MatMul8bit(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
|
||||
|
||||
if precision[0] != 8:
|
||||
with torch.no_grad():
|
||||
output = torch.matmul(A, B)
|
||||
else:
|
||||
if len(B.shape) == 2: dim = 0
|
||||
else: dim = 1
|
||||
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
|
||||
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
|
||||
iout = F.igemm(qA, qB)
|
||||
output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type)
|
||||
|
||||
if A.requires_grad or B.requires_grad:
|
||||
ctx.save_for_backward(A, B)
|
||||
|
||||
ctx.quant_type = quant_type
|
||||
ctx.precision = precision
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
A, B = ctx.saved_tensors
|
||||
quant_type = ctx.quant_type
|
||||
precision = ctx.precision
|
||||
grad_A = grad_B = None
|
||||
|
||||
if B.requires_grad:
|
||||
if len(A.shape) == 3:
|
||||
dims = [0, 1]
|
||||
# bsi -> ibs
|
||||
permute_dim = [0, 2, 1]
|
||||
else:
|
||||
dims = [0]
|
||||
# bs -> sb
|
||||
permute_dim = [1, 0]
|
||||
|
||||
if precision[1] != 8:
|
||||
with torch.no_grad():
|
||||
grad_B = torch.matmul(A.permute(permute_dim), grad_output)
|
||||
else:
|
||||
if len(B.shape) == 2 and len(A.shape) == 3:
|
||||
grad_output = grad_output.contiguous()
|
||||
if not grad_output.is_contiguous(): grad_output.contiguous()
|
||||
qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
|
||||
if not A.is_contiguous(): A = A.contiguous()
|
||||
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
|
||||
igrad_B = F.igemm(qA.t(), qgrad_output)
|
||||
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
|
||||
else:
|
||||
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
|
||||
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
|
||||
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
|
||||
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
|
||||
|
||||
if A.requires_grad:
|
||||
if len(grad_output.shape) == 3: dims = [2]
|
||||
else: dims = [1]
|
||||
|
||||
if len(B.shape) == 3:
|
||||
# bio -> boi
|
||||
permute_dim = [0, 2, 1]
|
||||
dim_B = dims
|
||||
else:
|
||||
# io -> oi
|
||||
permute_dim = [1, 0]
|
||||
dim_B = [1]
|
||||
|
||||
if precision[2] != 8:
|
||||
with torch.no_grad():
|
||||
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
|
||||
else:
|
||||
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
|
||||
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
|
||||
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
|
||||
grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
|
||||
|
||||
return grad_A, grad_B, None, None, None
|
||||
|
||||
|
||||
mm_cublas = MatMul8bit.apply
|
||||
bmm_cublas = MatMul8bit.apply
|
||||
matmul_cublas = MatMul8bit.apply
|
||||
|
||||
@dataclass
|
||||
class MatmulLtState:
|
||||
CB = None
|
||||
CxB = None
|
||||
SB = None
|
||||
SCB = None
|
||||
|
||||
CxBt = None
|
||||
SBt = None
|
||||
CBt = None
|
||||
|
||||
subB = None
|
||||
|
||||
outlier_pool = None
|
||||
has_accumulated_gradients = False
|
||||
threshold = 0.0
|
||||
idx = None
|
||||
is_training = True
|
||||
has_fp16_weights = True
|
||||
use_pool = False
|
||||
formatB = F.get_special_format_str()
|
||||
|
||||
def reset_grads(self):
|
||||
self.CB = None
|
||||
self.CxB = None
|
||||
self.SB = None
|
||||
self.SCB = None
|
||||
|
||||
self.CxBt = None
|
||||
self.SBt = None
|
||||
self.CBt = None
|
||||
|
||||
|
||||
class MatMul8bitLt(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, state=MatmulLtState()):
|
||||
# 1. Quantize A
|
||||
# 2. Quantize B
|
||||
# 3. Matmul
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
# 5. Save state
|
||||
requires_gradA = A.requires_grad
|
||||
requires_gradB = B.requires_grad
|
||||
formatB = state.formatB
|
||||
input_shape = A.shape
|
||||
if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
|
||||
assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
|
||||
|
||||
# 1. Quantize A
|
||||
if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
|
||||
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
|
||||
|
||||
if state.threshold > 0.0 and coo_tensorA is not None:
|
||||
if state.has_fp16_weights:
|
||||
idx = torch.unique(coo_tensorA.colidx).long()
|
||||
CA[:, idx] = 0
|
||||
CAt[:, idx] = 0
|
||||
subA = A[:, idx]
|
||||
state.subB = B[:, idx].t().contiguous()
|
||||
state.idx = idx
|
||||
else:
|
||||
if state.CxB is None:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
|
||||
# generate outlier index and subB
|
||||
outlier_idx = torch.unique(coo_tensorA.colidx).long()
|
||||
state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
|
||||
if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
|
||||
# do not use pool for 2nd FFN layer
|
||||
state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
|
||||
else:
|
||||
state.idx = outlier_idx
|
||||
state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
|
||||
|
||||
if state.idx is not None:
|
||||
# extract outliers
|
||||
CA[:, state.idx] = 0
|
||||
CAt[:, state.idx] = 0
|
||||
subA = A[:, state.idx]
|
||||
else:
|
||||
subA = None
|
||||
else:
|
||||
if not state.has_fp16_weights and state.CxB is None:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
subA = None
|
||||
|
||||
C32A, SA = F.transform(CA, 'col32')
|
||||
|
||||
# 2. Quantize B
|
||||
if state.has_fp16_weights:
|
||||
has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
|
||||
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
||||
if is_transposed: B = B.contiguous()
|
||||
|
||||
if (state.is_training and not has_grad) or state.CxB is None:
|
||||
state.reset_grads()
|
||||
CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
|
||||
state.CxB, state.SB = F.transform(CB, to_order=formatB)
|
||||
else:
|
||||
has_grad = False
|
||||
|
||||
shapeB = state.SB[0]
|
||||
|
||||
if len(input_shape) == 3:
|
||||
output_shape = (input_shape[0], input_shape[1], shapeB[0])
|
||||
else:
|
||||
output_shape = (input_shape[0], shapeB[0])
|
||||
|
||||
# 3. Matmul
|
||||
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
|
||||
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
|
||||
|
||||
# 4. Mixed-precision decomposition matmul
|
||||
if state.threshold > 0.0 and coo_tensorA is not None and subA is not None:
|
||||
output += torch.matmul(subA, state.subB)
|
||||
|
||||
# 5. Save state
|
||||
ctx.state = state
|
||||
|
||||
ctx.formatB = formatB
|
||||
ctx.grad_shape = input_shape
|
||||
ctx.req_grads = [requires_gradA, requires_gradB]
|
||||
|
||||
if requires_gradA or requires_gradB:
|
||||
ctx.tensors = (CAt, subA)
|
||||
ctx.tensor_states = (SCAt, state.idx)
|
||||
else:
|
||||
ctx.tensors = [None, None]
|
||||
ctx.tensor_states = (None, None)
|
||||
ctx.save_for_backward(None, None)
|
||||
|
||||
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
|
||||
clone_func = torch.clone
|
||||
return clone_func(output.view(output_shape))
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
req_gradA, req_gradB = ctx.req_grads
|
||||
CAt, subA = ctx.tensors
|
||||
SCAt, idx = ctx.tensor_states
|
||||
formatB = ctx.formatB
|
||||
state = ctx.state
|
||||
assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
|
||||
|
||||
if len(grad_output.shape) == 3:
|
||||
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
|
||||
|
||||
grad_A = grad_B = None
|
||||
|
||||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
|
||||
if req_gradB:
|
||||
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
|
||||
C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
|
||||
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
|
||||
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
|
||||
if state.threshold > 0.0 and subA is not None:
|
||||
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
||||
|
||||
if req_gradA:
|
||||
C32grad, Sgrad = F.transform(Cgrad, 'col32')
|
||||
if state.CxBt is None:
|
||||
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
matmul = MatMul8bitLt.apply
|
||||
|
||||
|
||||
def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
|
||||
state = state or MatmulLtState()
|
||||
if threshold > 0.0:
|
||||
state.threshold = threshold
|
||||
return MatMul8bitLt.apply(A, B, out, state)
|
||||
|
|
@ -6,6 +6,8 @@ lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
|
|||
|
||||
try:
|
||||
lib.cadam32bit_g32
|
||||
lib.get_context.restype = ct.c_void_p
|
||||
lib.get_cusparse.restype = ct.c_void_p
|
||||
COMPILED_WITH_CUDA = True
|
||||
except AttributeError:
|
||||
warn("The installed version of bitsandbytes was compiled without GPU support. "
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -2,4 +2,4 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .modules import StableEmbedding, Embedding
|
||||
from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params
|
||||
|
|
|
@ -3,14 +3,19 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from typing import Optional
|
||||
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
|
||||
|
||||
from torch import Tensor
|
||||
from torch import Tensor, device, dtype
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
import torch.nn.functional as F
|
||||
|
||||
from bitsandbytes.optim import GlobalOptimManager
|
||||
|
||||
T = TypeVar('T', bound='torch.nn.Module')
|
||||
|
||||
class StableEmbedding(torch.nn.Embedding):
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
||||
|
@ -70,3 +75,118 @@ class Embedding(torch.nn.Embedding):
|
|||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
|
||||
return emb
|
||||
|
||||
class Int8Params(torch.nn.Parameter):
|
||||
def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None):
|
||||
cls.has_fp16_weights = has_fp16_weights
|
||||
cls.CB = None
|
||||
cls.SCB = None
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def cuda(self, device):
|
||||
if self.has_fp16_weights:
|
||||
return super().cuda(device)
|
||||
else:
|
||||
# we store the 8-bit rows-major weight
|
||||
# we convert this weight to the turning/ampere weight during the first inference pass
|
||||
B = self.data.contiguous().half().cuda(device)
|
||||
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
|
||||
del CBt
|
||||
del SCBt
|
||||
self.data = CB
|
||||
setattr(self, 'CB', CB)
|
||||
setattr(self, 'SCB', SCB)
|
||||
|
||||
return self
|
||||
|
||||
@overload
|
||||
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ...,
|
||||
non_blocking: bool = ...) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T:
|
||||
...
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
|
||||
if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device)
|
||||
else:
|
||||
new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights)
|
||||
new_param.CB = self.CB
|
||||
new_param.SCB = self.SCB
|
||||
|
||||
return new_param
|
||||
|
||||
|
||||
|
||||
class Linear8bitLt(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None):
|
||||
super(Linear8bitLt, self).__init__(input_features, output_features, bias)
|
||||
self.state = bnb.MatmulLtState()
|
||||
self.index=index
|
||||
|
||||
self.state.threshold = threshold
|
||||
self.state.has_fp16_weights = has_fp16_weights
|
||||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
self.state.SCB = self.weight.SCB
|
||||
self.weight.CB = None
|
||||
self.weight.SCB = None
|
||||
|
||||
def forward(self, x):
|
||||
self.state.is_training = self.training
|
||||
|
||||
if self.weight.CB is not None: self.init_8bit_state()
|
||||
#assert not self.state.has_fp16_weights
|
||||
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
|
||||
|
||||
out = bnb.matmul(x, self.weight, state=self.state)
|
||||
|
||||
if self.bias is not None:
|
||||
out += self.bias.unsqueeze(0).expand_as(out)
|
||||
|
||||
if not self.state.has_fp16_weights and self.state.CB is not None:
|
||||
# we converted 8-bit row major to turing/ampere format in the first inference pass
|
||||
# we no longer need the row-major weight
|
||||
del self.state.CB
|
||||
self.weight.data = self.state.CxB
|
||||
|
||||
return out
|
||||
|
||||
class Linear8bit(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False):
|
||||
super(Linear8bit, self).__init__(input_features, output_features, bias)
|
||||
self.quant_type = quant_type
|
||||
self.index = index
|
||||
self.args = args
|
||||
self.iter = 0
|
||||
|
||||
def forward(self, x):
|
||||
self.iter += 1
|
||||
if self.iter % self.args.clip_freq == 0:
|
||||
with torch.no_grad():
|
||||
maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx)
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print('clip', maxval[-1].item())
|
||||
self.weight.clip_(-maxval[-1], maxval[-1])
|
||||
|
||||
|
||||
if self.args is not None:
|
||||
out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type)
|
||||
else:
|
||||
out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type)
|
||||
|
||||
return out
|
||||
|
|
874
csrc/kernels.cu
874
csrc/kernels.cu
|
@ -1737,10 +1737,884 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols)
|
||||
{
|
||||
// 0. reset stats to -FLT_MAX
|
||||
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
|
||||
// 2. compute col max (per thread); store in smem due to register pressure
|
||||
// 3. compute row max (per block); store in smem to accumulate full global mem transation
|
||||
// 4. store data via atomicMax
|
||||
|
||||
// each block loads TILE_COLs columns and TILE_ROW rows
|
||||
// after reading a tile the row counter increase by TILE_ROWS
|
||||
// the col counter reset after reading TILE_COL elements
|
||||
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
|
||||
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
|
||||
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
|
||||
const int base_idx = (base_row*cols) + base_col;
|
||||
const int items_per_load = ITEMS_PER_THREAD*THREADS;
|
||||
|
||||
typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT;
|
||||
typedef cub::BlockReduce<float, THREADS> BlockRowReduce;
|
||||
typedef cub::BlockReduce<int, THREADS> BlockRowSum;
|
||||
typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange;
|
||||
|
||||
__shared__ union {
|
||||
typename BlockExchange::TempStorage exchange;
|
||||
typename BlockRowReduce::TempStorage rowreduce;
|
||||
typename BlockRowSum::TempStorage rowsum;
|
||||
typename LoadT::TempStorage loadt;
|
||||
} temp_storage;
|
||||
|
||||
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS];
|
||||
__shared__ int smem_row_nnz_values[TILE_ROWS];
|
||||
//__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS];
|
||||
|
||||
half local_data[ITEMS_PER_THREAD];
|
||||
float local_data_fp32[ITEMS_PER_THREAD];
|
||||
float local_col_absmax_values[ITEMS_PER_THREAD];
|
||||
int local_row_nnz_count = 0;
|
||||
float row_absmax = -FLT_MAX;
|
||||
|
||||
// 0. reset stats to -FLT_MAX
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
{
|
||||
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
|
||||
smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX;
|
||||
smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0;
|
||||
}
|
||||
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
local_col_absmax_values[j] = -FLT_MAX;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
|
||||
int i = base_idx;
|
||||
// we load row after row from the base_position
|
||||
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
|
||||
for(int row = 0; row < TILE_ROWS; row++)
|
||||
{
|
||||
if(base_row+row >= rows){ break; }
|
||||
local_row_nnz_count = 0;
|
||||
i = base_idx + ((row)*cols);
|
||||
// each thread gets data from the same column
|
||||
__syncthreads();
|
||||
LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));
|
||||
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
local_data[j] = fabsf(local_data[j]);
|
||||
|
||||
|
||||
if(SPARSE_DECOMP)
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
{
|
||||
if((float)local_data[j] >= nnz_threshold)
|
||||
{
|
||||
local_row_nnz_count += 1;
|
||||
local_data[j] = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. compute col max (per thread); store in smem due to register pressure
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
// take the col max for this row
|
||||
// we use shared memory because register pressure is too high if we do this locally
|
||||
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
|
||||
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
|
||||
|
||||
// 3. compute row max (per block); store in smem to accumulate full global mem transation
|
||||
__syncthreads();
|
||||
|
||||
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units)
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
local_data_fp32[j] = local_data[j];
|
||||
|
||||
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max());
|
||||
if(SPARSE_DECOMP)
|
||||
{
|
||||
__syncthreads();
|
||||
local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
|
||||
}
|
||||
// we store the data temporarily in shared memory so we
|
||||
// can execute a full atomic block transaction into global memory later
|
||||
// we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
|
||||
// each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
|
||||
smem_row_nnz_values[row] = local_row_nnz_count;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
}
|
||||
|
||||
// 4. store data via atomicMax
|
||||
// to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
|
||||
// into a striped arangement: [0, 8, 16, 24, ..] for t0
|
||||
__syncthreads();
|
||||
BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);
|
||||
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
if(base_col+threadIdx.x+(j*THREADS) < cols)
|
||||
{
|
||||
float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
|
||||
if(val < local_col_absmax_values[j])
|
||||
atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
|
||||
}
|
||||
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
if(base_row+threadIdx.x+(j*THREADS) < rows)
|
||||
{
|
||||
float val = rowStats[base_row+(threadIdx.x+(j*THREADS))];
|
||||
if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)])
|
||||
atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]);
|
||||
}
|
||||
|
||||
if(SPARSE_DECOMP)
|
||||
if(threadIdx.x < TILE_ROWS)
|
||||
nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];
|
||||
|
||||
}
|
||||
|
||||
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
|
||||
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
|
||||
|
||||
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
|
||||
|
||||
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n)
|
||||
{
|
||||
|
||||
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive
|
||||
// since different row/col stats need to be loaded with each thread.
|
||||
// (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
|
||||
// and would lead to low global load utilization.
|
||||
// (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
|
||||
// for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
|
||||
// (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
|
||||
// This allows for efficient row/col loading from shared memory within the tile.
|
||||
// We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
|
||||
// the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
|
||||
// we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
|
||||
// shared memory loads.
|
||||
|
||||
// data is in 32 column-tile major with tile width 32 columns and numRows rows
|
||||
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
|
||||
// L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
|
||||
// C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
|
||||
// C2. Compute normalization values and store col values in register
|
||||
// S1. Store C1 into 16-bit output
|
||||
// S2. Store col/row statistics of new buffer in shared memory
|
||||
|
||||
// We allow for sub-tiles to span multiple col32 tiles. This is okay
|
||||
// since the items per thread only rely on a single column statistic.
|
||||
|
||||
|
||||
const int n_out = numRows*numCols;
|
||||
|
||||
int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
|
||||
// we have tiles of size numRows*32, thus col only increases every numRows
|
||||
// num_row_tiles is the tiles after which the column increases by 32
|
||||
// blockIdx.x is the index of the current tile
|
||||
int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
|
||||
// base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
|
||||
int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
|
||||
|
||||
// SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
|
||||
// subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
|
||||
// Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
|
||||
// For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
|
||||
// 1024*1024/(128*32) = 256 tiles
|
||||
// 256 tiles are 256*128*32/4 = 256*1024 threads
|
||||
|
||||
// 1. Figure out how index relates to the start of the sub-tile
|
||||
// 2. Each thread < SUBTILE_ROWS calculates row index
|
||||
// 3. Load striped and store in shared memory
|
||||
|
||||
int local_values[ITEMS_PER_THREAD];
|
||||
half local_output[ITEMS_PER_THREAD];
|
||||
float local_rowStats[ITEMS_PER_THREAD];
|
||||
__shared__ float smem_rowStats[SUBTILE_ROWS];
|
||||
|
||||
typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32;
|
||||
typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
|
||||
__shared__ typename LoadInt32::TempStorage loadint32;
|
||||
__shared__ typename ExchangeInt32::TempStorage exchangeint32;
|
||||
|
||||
|
||||
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
|
||||
float colStat = col >= numCols ? 0.0f : colStats[col];
|
||||
// no block loads for rows for now -- keep it simple
|
||||
for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
|
||||
{
|
||||
// todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
|
||||
int row = (base_row+j) % numRows; // wrap around
|
||||
// each warp accesses the same element, for four consequitive elements
|
||||
// todo: update description about striped shared memory, it is not needed
|
||||
// rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
|
||||
smem_rowStats[j] = rowStats[row];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
// each block processes SUBTILE_ROWS*32 elements
|
||||
const int items_per_load = THREADS*ITEMS_PER_THREAD;
|
||||
const int rows_per_load = items_per_load/32;
|
||||
|
||||
int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
|
||||
int row_offset = 0;
|
||||
// subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
|
||||
int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);
|
||||
for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
|
||||
{
|
||||
int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
|
||||
int valid_items = valid_rows*32;
|
||||
if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
|
||||
break;
|
||||
|
||||
// L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
|
||||
LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);
|
||||
ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);
|
||||
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
|
||||
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
local_output[j] = __float2half(local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat);
|
||||
//absmax_col = fmax(fabsf(local_output[j]), absmax_col);
|
||||
|
||||
// we store data in row major
|
||||
// to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
|
||||
// so that each thread holds ITEMS_PER_THREAD consecutive items for each row
|
||||
// this way throughput into storage is increased by a factor of ~2x
|
||||
// for now we use a simple store
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
{
|
||||
int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
|
||||
if(outIdx< n_out && col < numCols)
|
||||
out[outIdx] = local_output[j];
|
||||
}
|
||||
|
||||
row_offset += rows_per_load;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
|
||||
{
|
||||
// assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
|
||||
// Each thread reads the same column but multiple rows
|
||||
// Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
|
||||
|
||||
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
|
||||
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
|
||||
// 2. quantize data with row/col stats
|
||||
// 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
|
||||
|
||||
// each block loads TILE_COLs columns and TILE_ROW rows
|
||||
// after reading a tile the row counter increase by TILE_ROWS
|
||||
// the col counter reset after reading TILE_COL elements
|
||||
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
|
||||
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
|
||||
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
|
||||
const int base_idx = (base_row*cols) + base_col;
|
||||
const int items_per_load = ITEMS_PER_THREAD*THREADS;
|
||||
|
||||
typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
|
||||
__shared__ typename LoadHalf::TempStorage loadhalf;
|
||||
typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
|
||||
__shared__ typename StoreInt8::TempStorage storeint8;
|
||||
|
||||
__shared__ float smem_row_stats[TILE_ROWS];
|
||||
__shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];
|
||||
|
||||
half local_data[ITEMS_PER_THREAD];
|
||||
float local_col_stats[ITEMS_PER_THREAD];
|
||||
char local_quantized_data[ITEMS_PER_THREAD];
|
||||
|
||||
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
|
||||
local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);
|
||||
|
||||
for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x)
|
||||
{
|
||||
if(base_row + i < rows)
|
||||
smem_row_stats[i] = rowStats[base_row+i];
|
||||
|
||||
if(SPARSE_DECOMP)
|
||||
smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// we load row after row from the base_position
|
||||
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
|
||||
for(int row = 0; row < TILE_ROWS; row++)
|
||||
{
|
||||
if(base_row + row >= rows){ break; }
|
||||
int i = base_idx + (row*cols);
|
||||
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
|
||||
|
||||
|
||||
LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f);
|
||||
float row_stat = __fdividef(127.0f, smem_row_stats[row]);
|
||||
|
||||
// 2. quantize data with row/col stats
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
{
|
||||
// we already pre-normalized the col/row stat:
|
||||
// what this does is float/absmax*127 = int8
|
||||
if(SPARSE_DECOMP)
|
||||
{
|
||||
if(fabsf((float)local_data[j]) >= threshold)
|
||||
{
|
||||
local_quantized_data[j] = 0;
|
||||
|
||||
int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);
|
||||
|
||||
rowidx[old_idx] = base_row+row;
|
||||
colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
|
||||
val[old_idx] = local_data[j];
|
||||
}
|
||||
else
|
||||
{
|
||||
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
|
||||
}
|
||||
}
|
||||
else
|
||||
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
|
||||
}
|
||||
|
||||
StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);
|
||||
|
||||
// 2. quantize data with row/col stats
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
{
|
||||
// we already pre-normalized the col/row stat:
|
||||
// what this does is float/absmax*127 = int8
|
||||
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols)
|
||||
{
|
||||
|
||||
// 0. Load data into 32*32 shared memory tiles
|
||||
// 1. transpose / reorder in shared memory
|
||||
// 2. store
|
||||
|
||||
// COL32 FORMAT:
|
||||
// rows*32 tiles
|
||||
|
||||
// TURING FORMAT:
|
||||
// 8*32 tiles with 4*4 subtiles
|
||||
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
|
||||
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
|
||||
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
|
||||
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
|
||||
// index increases by 32
|
||||
|
||||
// AMPERE FORMAT:
|
||||
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
|
||||
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
|
||||
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
|
||||
|
||||
|
||||
// To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
|
||||
// As such we need:
|
||||
// at least 32*4 shared memory tiles for col32; preferably 32*32
|
||||
// at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
|
||||
// at least 32*8 shared memory tiles for col4_turing: preferably 32*32
|
||||
// for efficient loading of row major we need to load 128 elements and repeat this 32 items
|
||||
// this would imply a 32x128 shared memory tile -> 4kb
|
||||
// It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb
|
||||
// we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
|
||||
// for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
|
||||
// register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
|
||||
//
|
||||
// to make the shared memory work with that occupancy we might need to union the block loads/stores
|
||||
|
||||
// each block loads TILE_COLs columns and TILE_ROW rows
|
||||
// after reading a tile the row counter increase by TILE_ROWS
|
||||
// the col counter reset after reading TILE_COL elements
|
||||
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
|
||||
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
|
||||
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
|
||||
const int base_idx = (base_row*cols) + base_col;
|
||||
|
||||
// we load 128 bytes per warp with
|
||||
// 32 rows for transposes that fill col32 types
|
||||
// so that we can have contiguous stores
|
||||
__shared__ char smem_data[32*33*ITEMS_PER_THREAD];
|
||||
char local_data[ITEMS_PER_THREAD];
|
||||
typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange;
|
||||
__shared__ typename BlockExchange::TempStorage temp_storage;
|
||||
|
||||
// we load row after row from the base_position
|
||||
// Load data row by row
|
||||
int warps = blockDim.x/32;
|
||||
int warp_id = threadIdx.x/32;
|
||||
int warp_lane = threadIdx.x % 32;
|
||||
int offset = 0;
|
||||
|
||||
int smem_row = 0;
|
||||
// each warp loads one row of 128 bytes
|
||||
for(int row = warp_id; row < TILE_ROWS; row+=warps)
|
||||
{
|
||||
int i = base_idx + (row*cols);
|
||||
// we load up to 128 bytes/items per load
|
||||
int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col;
|
||||
|
||||
// 0. Load data into 32*32 shared memory tiles
|
||||
if(base_row + row < rows)
|
||||
{
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
{
|
||||
int col_idx = warp_lane+(j*32);
|
||||
if(col_idx < valid_items)
|
||||
local_data[j] = A[i+col_idx];
|
||||
else
|
||||
local_data[j] = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
local_data[j] = 0;
|
||||
}
|
||||
|
||||
if(TRANSPOSE)
|
||||
{
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
{
|
||||
int local_col = (32*j)+warp_lane;
|
||||
//int local_row = row;
|
||||
// store as 256x32
|
||||
smem_data[(local_col*33) + row] = local_data[j];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// treat smem as 32x256, that is 32 rows and 256 columns
|
||||
#pragma unroll ITEMS_PER_THREAD
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j];
|
||||
}
|
||||
|
||||
|
||||
|
||||
smem_row += warps;
|
||||
|
||||
// 1. transpose / reorder in shared memory
|
||||
if(smem_row % 32 == 0)
|
||||
{
|
||||
smem_row = 0;
|
||||
__syncthreads();
|
||||
|
||||
for(int subrow = warp_id; subrow < 32; subrow+=warps)
|
||||
{
|
||||
for(int j = 0; j < ITEMS_PER_THREAD; j++)
|
||||
{
|
||||
|
||||
switch(FORMAT)
|
||||
{
|
||||
case COL32:
|
||||
if(TRANSPOSE)
|
||||
{
|
||||
// data lies in shared memory in the following way:
|
||||
// row0 [col0 col1 ... col31]
|
||||
// row1 [col0 col1 ... col31]
|
||||
// ...
|
||||
//
|
||||
// As such we read consequtive entries with 256 threads (8rows x 32 columns)
|
||||
// as j increase, the row increase by a factor of 8
|
||||
// We load 8 rows per subrow loop, and subrow increase by 8 per loop
|
||||
// so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8
|
||||
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
|
||||
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
|
||||
//const int local_row = warp_id; // each warp_id is one row
|
||||
//const int block_row = base_col; // block offset for row
|
||||
//const int local_col = warp_lane
|
||||
//const int global_col = base_row; // block offset for col
|
||||
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
|
||||
{
|
||||
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
|
||||
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
|
||||
|
||||
// each 32 columns we have new tile
|
||||
// each tile has size outRows*32 and base_row is done in increments of 32
|
||||
offset = base_row*outRows;
|
||||
out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
|
||||
{
|
||||
offset = (base_col/32)*(32*rows);
|
||||
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
|
||||
out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case COL_TURING:
|
||||
// TURING FORMAT:
|
||||
// 8*32 tiles with 4*4 subtiles
|
||||
// the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements)
|
||||
// the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero
|
||||
// the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32])
|
||||
// the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column
|
||||
// index increases by 32
|
||||
//
|
||||
// [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...]
|
||||
if(TRANSPOSE)
|
||||
{
|
||||
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
|
||||
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
|
||||
//const int local_row = warp_id; // each warp_id is one row
|
||||
//const int block_row = base_col; // block offset for row
|
||||
//const int local_col = warp_lane
|
||||
//const int global_col = base_row; // block offset for col
|
||||
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
|
||||
{
|
||||
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
|
||||
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
|
||||
|
||||
// each 32 columns we have new tile
|
||||
// each tile has size 8*32 = 256 elements offset
|
||||
// for each row offset of 8 we increaes the tile first
|
||||
// after all rows are exhausted, we increase the col
|
||||
int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
|
||||
|
||||
// we increase by row_tile_column every 32 columns
|
||||
// base_row increase in increments of 32
|
||||
//int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
|
||||
//int col_offset = (base_row/32)*row_tile_column;
|
||||
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
|
||||
// 256*outRows/8*base_row/32 = outRows*base_row
|
||||
int col_offset = outRows*base_row;
|
||||
|
||||
offset = row_offset+col_offset;
|
||||
|
||||
// since we process even number of rows with each j (8) and with each subrow (8j) we can determine
|
||||
// odd or even rows with the warp_id (each warp processes one row)
|
||||
// the col is warp_lane (max 32 columns per row) and the row warp_id
|
||||
if(warp_id % 2 == 1)
|
||||
// odd
|
||||
offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2);
|
||||
else
|
||||
// even
|
||||
offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2);
|
||||
|
||||
out[offset] = data;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
|
||||
{
|
||||
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
|
||||
// set offset designates the tile offset among the 8*32 tiles
|
||||
// we first increase rows and then columns. Since we load 128 columns at once
|
||||
// we increase the offset by outRows*32 every 32 columns
|
||||
// additionally, we increase the offset by 8*32=256 every 8 rows
|
||||
offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile)
|
||||
// first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd
|
||||
// each of these has 32 values in total for 32*4 = 128 as offset if odd
|
||||
// every set of 4 columns increases the total offset by 16
|
||||
// each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2
|
||||
// this happends every 8 rows anew (subrow % 8)
|
||||
// one writes 4 columns at once that is (col % 4) for the particular index in the subtile
|
||||
int subcol = warp_lane;
|
||||
|
||||
// add local offset (4x4 sub-tile)
|
||||
if(subrow % 2 == 1)
|
||||
// odd
|
||||
offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2);
|
||||
else
|
||||
// even
|
||||
offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2);
|
||||
|
||||
out[offset] = data;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case COL_AMPERE:
|
||||
// AMPERE FORMAT:
|
||||
// 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows:
|
||||
// row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
|
||||
// the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32]
|
||||
if(TRANSPOSE)
|
||||
{
|
||||
const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j
|
||||
const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps)
|
||||
//const int local_row = warp_id; // each warp_id is one row
|
||||
//const int block_row = base_col; // block offset for row
|
||||
//const int local_col = warp_lane
|
||||
//const int global_col = base_row; // block offset for col
|
||||
if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows))
|
||||
{
|
||||
// each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem
|
||||
char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane];
|
||||
|
||||
// each 32 columns we have new tile
|
||||
// each tile has size 32*32 = 1024 elements offset
|
||||
// for each row offset of 32 we increaes the tile first
|
||||
// after all rows are exhausted, we increase the col
|
||||
int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows
|
||||
|
||||
// we increase by row_tile_column every 32 columns
|
||||
// base_row increase in increments of 32
|
||||
//int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
|
||||
//int col_offset = (base_row/32)*row_tile_column;
|
||||
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
|
||||
// 1024*outRows/32*base_row/32 = outRows*base_row
|
||||
int col_offset = outRows*base_row;
|
||||
|
||||
offset = row_offset+col_offset;
|
||||
|
||||
|
||||
// same as in the non-transpose case (see below)
|
||||
// the difference is that now rows = cols
|
||||
// in this case warp_id = subrow
|
||||
|
||||
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
|
||||
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
|
||||
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
|
||||
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
|
||||
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
|
||||
int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset
|
||||
int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2);
|
||||
|
||||
// global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane
|
||||
out[offset + (ampere_row*32) + warp_lane] = data;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols))
|
||||
{
|
||||
char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane];
|
||||
|
||||
// set offset designates the tile offset among the 32*32 tiles
|
||||
// we first increase rows and then columns. Since we load 128 columns at once
|
||||
// we increase the offset by outRows*32 every 32 columns
|
||||
// additionally, we increase the offset by 32*32=1024 every 32 rows
|
||||
offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile)
|
||||
|
||||
// [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]...
|
||||
// subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc
|
||||
// subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row
|
||||
// every 2 rows, the offset increases by two [0, 1, 8, 9...]
|
||||
// every 2 rows, the row index increase by 8 [0, 1, 8, 9...]
|
||||
int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2);
|
||||
|
||||
// global offset + row with 32 cols each + 32 cols per j + col_idx
|
||||
out[offset + (local_row*32) + warp_lane] = data;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define C 1.0f/127.0f
|
||||
#define MAX_SPARSE_COUNT 32
|
||||
#define SMEM_SIZE 8*256
|
||||
template <typename T, int SPMM_ITEMS, int BITS>
|
||||
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{
|
||||
|
||||
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
|
||||
// If a block finishes, the next one is scheduled. Since the last blocks like have fewer
|
||||
// elements they finish faster "fillin up" the gaps left by larger blocks
|
||||
|
||||
// without tensor cores
|
||||
// 1. use rowidx_length to find what to load (as many blocks as there are rows)
|
||||
// 2. Load A into registers
|
||||
// 3. each warp loads all required rows of B but each warp is offset by k
|
||||
// 4. Do mma operations that accumulate into registers
|
||||
// 5. Each warp stores its output row into matrix C
|
||||
|
||||
const int count = max_count[blockIdx.x];
|
||||
const int local_max_idx = max_idx[blockIdx.x];
|
||||
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
|
||||
const int local_row_idx = rowidx[offset];
|
||||
|
||||
const int warp_id = threadIdx.x / 32;
|
||||
const int warp_idx = threadIdx.x % 32;
|
||||
const int warp_offset = (warp_id*32)*SPMM_ITEMS;
|
||||
const int num_items = BITS == 8 ? 8 : 8;
|
||||
int idx_col_B = warp_offset;
|
||||
int local_idx_col_B_offset = 0;
|
||||
|
||||
half local_valA[MAX_SPARSE_COUNT];
|
||||
int local_colidxA[MAX_SPARSE_COUNT];
|
||||
half local_valC[SPMM_ITEMS];
|
||||
T local_valsB[num_items];
|
||||
half local_valOut[num_items];
|
||||
// 128 byte loads per warp == 4 bytes per thread
|
||||
|
||||
// 2. Load A into registers
|
||||
for(int j = 0; j < MAX_SPARSE_COUNT; j++)
|
||||
{
|
||||
local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
|
||||
local_colidxA[j] = j < count ? colidx[offset+j] : 0;
|
||||
}
|
||||
|
||||
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
|
||||
// we expect each warp to be SPMM_ITEMS*32 apart
|
||||
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
|
||||
// added 3 bytes = 6 values between warps should reduce bank conflicts
|
||||
__shared__ half smem_dequant_stats[SMEM_SIZE];
|
||||
|
||||
|
||||
while(idx_col_B < colsB)
|
||||
{
|
||||
|
||||
if(dequant_stats != NULL)
|
||||
{
|
||||
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
|
||||
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
|
||||
smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
#pragma unroll SPMM_ITEMS
|
||||
for(int j = 0; j < SPMM_ITEMS; j++)
|
||||
local_valC[j] = 0.0f;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < count; i++)
|
||||
{
|
||||
// 3. each warp loads all required rows of B but each warp is offset by k
|
||||
int row_offset = colsB*local_colidxA[i];
|
||||
|
||||
#pragma unroll SPMM_ITEMS
|
||||
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
|
||||
{
|
||||
// 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
|
||||
int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
|
||||
if(idx >= colsB){ break; }
|
||||
//printf("%i %i\n", (row_offset+idx) % num_items, row_offset+idx);
|
||||
if((idx+num_items < colsB))
|
||||
{
|
||||
if(BITS == 8)
|
||||
reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
|
||||
else
|
||||
reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll num_items
|
||||
for(int k = 0; k < num_items; k++)
|
||||
if(idx+k < colsB)
|
||||
local_valsB[k] = B[row_offset+idx+k];
|
||||
else
|
||||
local_valsB[k] = 0.0f;
|
||||
}
|
||||
#pragma unroll num_items
|
||||
for(int k = 0; k < num_items; k++)
|
||||
{
|
||||
//if((float)local_valsB[k] != 0.0)
|
||||
// printf("%f %i %i %i\n", (float)local_valsB[k], k, idx, colsB);
|
||||
if(BITS == 8 && dequant_stats != NULL)
|
||||
// we do texture cache reads (__ldg) on dequant_stats which should be super fast
|
||||
{
|
||||
float valB = local_valsB[k];
|
||||
float valA = local_valA[i];
|
||||
if(valB != 0.0 && valA != 0.0)
|
||||
local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*C*valB*valA;
|
||||
}
|
||||
else
|
||||
local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int idx_row_C = (colsB*local_row_idx);
|
||||
|
||||
#pragma unroll SPMM_ITEMS
|
||||
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
|
||||
{
|
||||
//int idx_col_C = idx_col_B + (32*j) + warp_idx;
|
||||
int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
|
||||
int idx_val = idx_col_C + idx_row_C;
|
||||
|
||||
if(idx_col_C +num_items < colsB)
|
||||
{
|
||||
|
||||
// load outputs to do inplace addition
|
||||
reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];
|
||||
|
||||
#pragma unroll num_items
|
||||
for(int k = 0; k < num_items; k++)
|
||||
local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
|
||||
|
||||
reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll num_items
|
||||
for(int k = 0; k < num_items; k++)
|
||||
if(idx_col_C + k < colsB)
|
||||
out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
|
||||
}
|
||||
}
|
||||
|
||||
idx_col_B += blockDim.x*SPMM_ITEMS;
|
||||
local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
|
||||
}
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
|
||||
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
|
||||
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
|
||||
|
||||
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
|
||||
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
|
||||
|
||||
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
|
||||
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
|
||||
|
||||
|
|
|
@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl
|
|||
|
||||
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
|
||||
|
||||
|
||||
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
|
||||
|
||||
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
|
||||
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
|
||||
half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n);
|
||||
|
||||
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
|
||||
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
|
||||
|
||||
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
|
|
406
csrc/ops.cu
406
csrc/ops.cu
|
@ -8,6 +8,7 @@
|
|||
#include <cub/device/device_scan.cuh>
|
||||
#include <limits>
|
||||
#include <BinSearch.h>
|
||||
#include <cassert>
|
||||
#include <common.h>
|
||||
|
||||
|
||||
|
@ -188,11 +189,416 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
|
|||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
|
||||
{
|
||||
const int falpha = 1;
|
||||
const int fbeta = 0;
|
||||
const void * alpha = &falpha;
|
||||
const void * beta = &fbeta;
|
||||
cublasStatus_t status;
|
||||
|
||||
status = cublasGemmEx(context->m_handle,
|
||||
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
m, n, k,
|
||||
alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta,
|
||||
C, CUDA_R_32I, ldc,
|
||||
CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
{
|
||||
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
|
||||
long long int strideA, long long int strideB, long long int strideC, int batchCount)
|
||||
{
|
||||
const int falpha = 1;
|
||||
const int fbeta = 0;
|
||||
const void * alpha = &falpha;
|
||||
const void * beta = &fbeta;
|
||||
cublasStatus_t status;
|
||||
|
||||
//cout << transposeA << transposeB << endl;
|
||||
//printf("%i %i %i\n", m,n,k);
|
||||
//printf("%i %i %i\n", lda,ldb,ldc);
|
||||
//printf("%i %i %i\n", strideA, strideB, strideC);
|
||||
//printf("%i\n", batchCount);
|
||||
|
||||
status = cublasGemmStridedBatchedEx(context->m_handle,
|
||||
transposeA ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
m, n, k,
|
||||
alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta,
|
||||
C, CUDA_R_32I, ldc, (long long int)strideC, batchCount,
|
||||
CUDA_R_32I, CUBLAS_GEMM_DEFAULT);
|
||||
|
||||
if (status != CUBLAS_STATUS_SUCCESS)
|
||||
{
|
||||
std::cout << "CUBLAS ERROR: Status " << status << std::endl;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int roundoff(int v, int d) {
|
||||
return (v + d - 1) / d * d;
|
||||
}
|
||||
|
||||
|
||||
template<int ORDER> cublasLtOrder_t get_order()
|
||||
{
|
||||
switch(ORDER)
|
||||
{
|
||||
case ROW:
|
||||
return CUBLASLT_ORDER_ROW;
|
||||
break;
|
||||
case COL:
|
||||
return CUBLASLT_ORDER_COL;
|
||||
break;
|
||||
case COL32:
|
||||
return CUBLASLT_ORDER_COL32;
|
||||
break;
|
||||
case COL_TURING:
|
||||
return CUBLASLT_ORDER_COL4_4R2_8C;
|
||||
break;
|
||||
case COL_AMPERE:
|
||||
return CUBLASLT_ORDER_COL32_2R_4R4;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template cublasLtOrder_t get_order<ROW>();
|
||||
template cublasLtOrder_t get_order<COL>();
|
||||
template cublasLtOrder_t get_order<COL32>();
|
||||
template cublasLtOrder_t get_order<COL_TURING>();
|
||||
template cublasLtOrder_t get_order<COL_AMPERE>();
|
||||
|
||||
|
||||
template<int ORDER> int get_leading_dim(int dim1, int dim2)
|
||||
{
|
||||
switch(ORDER)
|
||||
{
|
||||
case ROW:
|
||||
return dim2;
|
||||
break;
|
||||
case COL:
|
||||
return dim1;
|
||||
break;
|
||||
case COL32:
|
||||
// 32*row tiles
|
||||
return dim1*32;
|
||||
break;
|
||||
case COL_TURING:
|
||||
return 32*roundoff(dim1, 8);
|
||||
break;
|
||||
case COL_AMPERE:
|
||||
// 32*32 tiles
|
||||
return 32*roundoff(dim1, 32);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template int get_leading_dim<ROW>(int dim1, int dim2);
|
||||
template int get_leading_dim<COL>(int dim1, int dim2);
|
||||
template int get_leading_dim<COL32>(int dim1, int dim2);
|
||||
|
||||
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2)
|
||||
{
|
||||
|
||||
cublasLtOrder_t orderA = get_order<SRC>();
|
||||
cublasLtOrder_t orderOut = get_order<TARGET>();
|
||||
int ldA = get_leading_dim<SRC>(dim1, dim2);
|
||||
int ldOut = get_leading_dim<TARGET>(dim1, dim2);
|
||||
|
||||
cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
|
||||
cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
|
||||
cublasOperation_t opTranspose = CUBLAS_OP_T;
|
||||
float transformAlpha = 1.0f, transformBeta = 0.0f;
|
||||
|
||||
|
||||
if(DTYPE == 8)
|
||||
{
|
||||
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA));
|
||||
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut));
|
||||
}
|
||||
else if(DTYPE == 32)
|
||||
{
|
||||
checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA));
|
||||
checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut));
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE);
|
||||
}
|
||||
|
||||
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA)));
|
||||
checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut)));
|
||||
|
||||
checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F));
|
||||
|
||||
if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); }
|
||||
|
||||
checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0));
|
||||
|
||||
if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc));
|
||||
if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc));
|
||||
if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc));
|
||||
}
|
||||
|
||||
template void transform<int8_t, ROW, COL, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||
template void transform<int8_t, ROW, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||
template void transform<int8_t, ROW, COL32, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||
template void transform<int32_t, ROW, COL32, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
|
||||
template void transform<int8_t, ROW, COL_TURING, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||
template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
|
||||
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
|
||||
|
||||
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{
|
||||
int has_error = 0;
|
||||
cublasLtMatmulDesc_t matmulDesc = NULL;
|
||||
cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL;
|
||||
cublasOperation_t opT = CUBLAS_OP_T;
|
||||
cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO;
|
||||
cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32;
|
||||
cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C;
|
||||
cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4;
|
||||
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb));
|
||||
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
|
||||
if(FORMATB == COL_TURING)
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing)));
|
||||
else
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere)));
|
||||
|
||||
if(DTYPE_OUT == 32)
|
||||
{
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I));
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
|
||||
int alpha = 1, beta = 0;
|
||||
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F));
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc));
|
||||
has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
|
||||
if(!SCALE_ROWS)
|
||||
{
|
||||
float alpha = 1.0f, beta = 0.0f;
|
||||
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec)));
|
||||
has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc));
|
||||
if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc));
|
||||
if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc));
|
||||
if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc));
|
||||
if(has_error == 1)
|
||||
printf("error detected");
|
||||
|
||||
return has_error;
|
||||
}
|
||||
|
||||
int fill_up_to_nearest_multiple(int value, int multiple)
|
||||
{
|
||||
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
|
||||
}
|
||||
|
||||
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
|
||||
{
|
||||
int threads = 512;
|
||||
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
|
||||
int n = numRows*tileCols;
|
||||
int subtile_rows = 128;
|
||||
int tilesize = 32*subtile_rows;
|
||||
int num_blocks = numRows/subtile_rows;
|
||||
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
|
||||
num_blocks = num_blocks*(tileCols/32);
|
||||
assert(threads <= tilesize);
|
||||
|
||||
//cout << num_blocks << " blocks" << endl;
|
||||
|
||||
kdequant_mm_int32_fp16<4, 128, 512><<<num_blocks, threads>>>(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols, tileCols, n);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
#define STATS_THREADS 64
|
||||
#define STATS_ITEMS 4
|
||||
#define STATS_ROWS 16
|
||||
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
|
||||
{
|
||||
int tile_cols = STATS_THREADS*STATS_ITEMS;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/STATS_ROWS);
|
||||
|
||||
if(nnz_threshold == 0.0)
|
||||
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 0><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
||||
else if(nnz_threshold != 0.0)
|
||||
kgetColRowStats<half, STATS_THREADS, STATS_ITEMS, STATS_ROWS, STATS_THREADS*STATS_ITEMS, 1><<<num_blocks, STATS_THREADS>>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
|
||||
}
|
||||
|
||||
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols)
|
||||
{
|
||||
int threads = 64;
|
||||
int items_per_thread = 4;
|
||||
int tile_cols = threads*items_per_thread;
|
||||
int tile_rows = 16;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
|
||||
|
||||
//cout << cols << " " << tiledCols << " " << tiledRows << endl;
|
||||
//cout << "num blocks " << num_blocks << endl;
|
||||
|
||||
//cout << A << " " << out_col_normed << endl;
|
||||
if(threshold > 0.0f)
|
||||
kDoubleRowColQuant<64, 4, 16, 64*4, 1><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
|
||||
else
|
||||
kDoubleRowColQuant<64, 4, 16, 64*4, 0><<<num_blocks, threads>>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols);
|
||||
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols)
|
||||
{
|
||||
int threads = 256;
|
||||
int items_per_thread = 8;
|
||||
// we load 128 column values per warp
|
||||
int tile_cols = 32*items_per_thread;
|
||||
int tile_rows = 32;
|
||||
int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols);
|
||||
int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows);
|
||||
int num_blocks = (tiledCols/tile_cols) * (tiledRows/tile_rows);
|
||||
int outCols = fill_up_to_nearest_multiple(cols, 32);
|
||||
int outRows = fill_up_to_nearest_multiple(rows, 32);
|
||||
if(FORMAT == COL_TURING)
|
||||
{
|
||||
if(TRANSPOSE)
|
||||
outRows = fill_up_to_nearest_multiple(cols, 8);
|
||||
else
|
||||
outRows = fill_up_to_nearest_multiple(rows, 8);
|
||||
}
|
||||
else if(FORMAT == COL_AMPERE)
|
||||
{
|
||||
if(TRANSPOSE)
|
||||
outRows = fill_up_to_nearest_multiple(cols, 32);
|
||||
else
|
||||
outRows = fill_up_to_nearest_multiple(rows, 32);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(TRANSPOSE)
|
||||
{
|
||||
outCols = fill_up_to_nearest_multiple(rows, 32);
|
||||
outRows = cols;
|
||||
}
|
||||
}
|
||||
|
||||
//cout << cols << " " << tiledCols << " " << tiledRows << " " << outCols << endl;
|
||||
//cout << "num blocks " << num_blocks << endl;
|
||||
|
||||
//cout << A << " " << out_col_normed << endl;
|
||||
kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<<num_blocks, threads>>>(A, out, rows, cols, tiledCols, outRows, outCols);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
|
||||
{
|
||||
|
||||
cusparseSpMatDescr_t descA;
|
||||
cusparseDnMatDescr_t descB, descC;
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
void *dBuffer = NULL;
|
||||
size_t bufferSize = 0;
|
||||
|
||||
CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
|
||||
A_rowidx, A_colidx, A_vals,
|
||||
CUSPARSE_INDEX_32I,
|
||||
CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) );
|
||||
// Create dense matrix C
|
||||
CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
|
||||
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
|
||||
// Create dense matrix B
|
||||
if(transposed_B)
|
||||
{
|
||||
int tmp = A_cols;
|
||||
A_cols = B_cols;
|
||||
B_cols = tmp;
|
||||
}
|
||||
|
||||
CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
|
||||
CUDA_R_16F, CUSPARSE_ORDER_ROW) );
|
||||
// allocate an external buffer if needed
|
||||
CHECK_CUSPARSE( cusparseSpMM_bufferSize(
|
||||
handle,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
|
||||
CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
|
||||
CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) );
|
||||
|
||||
// execute SpMM
|
||||
CHECK_CUSPARSE( cusparseSpMM(handle,
|
||||
CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE,
|
||||
&alpha, descA, descB, &beta, descC, CUDA_R_32F,
|
||||
CUSPARSE_SPMM_ALG_DEFAULT, dBuffer));
|
||||
|
||||
// destroy matrix/vector descriptors
|
||||
CHECK_CUSPARSE( cusparseDestroySpMat(descA) );
|
||||
CHECK_CUSPARSE( cusparseDestroyDnMat(descB) );
|
||||
CHECK_CUSPARSE( cusparseDestroyDnMat(descC) );
|
||||
CUDA_CHECK_RETURN( cudaFree(dBuffer) );
|
||||
}
|
||||
|
||||
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{
|
||||
|
||||
kspmm_coo_very_sparse_naive<T, 8, BITS><<<nnz_rows, 256>>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
|
||||
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
//==============================================================
|
||||
// TEMPLATE DEFINITIONS
|
||||
//==============================================================
|
||||
|
||||
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
|
||||
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
|
||||
|
||||
template int igemmlt<COL_TURING, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
|
||||
template int igemmlt<COL_TURING, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
|
||||
template int igemmlt<COL_TURING, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
|
||||
template int igemmlt<COL_AMPERE, 32, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
|
||||
template int igemmlt<COL_AMPERE, 8, 0>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
|
||||
template int igemmlt<COL_AMPERE, 8, 1>(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
|
||||
|
||||
template void transformRowToFormat<COL32, 0>(char * A, char *out, int rows, int cols);
|
||||
template void transformRowToFormat<COL32, 1>(char * A, char *out, int rows, int cols);
|
||||
template void transformRowToFormat<COL_TURING, 0>(char * A, char *out, int rows, int cols);
|
||||
template void transformRowToFormat<COL_TURING, 1>(char * A, char *out, int rows, int cols);
|
||||
template void transformRowToFormat<COL_AMPERE, 0>(char * A, char *out, int rows, int cols);
|
||||
template void transformRowToFormat<COL_AMPERE, 1>(char * A, char *out, int rows, int cols);
|
||||
|
||||
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||
|
||||
|
|
104
csrc/ops.cuh
104
csrc/ops.cuh
|
@ -14,6 +14,11 @@
|
|||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cublasLt.h>
|
||||
#include <cusparse.h>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
#define CUDA_CHECK_RETURN(value) { \
|
||||
cudaError_t _m_cudaStat = value; \
|
||||
|
@ -25,6 +30,34 @@
|
|||
|
||||
#define THREADS_PER_BLOCKS (512)
|
||||
|
||||
#define CHECK_CUSPARSE(value) { \
|
||||
cusparseStatus_t _m_cudaStat = value; \
|
||||
if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \
|
||||
fprintf(stderr, "Error %s at line %d in file %s\n", \
|
||||
cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
|
||||
exit(1); \
|
||||
} }
|
||||
|
||||
|
||||
#define THREADS_PER_BLOCKS (512)
|
||||
|
||||
|
||||
inline void checkCudaStatus(cudaError_t status) {
|
||||
if (status != cudaSuccess) {
|
||||
printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status));
|
||||
throw std::logic_error("cuda API failed");
|
||||
}
|
||||
}
|
||||
|
||||
inline int checkCublasStatus(cublasStatus_t status) {
|
||||
if (status != CUBLAS_STATUS_SUCCESS) {
|
||||
printf("cuBLAS API failed with status %d\n", status);
|
||||
//throw std::logic_error("cuBLAS API failed");
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
typedef enum Operations_t
|
||||
{
|
||||
ksmul = 0,
|
||||
|
@ -39,6 +72,57 @@ typedef enum Optimizer_t
|
|||
ADAGRAD = 4,
|
||||
} Optimizer_t;
|
||||
|
||||
typedef enum Transform_t
|
||||
{
|
||||
ROW = 0,
|
||||
COL = 1,
|
||||
COL32 = 2,
|
||||
COL_TURING = 3,
|
||||
COL_AMPERE = 4,
|
||||
} Transform_t;
|
||||
|
||||
class Context
|
||||
{
|
||||
public:
|
||||
cublasHandle_t m_handle;
|
||||
|
||||
Context()
|
||||
{
|
||||
cublasHandle_t handle;
|
||||
cublasCreate_v2(&handle);
|
||||
m_handle = handle;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class ContextLt
|
||||
{
|
||||
public:
|
||||
cublasLtHandle_t m_handle;
|
||||
|
||||
ContextLt()
|
||||
{
|
||||
cublasLtHandle_t handle;
|
||||
cublasLtCreate(&handle);
|
||||
m_handle = handle;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
class ContextCusparse
|
||||
{
|
||||
public:
|
||||
cusparseHandle_t m_handle;
|
||||
|
||||
ContextCusparse()
|
||||
{
|
||||
cusparseHandle_t handle;
|
||||
cusparseCreate(&handle);
|
||||
m_handle = handle;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
|
||||
|
||||
|
@ -70,4 +154,24 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
|
|||
|
||||
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
|
||||
|
||||
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
|
||||
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
|
||||
long long int strideA, long long int strideB, long long int strideC, int batchCount);
|
||||
|
||||
|
||||
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc);
|
||||
|
||||
template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2);
|
||||
void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
|
||||
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols);
|
||||
void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols);
|
||||
void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed,
|
||||
int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols);
|
||||
|
||||
template <int FORMAT, int TRANSPOSE> void transformRowToFormat(char * A, char *out, int rows, int cols);
|
||||
|
||||
void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B);
|
||||
|
||||
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
|
||||
|
||||
#endif
|
||||
|
|
|
@ -84,6 +84,52 @@ void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half
|
|||
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||
#endif
|
||||
|
||||
#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
{ \
|
||||
transform<dtype, src, target, transpose, bits>(ltHandle, A, out, dim1, dim2); \
|
||||
} \
|
||||
|
||||
MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8);
|
||||
MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8);
|
||||
MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8);
|
||||
MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32);
|
||||
MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8);
|
||||
MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8);
|
||||
MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8);
|
||||
MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32);
|
||||
|
||||
void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 0>(A, out, rows, cols); }
|
||||
void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat<COL32, 1>(A, out, rows, cols); }
|
||||
void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 0>(A, out, rows, cols); }
|
||||
void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_TURING, 1>(A, out, rows, cols); }
|
||||
void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 0>(A, out, rows, cols); }
|
||||
void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat<COL_AMPERE, 1>(A, out, rows, cols); }
|
||||
|
||||
int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt<COL_TURING, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt<COL_TURING, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt<COL_TURING, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt<COL_AMPERE, 32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt<COL_AMPERE, 8, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt<COL_AMPERE, 8, 1>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{ spmm_coo_very_sparse_naive<half, 16>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
|
||||
|
||||
void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{ spmm_coo_very_sparse_naive<signed char, 8>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
|
||||
|
||||
extern "C"
|
||||
{
|
||||
#if BUILD_CUDA
|
||||
|
@ -155,7 +201,86 @@ extern "C"
|
|||
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
|
||||
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
|
||||
|
||||
#endif
|
||||
void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
|
||||
{ gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); }
|
||||
void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
|
||||
long strideA, long strideB, long strideC, int batchCount)
|
||||
{ strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); }
|
||||
|
||||
Context *get_context(){ return new Context(); }
|
||||
ContextCusparse *get_cusparse(){ return new ContextCusparse(); }
|
||||
|
||||
int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
//{ (cublasLtHandle_t)context->m_handle; return 0; }
|
||||
//{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
|
||||
{ return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); }
|
||||
|
||||
#define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \
|
||||
void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \
|
||||
{ \
|
||||
transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \
|
||||
} \
|
||||
|
||||
MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8)
|
||||
MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8)
|
||||
MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8)
|
||||
MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32)
|
||||
MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8)
|
||||
MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8)
|
||||
MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8)
|
||||
MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32)
|
||||
|
||||
void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, int numRows, int numCols)
|
||||
{ dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, numRows, numCols); }
|
||||
void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols)
|
||||
{ getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); }
|
||||
|
||||
void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols)
|
||||
{ doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); }
|
||||
|
||||
void ctransform_row2col32(char * A, char *out, int rows, int cols)
|
||||
{ transform_row2col32(A, out, rows, cols); }
|
||||
|
||||
void ctransform_row2col32T(char * A, char *out, int rows, int cols)
|
||||
{ transform_row2col32T(A, out, rows, cols); }
|
||||
|
||||
void ctransform_row2turing(char * A, char *out, int rows, int cols)
|
||||
{ transform_row2turing(A, out, rows, cols); }
|
||||
|
||||
void ctransform_row2turingT(char * A, char *out, int rows, int cols)
|
||||
{ transform_row2turingT(A, out, rows, cols); }
|
||||
|
||||
void ctransform_row2ampere(char * A, char *out, int rows, int cols)
|
||||
{ transform_row2ampere(A, out, rows, cols); }
|
||||
|
||||
void ctransform_row2ampereT(char * A, char *out, int rows, int cols)
|
||||
{ transform_row2ampereT(A, out, rows, cols); }
|
||||
|
||||
void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
|
||||
{ spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); }
|
||||
|
||||
void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{ spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
|
||||
|
||||
void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
|
||||
{ spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); }
|
||||
|
||||
#endif
|
||||
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
|
||||
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
|
||||
}
|
||||
|
|
270
tests/test_autograd.py
Normal file
270
tests/test_autograd.py
Normal file
|
@ -0,0 +1,270 @@
|
|||
import pytest
|
||||
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
|
||||
from itertools import product
|
||||
|
||||
n = 1
|
||||
k = 25
|
||||
dim1 = torch.randint(16,64, size=(n,)).tolist()
|
||||
dim2 = torch.randint(32,96, size=(n,)).tolist()
|
||||
dim3 = torch.randint(32,96, size=(n,)).tolist()
|
||||
dim4 = torch.randint(32,96, size=(n,)).tolist()
|
||||
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
|
||||
str_funcs = ['bmm', 'matmul']
|
||||
req_grad = [(False, False), (True, False), (True, True), (False, True)]
|
||||
req_grad_str = ['FF', 'TF', 'TT', 'FT']
|
||||
transpose = [(False, False), (False, True), (True, True), (True, False)]
|
||||
str_transpose = ['FF', 'FT', 'TT', 'TF']
|
||||
dtype = [torch.float32, torch.float16]
|
||||
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose))
|
||||
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose))
|
||||
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values]
|
||||
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names)
|
||||
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
|
||||
dim2 = dim2 - (dim2 % 16)
|
||||
dim3 = dim3 - (dim3 % 16)
|
||||
dim4 = dim4 - (dim4 % 16)
|
||||
for i in range(k):
|
||||
|
||||
# normal multiply
|
||||
if funcs[0] in [torch.mm, torch.matmul]:
|
||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0])
|
||||
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
|
||||
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1])
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
if not transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B)
|
||||
elif not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B.t())
|
||||
elif transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A.t(), B)
|
||||
out_bnb = funcs[1](A.t(), B)
|
||||
elif transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A.t(), B.t())
|
||||
out_bnb = funcs[1](A.t(), B.t())
|
||||
|
||||
n = out_bnb.numel()
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||
assert (idx==0).sum().item() < n*0.0175
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
|
||||
assert (idx==0).sum().item() < n*0.001
|
||||
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
||||
loss_bnb.backward()
|
||||
gradA1 = A.grad
|
||||
gradB1 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
|
||||
loss_torch.backward()
|
||||
gradA2 = A.grad
|
||||
gradB2 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.02
|
||||
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
|
||||
|
||||
# batched matrix multiply
|
||||
if funcs[0] in [torch.bmm, torch.matmul]:
|
||||
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
|
||||
B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1])
|
||||
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B)
|
||||
|
||||
n = out_bnb.numel()
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||
assert (idx==0).sum().item() < n*0.01
|
||||
torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
|
||||
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
||||
loss_bnb.backward()
|
||||
gradA1 = A.grad
|
||||
gradB1 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
|
||||
loss_torch.backward()
|
||||
gradA2 = A.grad
|
||||
gradB2 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.02
|
||||
|
||||
if funcs[0] in [torch.matmul]:
|
||||
dim1 = dim1 - (dim1 % 16)
|
||||
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0])
|
||||
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
|
||||
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1])
|
||||
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1])
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
|
||||
if transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B.t())
|
||||
else:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B)
|
||||
|
||||
n = out_bnb.numel()
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||
assert (idx==0).sum().item() < n*0.0175
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
|
||||
assert (idx==0).sum().item() < n*0.001
|
||||
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
||||
loss_bnb.backward()
|
||||
gradA1 = A.grad
|
||||
gradB1 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
|
||||
loss_torch.backward()
|
||||
gradA2 = A.grad
|
||||
gradB2 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.02
|
||||
|
||||
|
||||
n = 1
|
||||
k = 3
|
||||
dim1 = torch.randint(16,64, size=(n,)).tolist()
|
||||
dim2 = torch.randint(32,96, size=(n,)).tolist()
|
||||
dim3 = torch.randint(32,96, size=(n,)).tolist()
|
||||
dim4 = torch.randint(32,96, size=(n,)).tolist()
|
||||
|
||||
#dim1 = (17,)
|
||||
#dim2 = (7,)
|
||||
#dim3 = (37,)
|
||||
#dim4 = (23,)
|
||||
|
||||
decomp = [0.0, 6.0]
|
||||
funcs = [(torch.matmul, bnb.matmul)]
|
||||
str_funcs = ['matmul']
|
||||
req_grad = [(False, False), (True, False), (True, True), (False, True)]
|
||||
req_grad_str = ['FF', 'TF', 'TT', 'FT']
|
||||
transpose = [(False, True), (False, False)]
|
||||
str_transpose = ['NT', 'NN']
|
||||
dtype = [torch.float16]
|
||||
has_fp16_weights = [True, False]
|
||||
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights))
|
||||
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights))
|
||||
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values]
|
||||
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names)
|
||||
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights):
|
||||
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
|
||||
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
|
||||
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda')
|
||||
|
||||
for i in range(k):
|
||||
|
||||
# normal multiply
|
||||
if funcs[0] in [torch.mm, torch.matmul]:
|
||||
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype)
|
||||
if decomp == 6.0:
|
||||
with torch.no_grad():
|
||||
A[:, outlier_dim] = 6.0
|
||||
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype)
|
||||
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype)
|
||||
torch.nn.init.xavier_uniform_(B)
|
||||
B2 = B.clone()
|
||||
|
||||
state = bnb.MatmulLtState()
|
||||
state.threshold = decomp
|
||||
state.has_fp16_weights = has_fp16_weights
|
||||
if not has_fp16_weights:
|
||||
if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous()
|
||||
state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2)
|
||||
B2 = state.CB
|
||||
|
||||
if not transpose[0] and transpose[1]:
|
||||
out_torch = funcs[0](A, B.t())
|
||||
out_bnb = funcs[1](A, B2, state=state)
|
||||
elif not transpose[0] and not transpose[1]:
|
||||
out_torch = funcs[0](A, B)
|
||||
out_bnb = funcs[1](A, B2.t(), state=state)
|
||||
|
||||
n = out_bnb.numel()
|
||||
err = torch.abs(out_bnb-out_torch).mean().item()
|
||||
#print(f'abs error {err:.4f}')
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
|
||||
assert (idx==0).sum().item() < n*0.0175
|
||||
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
|
||||
assert (idx==0).sum().item() < n*0.001
|
||||
|
||||
if has_fp16_weights:
|
||||
if any(req_grad):
|
||||
out_bnb.data.copy_(out_torch)
|
||||
torch.cuda.synchronize()
|
||||
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
|
||||
loss_bnb.backward()
|
||||
gradA1 = A.grad
|
||||
gradB1 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
|
||||
loss_torch.backward()
|
||||
gradA2 = A.grad
|
||||
gradB2 = B.grad
|
||||
A.grad = None
|
||||
B.grad = None
|
||||
|
||||
if req_grad[0]:
|
||||
torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1)
|
||||
if req_grad[1]:
|
||||
n = gradB1.numel()
|
||||
assert torch.abs(gradB1).sum() > 0.0
|
||||
assert torch.abs(gradB2).sum() > 0.0
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.1
|
||||
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
|
||||
assert (idx==0).sum().item() < n*0.02
|
||||
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -1,42 +1,470 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from itertools import product
|
||||
from torch import nn
|
||||
|
||||
import bitsandbytes as bnb
|
||||
|
||||
class MockArgs(object):
|
||||
def __init__(self, initial_data):
|
||||
for key in initial_data:
|
||||
setattr(self, key, initial_data[key])
|
||||
|
||||
@pytest.mark.parametrize("embcls", [bnb.nn.Embedding, bnb.nn.StableEmbedding], ids=['Embedding', 'StableEmbedding'])
|
||||
def test_embeddings(embcls):
|
||||
bnb.optim.GlobalOptimManager.get_instance().initialize()
|
||||
emb1 = torch.nn.Embedding(100, 512).cuda()
|
||||
emb2 = embcls(100, 512).cuda()
|
||||
class MLP8bit(torch.nn.Module):
|
||||
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
|
||||
super(MLP8bit, self).__init__()
|
||||
self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold)
|
||||
self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold)
|
||||
|
||||
adam1 = bnb.optim.Adam8bit(emb1.parameters())
|
||||
adam2 = bnb.optim.Adam8bit(emb2.parameters())
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
batches = torch.randint(1, 100, size=(100, 4, 32)).cuda()
|
||||
|
||||
def get_args():
|
||||
args = MockArgs([])
|
||||
args.quant_type = 'vector'
|
||||
args.use_8bit_training = 'full'
|
||||
args.clip_freq = 9999
|
||||
return args
|
||||
|
||||
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
|
||||
idx = torch.isclose(a, b, rtol, atol)
|
||||
sumval = (idx==0).sum().item()
|
||||
if sumval > count:
|
||||
print(f'Too many values not close: assert {sumval} < {count}')
|
||||
torch.testing.assert_allclose(a, b, rtol, atol)
|
||||
|
||||
class LinearFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
|
||||
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
|
||||
norm = math.sqrt(math.pi)/math.sqrt(2.0)
|
||||
#std = torch.abs(x).mean()*norm
|
||||
std = torch.std(x)
|
||||
max1 = std*trim_value
|
||||
x = x/max1*127
|
||||
x = round_func(x)
|
||||
x[x > 127] = 127
|
||||
x[x < -127] = -127
|
||||
x = x/127*max1
|
||||
|
||||
return x
|
||||
|
||||
def quant(x, quant_type, dim=1):
|
||||
if quant_type == 'linear':
|
||||
max1 = torch.abs(x).max().float()
|
||||
xq = torch.round(x/max1*127).to(torch.int8)
|
||||
return xq, max1
|
||||
elif quant_type == 'vector':
|
||||
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
|
||||
xq = torch.round(x/max1*127).to(torch.int8)
|
||||
return xq, max1
|
||||
elif quant_type == 'min-max':
|
||||
maxA = torch.amax(x, dim=dim, keepdim=True).float()
|
||||
minA = torch.amin(x, dim=dim, keepdim=True).float()
|
||||
scale = (maxA-minA)/2.0
|
||||
xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8)
|
||||
return xq, (minA.float(), scale.float())
|
||||
else: return None
|
||||
|
||||
def dequant(xq, S1, S2, dtype, quant_type):
|
||||
if quant_type == 'linear':
|
||||
norm = S1*S2/(127*127)
|
||||
# double cast needed to prevent overflows
|
||||
return (xq.float()*norm).to(dtype)
|
||||
elif quant_type == 'vector':
|
||||
x = xq.float()
|
||||
if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0)
|
||||
if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0)
|
||||
#print(x.shape, S1.shape, S2.shape)
|
||||
if len(S1.shape) == 2:
|
||||
x *= S1.t()/127
|
||||
else:
|
||||
x *= S1/127
|
||||
x *= S2/127
|
||||
return x.to(dtype)
|
||||
else: return None
|
||||
|
||||
def dequant_min_max(xq, A, B, SA, SB, dtype):
|
||||
offset = B.float().t().sum(0)*(SA[0]+SA[1])
|
||||
x = xq.float()
|
||||
if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0)
|
||||
if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0)
|
||||
if len(SB.shape) == 2:
|
||||
x *= SB.t()/127
|
||||
else:
|
||||
x *= SB/127
|
||||
x *= SA[1]/127
|
||||
x +=offset
|
||||
return x.to(dtype)
|
||||
|
||||
|
||||
def get_8bit_linear(x, stochastic=False):
|
||||
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
|
||||
max1 = torch.abs(x).max()
|
||||
x = x/max1*127
|
||||
x = round_func(x)/127*max1
|
||||
#x = torch.round(x)/128*max1
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def get_8bit_vector_wise(x, dim, stochastic=False):
|
||||
round_func = LinearFunction.round_stoachastic if stochastic else torch.round
|
||||
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
|
||||
max1[max1==0] = 1.0
|
||||
x = (x*127)/max1
|
||||
x = round_func(x)/127*max1
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def round_stoachastic(x):
|
||||
sign = torch.sign(x)
|
||||
absx = torch.abs(x)
|
||||
decimal = absx-torch.floor(absx)
|
||||
rdm = torch.rand_like(decimal)
|
||||
return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype))
|
||||
|
||||
@staticmethod
|
||||
def fake_8bit_storage(w, exponent_bits):
|
||||
code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device)
|
||||
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
|
||||
out = bnb.functional.dequantize_blockwise(absmax, C, code)
|
||||
out = out.half()
|
||||
w.copy_(out)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def fake_8bit_storage_quantile(w, args):
|
||||
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
|
||||
#C = bnb.functional.quantize_no_absmax(code, w)
|
||||
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
|
||||
#print(out)
|
||||
#out = out.half()
|
||||
code /= torch.max(torch.abs(code))
|
||||
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
|
||||
out = bnb.functional.dequantize_blockwise(absmax, C, code)
|
||||
out = out.half()
|
||||
w.copy_(out)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def fake_8bit_storage_stoachstic(w):
|
||||
rand = torch.rand(1024, device=w.device)
|
||||
absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand)
|
||||
out = bnb.functional.dequantize_blockwise(absmax, C)
|
||||
out = out.half()
|
||||
w.copy_(out)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def fake_8bit_storage_with_max(w, topk=8):
|
||||
blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256)
|
||||
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
|
||||
idx = idx[:, :topk]
|
||||
max_val = max_val[:, :topk]
|
||||
|
||||
mask = torch.zeros_like(blocked_w)
|
||||
mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val))
|
||||
mask = mask.bool()
|
||||
|
||||
# 1. zero out max values
|
||||
# 2. quantize + dequantize
|
||||
# 3. write back max values
|
||||
# 4. copy matrix back to weight
|
||||
|
||||
values = blocked_w[mask]
|
||||
blocked_w[mask] = 0
|
||||
|
||||
code = bnb.functional.create_dynamic_map()
|
||||
code = code.to(w.device)
|
||||
absmax, C = bnb.functional.quantize_blockwise(blocked_w.data)
|
||||
bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w)
|
||||
|
||||
blocked_w[mask] = values
|
||||
|
||||
unblocked_w = blocked_w.flatten().view(w.shape)
|
||||
|
||||
w.copy_(unblocked_w)
|
||||
return unblocked_w
|
||||
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias=None, args=None):
|
||||
if args.use_8bit_training != 'off':
|
||||
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
|
||||
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
|
||||
outputq = bnb.functional.igemm(x8, weight8.t())
|
||||
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
|
||||
#if torch.rand(1) < 0.01:
|
||||
#output32 = torch.matmul(x, weight.t())
|
||||
#err = torch.abs(output-output32).float()
|
||||
#relerr = err/(torch.abs(output32).float()+1e-8)
|
||||
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
|
||||
else:
|
||||
#output = torch.matmul(x, weight.t())
|
||||
output = torch.einsum('bsi,oi->bso', x, weight)
|
||||
|
||||
ctx.save_for_backward(x, weight, bias)
|
||||
ctx.args = args
|
||||
|
||||
if bias is not None:
|
||||
output += bias.unsqueeze(0).expand_as(output)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
x, weight, bias = ctx.saved_tensors
|
||||
args = ctx.args
|
||||
stochastic = False
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0)
|
||||
|
||||
# weight and x are already 8bit
|
||||
# -> transform grad_output to 8-bit
|
||||
if args.use_8bit_training == 'forward+wgrad':
|
||||
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
|
||||
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
|
||||
grad_weight8 = bnb.functional.igemm(grad_output8, x8)
|
||||
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
|
||||
|
||||
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
|
||||
|
||||
grad_input = grad_output.matmul(weight)
|
||||
elif args.use_8bit_training == 'full':
|
||||
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1])
|
||||
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
|
||||
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
|
||||
bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
|
||||
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type)
|
||||
|
||||
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
|
||||
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
|
||||
grad_input8 = bnb.functional.igemm(grad_output8, weight8)
|
||||
grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type)
|
||||
|
||||
else:
|
||||
grad_input = grad_output.matmul(weight)
|
||||
grad_weight = torch.einsum('bsi,bso->oi', x, grad_output)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None
|
||||
|
||||
class Linear8bit(nn.Module):
|
||||
def __init__(self, input_features, output_features, bias=True, args=None):
|
||||
super(Linear8bit, self).__init__()
|
||||
self.input_features = input_features
|
||||
self.output_features = output_features
|
||||
self.args = args
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(output_features, input_features))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(output_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
torch.nn.init.xavier_uniform_(self.weight)
|
||||
if self.bias is not None:
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x):
|
||||
self.args.training = self.training
|
||||
|
||||
return LinearFunction.apply(x, self.weight, self.bias, self.args)
|
||||
|
||||
|
||||
|
||||
def test_linear8bit():
|
||||
l0 = torch.nn.Linear(32, 64).cuda().half()
|
||||
l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half()
|
||||
l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
|
||||
l3 = bnb.nn.Linear8bitLt(32,64).cuda().half()
|
||||
|
||||
l0.weight.data = l2.weight.data.clone()
|
||||
l0.bias.data = l2.bias.data.clone()
|
||||
|
||||
l1.weight.data = l2.weight.data.clone()
|
||||
l1.bias.data = l2.bias.data.clone()
|
||||
|
||||
l3.weight.data = l2.weight.data.clone()
|
||||
l3.bias.data = l2.bias.data.clone()
|
||||
|
||||
for i in range(100):
|
||||
batch = batches[i]
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
t = torch.randn(16, 8, 64, device='cuda').half()
|
||||
b2 = b1.clone()
|
||||
b3 = b1.clone()
|
||||
b0 = b1.clone()
|
||||
|
||||
embedded1 = emb1(batch)
|
||||
embedded2 = emb2(batch)
|
||||
o0 = l0(b0)
|
||||
o1 = l1(b1)
|
||||
o2 = l2(b2)
|
||||
o3 = l3(b3)
|
||||
|
||||
l1 = embedded1.mean()
|
||||
l2 = embedded2.mean()
|
||||
assert_all_approx_close(o1, o2, atol=0.013, rtol=0.05, count=1)
|
||||
assert_all_approx_close(o3, o2, atol=0.013, rtol=0.05, count=1)
|
||||
|
||||
l1.backward()
|
||||
l2.backward()
|
||||
loss0 = torch.nn.functional.mse_loss(o0, t)
|
||||
loss1 = torch.nn.functional.mse_loss(o1, t)
|
||||
loss2 = torch.nn.functional.mse_loss(o2, t)
|
||||
loss3 = torch.nn.functional.mse_loss(o3, t)
|
||||
|
||||
adam1.step()
|
||||
adam2.step()
|
||||
loss0.backward()
|
||||
loss1.backward()
|
||||
loss2.backward()
|
||||
loss3.backward()
|
||||
|
||||
adam1.zero_grad()
|
||||
adam2.zero_grad()
|
||||
assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
|
||||
assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
|
||||
assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
|
||||
assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2)
|
||||
|
||||
assert adam1.state[emb1.weight]['state1'].dtype == torch.uint8
|
||||
assert adam2.state[emb2.weight]['state1'].dtype == torch.float32
|
||||
err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item()
|
||||
err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item()
|
||||
err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item()
|
||||
|
||||
assert err1*0.8 < err2
|
||||
assert err2*0.8 < err3
|
||||
assert err3*0.8 < err1
|
||||
|
||||
l0.weight.grad = None
|
||||
l1.weight.grad = None
|
||||
l2.weight.grad = None
|
||||
l3.weight.grad = None
|
||||
l0.bias.grad = None
|
||||
l1.bias.grad = None
|
||||
l2.bias.grad = None
|
||||
l3.bias.grad = None
|
||||
|
||||
|
||||
threshold = [0.0, 3.0]
|
||||
values = threshold
|
||||
names = ['threshold_{0}'.format(vals) for vals in values]
|
||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||
def test_linear8bitlt_inference(threshold):
|
||||
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half()
|
||||
assert l1.weight.device.type == 'cuda'
|
||||
assert l1.weight.dtype == torch.float16
|
||||
|
||||
l1.eval()
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
o1 = l1(b1)
|
||||
if i == 1:
|
||||
assert l1.state.CxB is not None
|
||||
|
||||
def test_linear8bitlt_accumulated_gradient():
|
||||
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)])
|
||||
l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)])
|
||||
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
|
||||
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
|
||||
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
|
||||
l2[1].bias = torch.nn.Parameter(l1[1].bias.clone())
|
||||
opt1 = bnb.optim.Adam8bit(l1.parameters(), lr=0.001)
|
||||
opt2 = bnb.optim.Adam8bit(l2.parameters(), lr=0.001)
|
||||
|
||||
acc_steps = 10
|
||||
|
||||
|
||||
for i in range(10):
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
o1 = l1(b1)
|
||||
o2 = l2(b1)
|
||||
loss1 = o1.mean()
|
||||
loss2 = o2.mean()
|
||||
loss1.backward()
|
||||
loss2.backward()
|
||||
if i == 2:
|
||||
assert l1[0].state.CxB is not None
|
||||
assert l1[1].state.CxB is not None
|
||||
|
||||
if i > 0 and i % acc_steps == 0:
|
||||
opt1.step()
|
||||
opt1.zero_grad(True)
|
||||
opt2.step()
|
||||
opt2.zero_grad(True)
|
||||
assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2)
|
||||
assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2)
|
||||
# we do this copy because otherwise we have small divergences over time that add up
|
||||
l1[0].weight.data.copy_(l2[0].weight.data)
|
||||
l1[1].weight.data.copy_(l2[1].weight.data)
|
||||
else:
|
||||
torch.testing.assert_allclose(l1[0].weight.grad, l2[0].weight.grad)
|
||||
torch.testing.assert_allclose(l1[1].weight.grad, l2[1].weight.grad)
|
||||
|
||||
|
||||
threshold = [0.0, 2.0]
|
||||
values = threshold
|
||||
names = ['threshold_{0}'.format(vals) for vals in values]
|
||||
@pytest.mark.parametrize("threshold", values, ids=names)
|
||||
def test_linear8bitlt_no_fp16_weights(threshold):
|
||||
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half()
|
||||
assert l1.weight.dtype == torch.int8
|
||||
|
||||
l1.eval()
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
o1 = l1(b1)
|
||||
assert o1.dtype == torch.float16
|
||||
|
||||
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda()
|
||||
assert mlp.fc1.weight.dtype == torch.int8
|
||||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
o1 = mlp(b1)
|
||||
assert o1.dtype == torch.float16
|
||||
if threshold > 0: assert mlp.fc1.state.idx is not None
|
||||
if threshold > 0: assert mlp.fc2.state.idx is not None
|
||||
|
||||
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
|
||||
assert mlp.fc1.weight.dtype == torch.int8
|
||||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
o1 = mlp(b1)
|
||||
assert o1.dtype == torch.float16
|
||||
if threshold > 0: assert mlp.fc1.state.idx is not None
|
||||
if threshold > 0: assert mlp.fc2.state.idx is not None
|
||||
|
||||
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
o1 = mlp(b1)
|
||||
assert o1.dtype == torch.float16
|
||||
if threshold > 0: assert mlp.fc1.state.idx is not None
|
||||
if threshold > 0: assert mlp.fc2.state.idx is not None
|
||||
assert mlp.fc1.weight.dtype == torch.int8
|
||||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
|
||||
|
||||
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
o1 = mlp(b1)
|
||||
assert o1.dtype == torch.float16
|
||||
if threshold > 0: assert mlp.fc1.state.idx is not None
|
||||
if threshold > 0: assert mlp.fc2.state.idx is not None
|
||||
assert mlp.fc1.weight.dtype == torch.int8
|
||||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
assert mlp.fc1.weight.device.type == 'cuda'
|
||||
assert mlp.fc2.weight.device.type == 'cuda'
|
||||
|
||||
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda')
|
||||
|
||||
for i in range(100):
|
||||
b1 = torch.randn(16, 8, 32, device='cuda').half()
|
||||
o1 = mlp(b1)
|
||||
assert o1.dtype == torch.float16
|
||||
if threshold > 0: assert mlp.fc1.state.idx is not None
|
||||
if threshold > 0: assert mlp.fc2.state.idx is not None
|
||||
assert mlp.fc1.weight.dtype == torch.int8
|
||||
assert mlp.fc2.weight.dtype == torch.int8
|
||||
assert mlp.fc1.weight.device.type == 'cuda'
|
||||
assert mlp.fc2.weight.device.type == 'cuda'
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
import os
|
||||
import time
|
||||
import shutil
|
||||
import uuid
|
||||
import pytest
|
||||
import ctypes
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
import bitsandbytes.functional as F
|
||||
|
@ -14,7 +11,9 @@ import bitsandbytes.functional as F
|
|||
from os.path import join
|
||||
from itertools import product
|
||||
|
||||
import apex
|
||||
#import apex
|
||||
|
||||
k = 20
|
||||
|
||||
def get_temp_dir():
|
||||
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
|
||||
|
@ -26,55 +25,47 @@ def rm_path(path):
|
|||
|
||||
str2optimizers = {}
|
||||
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||
str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||
#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||
str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
|
||||
str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
|
||||
#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
|
||||
#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
|
||||
|
||||
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
|
||||
str2optimizers['adamw'] = (torch.optim.AdamW, bnb.optim.AdamW)
|
||||
str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
|
||||
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
|
||||
str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
|
||||
#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
|
||||
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
|
||||
str2optimizers['adagrad'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad(pxx, 0.01, block_wise=False))
|
||||
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
|
||||
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
|
||||
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
|
||||
str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
|
||||
#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
|
||||
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
|
||||
|
||||
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
|
||||
str2optimizers['adamw8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.AdamW8bit(pxx, block_wise=True))
|
||||
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
|
||||
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
|
||||
str2optimizers['adagrad8bit_blockwise'] = (lambda pxx: torch.optim.Adagrad(pxx, 0.01), lambda pxx: bnb.optim.Adagrad8bit(pxx, 0.01, block_wise=True))
|
||||
|
||||
str2statenames = {}
|
||||
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||
str2statenames['adamw'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
|
||||
str2statenames['lars'] = [('momentum_buffer', 'state1')]
|
||||
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||
str2statenames['rmsprop'] = [('square_avg', 'state1')]
|
||||
str2statenames['adagrad'] = [('sum', 'state1')]
|
||||
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
|
||||
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
|
||||
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
|
||||
str2statenames['adamw8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
|
||||
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
|
||||
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
|
||||
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
|
||||
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
|
||||
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
|
||||
str2statenames['adagrad8bit_blockwise'] = [('sum', 'state1', 'qmap1', 'absmax1')]
|
||||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097, 1]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ['adam', 'adamw', 'momentum', 'rmsprop', 'lars', 'lamb', 'adagrad']
|
||||
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
|
||||
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
|
@ -89,12 +80,12 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||
|
||||
if gtype == torch.float32:
|
||||
atol, rtol = 2e-6, 1e-5
|
||||
atol, rtol = 1e-6, 1e-5
|
||||
else:
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
||||
|
||||
for i in range(50):
|
||||
for i in range(k):
|
||||
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
|
||||
p1.grad = g.clone().float()
|
||||
p2.grad = g.clone()
|
||||
|
@ -107,7 +98,7 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
|||
|
||||
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
|
||||
|
||||
if i % 10 == 0 and i > 0:
|
||||
if i % (k//5) == 0 and i > 0:
|
||||
path = get_temp_dir()
|
||||
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
|
||||
del bnb_optimizer
|
||||
|
@ -148,7 +139,6 @@ def test_global_config(dim1, dim2, gtype):
|
|||
eps = 1e-8
|
||||
|
||||
bnb.optim.GlobalOptimManager.get_instance().initialize()
|
||||
bnb.optim.GlobalOptimManager.get_instance().override_config(p2, 'skip_zeros', True)
|
||||
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
|
||||
|
||||
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
|
||||
|
@ -163,8 +153,6 @@ def test_global_config(dim1, dim2, gtype):
|
|||
else:
|
||||
atol, rtol = 1e-4, 1e-3
|
||||
|
||||
original_p2 = p2[mask].clone()
|
||||
|
||||
for i in range(50):
|
||||
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
|
||||
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
|
||||
|
@ -173,38 +161,17 @@ def test_global_config(dim1, dim2, gtype):
|
|||
p2.grad = g2
|
||||
p3.grad = g3
|
||||
|
||||
if i > 30 and i % 10 == 0:
|
||||
g1.data[mask] = 0.0
|
||||
g2.data[mask] = 0.0
|
||||
p1.grad = g1
|
||||
p2.grad = g2
|
||||
original_p1 = p1[mask].clone()
|
||||
original_p2 = p2[mask].clone()
|
||||
og_s1 = adam2.state[p2]['state1'][mask].clone()
|
||||
og_s2 = adam2.state[p2]['state2'][mask].clone()
|
||||
og_s11 = adam2.state[p1]['state1'][mask].clone()
|
||||
og_s21 = adam2.state[p1]['state2'][mask].clone()
|
||||
|
||||
adam2.step()
|
||||
|
||||
assert adam2.state[p3]['state1'].dtype == torch.uint8
|
||||
assert adam2.state[p3]['state2'].dtype == torch.uint8
|
||||
|
||||
if i > 30 and i % 10 == 0:
|
||||
torch.testing.assert_allclose(original_p2, p2[mask])
|
||||
torch.testing.assert_allclose(adam2.state[p2]['state1'][mask], og_s1)
|
||||
torch.testing.assert_allclose(adam2.state[p2]['state2'][mask], og_s2)
|
||||
assert ((p1[mask]- original_p1)==0.0).sum() < p1.numel()
|
||||
assert ((adam2.state[p1]['state1'][mask]- og_s11)==0.0).sum() == 0.0
|
||||
assert ((adam2.state[p1]['state2'][mask]- og_s21)==0.0).sum() == 0.0
|
||||
|
||||
|
||||
|
||||
|
||||
dim1 = [1024]
|
||||
dim2 = [32, 1024, 4097]
|
||||
gtype = [torch.float32, torch.float16]
|
||||
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'adamw8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise', 'adagrad8bit_blockwise']
|
||||
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
|
||||
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||
|
@ -370,13 +337,12 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
|
|||
if dim1 == 1 and dim2 == 1: return
|
||||
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
|
||||
|
||||
|
||||
bnb_optimizer = str2optimizers[optim_name][1]([p1])
|
||||
|
||||
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
|
||||
p1.grad = g
|
||||
for i in range(5000):
|
||||
if i == 500:
|
||||
for i in range(k):
|
||||
if i == k//5:
|
||||
# 100 iterations for burn-in
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
|
@ -386,23 +352,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
|
|||
torch.cuda.synchronize()
|
||||
s = time.time()-t0
|
||||
print('')
|
||||
params = 4500*4096*4096
|
||||
params = (k-k//5)*dim1*dim2
|
||||
print(optim_name, gtype, s/params)
|
||||
#assert s < 3.9
|
||||
|
||||
|
||||
|
||||
def test_str_betas():
|
||||
betas = (0.80, 0.95)
|
||||
strbetas = '(0.80, 0.95)'
|
||||
|
||||
layer = torch.nn.Linear(10, 10)
|
||||
|
||||
base = bnb.optim.Adam(layer.parameters(), betas=betas)
|
||||
strbase = bnb.optim.Adam(layer.parameters(), betas=strbetas)
|
||||
assert base.defaults['betas'][0] == 0.8
|
||||
assert base.defaults['betas'][1] == 0.95
|
||||
assert strbase.defaults['betas'][0] == 0.8
|
||||
assert strbase.defaults['betas'][1] == 0.95
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user