2022-08-01 10:31:48 +00:00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
2021-10-06 02:16:20 +00:00
# LICENSE file in the root directory of this source tree.
2022-06-30 15:14:20 +00:00
import ctypes as ct
2022-11-17 14:22:29 +00:00
import itertools
2022-08-08 16:13:22 +00:00
import operator
2021-10-06 02:16:20 +00:00
import random
import torch
2022-11-04 02:49:50 +00:00
import itertools
2022-11-19 15:24:03 +00:00
import math
2022-08-03 18:54:01 +00:00
2022-10-27 11:15:21 +00:00
from functools import reduce # Required in Python 3
2022-08-03 18:54:01 +00:00
from typing import Tuple
2021-10-06 02:16:20 +00:00
from torch import Tensor
2022-08-01 10:31:48 +00:00
from . cextension import COMPILED_WITH_CUDA , lib
2022-10-27 11:15:21 +00:00
2022-08-08 16:13:22 +00:00
# math.prod not compatible with python < 3.8
def prod ( iterable ) :
return reduce ( operator . mul , iterable , 1 )
2022-07-01 14:16:10 +00:00
2021-10-06 02:16:20 +00:00
name2qmap = { }
2022-07-01 14:16:10 +00:00
if COMPILED_WITH_CUDA :
2022-08-01 10:31:48 +00:00
""" C FUNCTIONS FOR OPTIMIZERS """
2022-07-01 14:16:10 +00:00
str2optimizer32bit = { }
2022-08-01 10:31:48 +00:00
str2optimizer32bit [ " adam " ] = ( lib . cadam32bit_g32 , lib . cadam32bit_g16 )
2022-08-01 16:32:47 +00:00
str2optimizer32bit [ " momentum " ] = (
lib . cmomentum32bit_g32 ,
lib . cmomentum32bit_g16 ,
)
str2optimizer32bit [ " rmsprop " ] = (
lib . crmsprop32bit_g32 ,
lib . crmsprop32bit_g16 ,
)
2023-03-09 18:10:19 +00:00
str2optimizer32bit [ " lion " ] = (
lib . clion32bit_g32 ,
lib . clion32bit_g16 ,
)
2022-08-01 16:32:47 +00:00
str2optimizer32bit [ " adagrad " ] = (
lib . cadagrad32bit_g32 ,
lib . cadagrad32bit_g16 ,
)
str2optimizer32bit [ " lars " ] = (
lib . cmomentum32bit_g32 ,
lib . cmomentum32bit_g16 ,
)
2022-08-01 10:31:48 +00:00
str2optimizer32bit [ " lamb " ] = ( lib . cadam32bit_g32 , lib . cadam32bit_g16 )
2022-07-01 14:16:10 +00:00
str2optimizer8bit = { }
2022-08-01 16:32:47 +00:00
str2optimizer8bit [ " adam " ] = (
lib . cadam_static_8bit_g32 ,
lib . cadam_static_8bit_g16 ,
)
2022-08-01 10:31:48 +00:00
str2optimizer8bit [ " momentum " ] = (
lib . cmomentum_static_8bit_g32 ,
lib . cmomentum_static_8bit_g16 ,
)
str2optimizer8bit [ " rmsprop " ] = (
lib . crmsprop_static_8bit_g32 ,
lib . crmsprop_static_8bit_g16 ,
)
2023-03-09 18:10:19 +00:00
str2optimizer8bit [ " lion " ] = (
lib . clion_static_8bit_g32 ,
lib . clion_static_8bit_g16 ,
)
2022-08-01 16:32:47 +00:00
str2optimizer8bit [ " lamb " ] = (
lib . cadam_static_8bit_g32 ,
lib . cadam_static_8bit_g16 ,
)
2022-08-01 10:31:48 +00:00
str2optimizer8bit [ " lars " ] = (
lib . cmomentum_static_8bit_g32 ,
lib . cmomentum_static_8bit_g16 ,
)
2022-07-01 14:16:10 +00:00
str2optimizer8bit_blockwise = { }
2022-08-01 10:31:48 +00:00
str2optimizer8bit_blockwise [ " adam " ] = (
lib . cadam_8bit_blockwise_fp32 ,
lib . cadam_8bit_blockwise_fp16 ,
)
str2optimizer8bit_blockwise [ " momentum " ] = (
lib . cmomentum_8bit_blockwise_fp32 ,
lib . cmomentum_8bit_blockwise_fp16 ,
)
str2optimizer8bit_blockwise [ " rmsprop " ] = (
lib . crmsprop_8bit_blockwise_fp32 ,
lib . crmsprop_8bit_blockwise_fp16 ,
)
2023-03-09 18:10:19 +00:00
str2optimizer8bit_blockwise [ " lion " ] = (
lib . clion_8bit_blockwise_fp32 ,
lib . clion_8bit_blockwise_fp16 ,
)
2022-08-01 10:31:48 +00:00
str2optimizer8bit_blockwise [ " adagrad " ] = (
lib . cadagrad_8bit_blockwise_fp32 ,
lib . cadagrad_8bit_blockwise_fp16 ,
)
2021-10-06 02:16:20 +00:00
2022-10-27 11:14:13 +00:00
class CUBLAS_Context :
2022-07-22 21:41:05 +00:00
_instance = None
def __init__ ( self ) :
2022-08-01 10:31:48 +00:00
raise RuntimeError ( " Call get_instance() instead " )
2022-07-22 21:41:05 +00:00
def initialize ( self ) :
self . context = { }
2022-08-01 10:31:48 +00:00
# prev_device = torch.cuda.current_device()
# for i in range(torch.cuda.device_count()):
2022-07-22 21:41:05 +00:00
# torch.cuda.set_device(torch.device('cuda', i))
# self.context.append(ct.c_void_p(lib.get_context()))
2022-08-01 10:31:48 +00:00
# torch.cuda.set_device(prev_device)
2022-07-22 21:41:05 +00:00
@classmethod
def get_instance ( cls ) :
if cls . _instance is None :
cls . _instance = cls . __new__ ( cls )
cls . _instance . initialize ( )
return cls . _instance
def get_context ( self , device ) :
if device . index not in self . context :
prev_device = torch . cuda . current_device ( )
torch . cuda . set_device ( device )
self . context [ device . index ] = ct . c_void_p ( lib . get_context ( ) )
torch . cuda . set_device ( prev_device )
return self . context [ device . index ]
2022-08-01 10:31:48 +00:00
2022-10-27 11:14:13 +00:00
class Cusparse_Context :
2022-07-22 21:41:05 +00:00
_instance = None
def __init__ ( self ) :
2022-08-01 10:31:48 +00:00
raise RuntimeError ( " Call get_instance() instead " )
2022-07-22 21:41:05 +00:00
def initialize ( self ) :
self . context = ct . c_void_p ( lib . get_cusparse ( ) )
@classmethod
def get_instance ( cls ) :
if cls . _instance is None :
cls . _instance = cls . __new__ ( cls )
cls . _instance . initialize ( )
return cls . _instance
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
2022-11-19 15:24:03 +00:00
def create_linear_map ( signed = True , total_bits = 8 , add_zero = True ) :
2022-11-06 19:47:54 +00:00
sign = ( - 1.0 if signed else 0.0 )
2022-11-19 15:24:03 +00:00
total_values = 2 * * total_bits
if add_zero or total_bits < 8 :
# add a zero
# since we simulate less bits by having zeros in the data type, we
# we need to center the quantization around zero and as such lose
# a single value
total_values = ( 2 * * total_bits if not signed else 2 * * total_bits - 1 )
values = torch . linspace ( sign , 1.0 , total_values )
2022-11-06 19:47:54 +00:00
gap = 256 - values . numel ( )
if gap == 0 :
return values
2021-10-06 02:16:20 +00:00
else :
2022-11-06 19:47:54 +00:00
l = values . numel ( ) / / 2
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
return torch . Tensor ( values [ : l ] . tolist ( ) + [ 0 ] * gap + values [ l : ] . tolist ( ) )
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
2022-11-06 19:59:37 +00:00
def create_fp8_map ( signed = True , exponent_bits = 5 , precision_bits = 2 , total_bits = 8 ) :
2022-11-04 02:49:50 +00:00
e = exponent_bits
p = precision_bits
2022-11-06 19:59:37 +00:00
has_sign = 1 if signed else 0
assert e + p == total_bits - has_sign
2022-11-04 02:49:50 +00:00
# the exponent is biased to 2^(e-1) -1 == 0
evalues = [ ]
pvalues = [ ]
2022-11-06 19:59:37 +00:00
for i , val in enumerate ( range ( - ( ( 2 * * ( exponent_bits - has_sign ) ) ) , 2 * * ( exponent_bits - has_sign ) , 1 ) ) :
2022-11-04 02:49:50 +00:00
evalues . append ( 2 * * val )
values = [ ]
2022-11-19 15:24:03 +00:00
lst = list ( itertools . product ( [ 0 , 1 ] , repeat = precision_bits ) )
#for ev in evalues:
bias = 2 * * ( exponent_bits - 1 ) - 1
for evalue in range ( 2 * * ( exponent_bits ) ) :
for bit_pattern in lst :
value = ( 1 if evalue != 0 else 0 )
for i , pval in enumerate ( list ( bit_pattern ) ) :
value + = pval * ( 2 * * - ( i + 1 ) )
if evalue == 0 :
# subnormals
value = value * 2 * * - ( bias - 1 )
else :
# normals
value = value * 2 * * - ( evalue - bias - 2 )
values . append ( value )
2022-11-06 19:59:37 +00:00
if signed :
2022-11-19 15:24:03 +00:00
values . append ( - value )
assert len ( values ) == 2 * * total_bits
values . sort ( )
2022-11-06 19:59:37 +00:00
if total_bits < 8 :
gap = 256 - len ( values )
for i in range ( gap ) :
values . append ( 0 )
2022-11-04 02:49:50 +00:00
values . sort ( )
code = torch . Tensor ( values )
code / = code . max ( )
return code
2022-11-06 21:05:25 +00:00
def create_dynamic_map ( signed = True , max_exponent_bits = 7 , total_bits = 8 ) :
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
Creates the dynamic quantiztion map .
The dynamic data type is made up of a dynamic exponent and
fraction . As the exponent increase from 0 to - 7 the number
of bits available for the fraction shrinks .
This is a generalization of the dynamic type where a certain
number of the bits and be reserved for the linear quantization
region ( the fraction ) . n determines the maximum number of
exponent bits .
For more details see
( 8 - Bit Approximations for Parallelism in Deep Learning ) [ https : / / arxiv . org / abs / 1511.04561 ]
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
data = [ ]
# these are additional items that come from the case
# where all the exponent bits are zero and no
# indicator bit is present
2022-11-06 21:05:25 +00:00
non_sign_bits = total_bits - ( 1 if signed else 0 )
additional_items = 2 * * ( non_sign_bits - max_exponent_bits ) - 1
2022-08-01 10:31:48 +00:00
if not signed :
additional_items = 2 * additional_items
2022-11-06 21:05:25 +00:00
for i in range ( max_exponent_bits ) :
fraction_items = int ( ( 2 * * ( i + non_sign_bits - max_exponent_bits ) + 1 if signed else 2 * * ( i + non_sign_bits - max_exponent_bits + 1 ) + 1 ) )
2021-10-06 02:16:20 +00:00
boundaries = torch . linspace ( 0.1 , 1 , fraction_items )
2022-08-01 10:31:48 +00:00
means = ( boundaries [ : - 1 ] + boundaries [ 1 : ] ) / 2.0
2022-11-06 21:05:25 +00:00
data + = ( ( 10 * * ( - ( max_exponent_bits - 1 ) + i ) ) * means ) . tolist ( )
2021-10-06 02:16:20 +00:00
if signed :
2022-11-06 21:05:25 +00:00
data + = ( - ( 10 * * ( - ( max_exponent_bits - 1 ) + i ) ) * means ) . tolist ( )
2021-10-06 02:16:20 +00:00
2022-11-06 21:05:25 +00:00
if additional_items > 0 :
boundaries = torch . linspace ( 0.1 , 1 , additional_items + 1 )
means = ( boundaries [ : - 1 ] + boundaries [ 1 : ] ) / 2.0
data + = ( ( 10 * * ( - ( max_exponent_bits - 1 ) + i ) ) * means ) . tolist ( )
if signed :
data + = ( - ( 10 * * ( - ( max_exponent_bits - 1 ) + i ) ) * means ) . tolist ( )
2021-10-06 02:16:20 +00:00
data . append ( 0 )
data . append ( 1.0 )
2022-11-06 21:05:25 +00:00
gap = 256 - len ( data )
for i in range ( gap ) :
data . append ( 0 )
2021-10-06 02:16:20 +00:00
data . sort ( )
return Tensor ( data )
2022-11-19 15:24:03 +00:00
def create_quantile_map ( A , total_bits = 8 ) :
q = estimate_quantiles ( A , num_quantiles = 2 * * total_bits - 1 )
q = q . tolist ( )
q . append ( 0 )
gap = 256 - len ( q )
for i in range ( gap ) :
q . append ( 0 )
q . sort ( )
q = Tensor ( q )
q = q / q . abs ( ) . max ( )
return q
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
def get_special_format_str ( ) :
2022-08-23 23:00:26 +00:00
if not torch . cuda . is_available ( ) : return ' col_turing '
2022-10-27 11:25:07 +00:00
major , _minor = torch . cuda . get_device_capability ( )
2022-08-23 23:00:26 +00:00
if major < = 7 :
2022-08-01 10:31:48 +00:00
return " col_turing "
2022-10-27 11:25:07 +00:00
if major == 8 :
2022-08-01 10:31:48 +00:00
return " col_ampere "
2022-10-27 11:25:07 +00:00
return " col_turing "
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
2022-08-03 16:05:37 +00:00
def is_on_gpu ( tensors ) :
on_gpu = True
for t in tensors :
if t is None : continue # NULL pointers are fine
on_gpu & = t . device . type == ' cuda '
return on_gpu
2021-10-06 02:16:20 +00:00
def get_ptr ( A : Tensor ) - > ct . c_void_p :
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
Get the ctypes pointer from a PyTorch Tensor .
Parameters
- - - - - - - - - -
A : torch . tensor
The PyTorch tensor .
Returns
- - - - - - -
ctypes . c_void_p
2022-08-01 10:31:48 +00:00
"""
if A is None :
return None
else :
2022-08-16 17:56:17 +00:00
return ct . c_void_p ( A . data . data_ptr ( ) )
2022-08-01 10:31:48 +00:00
2021-10-06 02:16:20 +00:00
2022-07-22 21:41:05 +00:00
def pre_call ( device ) :
prev_device = torch . cuda . current_device ( )
torch . cuda . set_device ( device )
return prev_device
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
def post_call ( prev_device ) :
torch . cuda . set_device ( prev_device )
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
def get_transform_func ( dtype , orderA , orderOut , transpose = False ) :
name = f ' ctransform_ { ( 8 if dtype == torch . int8 else 32 ) } _ { orderA } _to_ { orderOut } _ { " t " if transpose else " n " } '
if not hasattr ( lib , name ) :
print ( name )
2022-08-01 10:31:48 +00:00
raise ValueError (
f " Transform function not supported: { orderA } to { orderOut } for data type { dtype } and transpose= { transpose } "
)
2022-07-22 21:41:05 +00:00
else :
return getattr ( lib , name )
2022-08-01 10:31:48 +00:00
def get_transform_buffer (
shape , dtype , device , to_order , from_order = " row " , transpose = False
) :
# init_func = torch.empty
2022-07-22 21:41:05 +00:00
init_func = torch . zeros
dims = len ( shape )
if dims == 2 :
rows = shape [ 0 ]
elif dims == 3 :
2022-08-01 10:31:48 +00:00
rows = shape [ 0 ] * shape [ 1 ]
2022-07-22 21:41:05 +00:00
cols = shape [ - 1 ]
state = ( shape , to_order )
if transpose :
# swap dims
tmp = rows
rows = cols
cols = tmp
state = ( shape [ : : - 1 ] , to_order )
2022-08-01 10:31:48 +00:00
if to_order == " row " or to_order == " col " :
2022-07-22 21:41:05 +00:00
return init_func ( shape , dtype = dtype , device = device ) , state
2022-08-01 10:31:48 +00:00
elif to_order == " col32 " :
2022-07-22 21:41:05 +00:00
# blocks of 32 columns (padded)
2022-08-01 10:31:48 +00:00
cols = 32 * ( ( cols + 31 ) / / 32 )
2022-07-22 21:41:05 +00:00
return init_func ( ( rows , cols ) , dtype = dtype , device = device ) , state
2022-08-01 10:31:48 +00:00
elif to_order == " col_turing " :
2022-07-22 21:41:05 +00:00
# blocks of 32 columns and 8 rows
2022-08-01 10:31:48 +00:00
cols = 32 * ( ( cols + 31 ) / / 32 )
rows = 8 * ( ( rows + 7 ) / / 8 )
2022-07-22 21:41:05 +00:00
return init_func ( ( rows , cols ) , dtype = dtype , device = device ) , state
2022-08-01 10:31:48 +00:00
elif to_order == " col_ampere " :
2022-07-22 21:41:05 +00:00
# blocks of 32 columns and 32 rows
2022-08-01 10:31:48 +00:00
cols = 32 * ( ( cols + 31 ) / / 32 )
rows = 32 * ( ( rows + 31 ) / / 32 )
2022-07-22 21:41:05 +00:00
return init_func ( ( rows , cols ) , dtype = dtype , device = device ) , state
else :
2022-08-01 10:31:48 +00:00
raise NotImplementedError ( f " To_order not supported: { to_order } " )
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
def nvidia_transform (
2022-08-01 16:32:47 +00:00
A ,
to_order ,
from_order = " row " ,
out = None ,
transpose = False ,
state = None ,
ld = None ,
2022-08-01 10:31:48 +00:00
) :
if state is None :
state = ( A . shape , from_order )
else :
from_order = state [ 1 ]
if out is None :
out , new_state = get_transform_buffer (
state [ 0 ] , A . dtype , A . device , to_order , state [ 1 ]
)
else :
new_state = ( state [ 1 ] , to_order )
2022-07-22 21:41:05 +00:00
func = get_transform_func ( A . dtype , from_order , to_order , transpose )
shape = state [ 0 ]
if len ( shape ) == 2 :
dim1 = ct . c_int32 ( shape [ 0 ] )
dim2 = ct . c_int32 ( shape [ 1 ] )
elif ld is not None :
2022-08-08 16:13:22 +00:00
n = prod ( shape )
dim1 = prod ( [ shape [ i ] for i in ld ] )
2022-08-01 10:31:48 +00:00
dim2 = ct . c_int32 ( n / / dim1 )
2022-07-22 21:41:05 +00:00
dim1 = ct . c_int32 ( dim1 )
else :
2022-08-01 10:31:48 +00:00
dim1 = ct . c_int32 ( shape [ 0 ] * shape [ 1 ] )
2022-07-22 21:41:05 +00:00
dim2 = ct . c_int32 ( shape [ 2 ] )
ptr = CUBLAS_Context . get_instance ( ) . get_context ( A . device )
func ( ptr , get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
return out , new_state
2022-08-01 10:31:48 +00:00
2022-11-06 21:05:25 +00:00
def estimate_quantiles ( A : Tensor , out : Tensor = None , offset : float = 1 / 512 , num_quantiles = 256 ) - > Tensor :
2021-10-06 02:16:20 +00:00
'''
Estimates 256 equidistant quantiles on the input tensor eCDF .
Uses SRAM - Quantiles algorithm to quickly estimate 256 equidistant quantiles
via the eCDF of the input tensor ` A ` . This is a fast but approximate algorithm
and the extreme quantiles close to 0 and 1 have high variance / large estimation
errors . These large errors can be avoided by using the offset variable which trims
the distribution . The default offset value of 1 / 512 ensures minimum entropy encoding - - it
trims 1 / 512 = 0.2 % from each side of the distrivution . An offset value of 0.01 to 0.02
usually has a much lower error but is not a minimum entropy encoding . Given an offset
of 0.02 equidistance points in the range [ 0.02 , 0.98 ] are used for the quantiles .
Parameters
- - - - - - - - - -
A : torch . Tensor
The input tensor . Any shape .
out : torch . Tensor
Tensor with the 256 estimated quantiles .
offset : float
2022-11-06 21:05:25 +00:00
The offset for the first and last quantile from 0 and 1. Default : 1 / ( 2 * num_quantiles )
num_quantiles : int
The number of equally spaced quantiles .
2021-10-06 02:16:20 +00:00
Returns
- - - - - - -
torch . Tensor :
The 256 quantiles in float32 datatype .
'''
2022-11-06 21:05:25 +00:00
if A . numel ( ) < 256 : raise NotImplementedError ( f ' Quantile estimation needs at least 256 values in the Tensor, but Tensor had only { A . numel ( ) } values. ' )
if num_quantiles > 256 : raise NotImplementedError ( f " Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles= { num_quantiles } " )
if num_quantiles < 256 and offset == 1 / ( 512 ) :
# override default arguments
offset = 1 / ( 2 * num_quantiles )
2021-10-06 02:16:20 +00:00
if out is None : out = torch . zeros ( ( 256 , ) , dtype = torch . float32 , device = A . device )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ A , out ] )
2022-11-06 21:05:25 +00:00
device = pre_call ( A . device )
2021-10-06 02:16:20 +00:00
if A . dtype == torch . float32 :
2022-11-06 21:05:25 +00:00
lib . cestimate_quantiles_fp32 ( get_ptr ( A ) , get_ptr ( out ) , ct . c_float ( offset ) , ct . c_int ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
elif A . dtype == torch . float16 :
2022-11-06 21:05:25 +00:00
lib . cestimate_quantiles_fp16 ( get_ptr ( A ) , get_ptr ( out ) , ct . c_float ( offset ) , ct . c_int ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
else :
2022-08-01 10:31:48 +00:00
raise NotImplementedError ( f " Not supported data type { A . dtype } " )
2022-11-06 21:05:25 +00:00
post_call ( device )
if num_quantiles < 256 :
2022-11-19 15:24:03 +00:00
step = round ( 256 / num_quantiles )
2022-11-06 21:05:25 +00:00
idx = torch . linspace ( 0 , 255 , num_quantiles ) . long ( ) . to ( A . device )
out = out [ idx ]
2021-10-06 02:16:20 +00:00
return out
2022-08-01 10:31:48 +00:00
2022-09-11 18:55:09 +00:00
def quantize_blockwise ( A : Tensor , code : Tensor = None , absmax : Tensor = None , rand = None , out : Tensor = None , blocksize = 4096 ) - > Tensor :
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
Quantize tensor A in blocks of size 4096 values .
Quantizes tensor A by dividing it into blocks of 4096 values .
Then the absolute maximum value within these blocks is calculated
for the non - linear quantization .
Parameters
- - - - - - - - - -
A : torch . Tensor
The input tensor .
code : torch . Tensor
The quantization map .
absmax : torch . Tensor
The absmax values .
rand : torch . Tensor
The tensor for stochastic rounding .
out : torch . Tensor
The output tensor ( 8 - bit ) .
Returns
- - - - - - -
torch . Tensor :
The 8 - bit tensor .
tuple ( torch . Tensor , torch . Tensor ) :
The quantization state to undo the quantization .
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
2022-11-07 00:27:48 +00:00
2021-10-06 02:16:20 +00:00
if code is None :
2022-08-01 10:31:48 +00:00
if " dynamic " not in name2qmap :
name2qmap [ " dynamic " ] = create_dynamic_map ( ) . to ( A . device )
code = name2qmap [ " dynamic " ]
2021-10-06 02:16:20 +00:00
if absmax is None :
n = A . numel ( )
2022-09-11 18:55:09 +00:00
blocks = n / / blocksize
blocks + = 1 if n % blocksize > 0 else 0
2021-10-06 02:16:20 +00:00
absmax = torch . zeros ( ( blocks , ) , device = A . device )
2022-08-01 10:31:48 +00:00
if out is None :
out = torch . zeros_like ( A , dtype = torch . uint8 )
2021-10-06 02:16:20 +00:00
if A . device . type != ' cpu ' :
2022-11-20 22:18:15 +00:00
assert blocksize in [ 4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ]
2022-11-07 00:27:48 +00:00
cblocksize = ct . c_int32 ( blocksize )
2022-11-08 02:06:18 +00:00
prev_device = pre_call ( A . device )
code = code . to ( A . device )
2021-10-06 02:16:20 +00:00
if rand is not None :
2022-11-07 01:17:51 +00:00
is_on_gpu ( [ code , A , out , absmax , rand ] )
2022-11-07 00:27:48 +00:00
assert blocksize == 4096
2021-10-06 02:16:20 +00:00
assert rand . numel ( ) > = 1024
rand_offset = random . randint ( 0 , 1023 )
if A . dtype == torch . float32 :
2022-09-11 18:55:09 +00:00
lib . cquantize_blockwise_stochastic_fp32 ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( absmax ) , get_ptr ( out ) , get_ptr ( rand ) , ct . c_int32 ( rand_offset ) , ct . c_int ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
elif A . dtype == torch . float16 :
2022-09-11 18:55:09 +00:00
lib . cquantize_blockwise_stochastic_fp16 ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( absmax ) , get_ptr ( out ) , get_ptr ( rand ) , ct . c_int32 ( rand_offset ) , ct . c_int ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
else :
2022-11-07 00:27:48 +00:00
raise ValueError ( f " Blockwise quantization only supports 16/32-bit floats, but got { A . dtype } " )
2021-10-06 02:16:20 +00:00
else :
2022-11-07 01:17:51 +00:00
is_on_gpu ( [ code , A , out , absmax ] )
2021-10-06 02:16:20 +00:00
if A . dtype == torch . float32 :
2022-11-07 00:27:48 +00:00
lib . cquantize_blockwise_fp32 ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( absmax ) , get_ptr ( out ) , cblocksize , ct . c_int ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
elif A . dtype == torch . float16 :
2022-11-07 00:27:48 +00:00
lib . cquantize_blockwise_fp16 ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( absmax ) , get_ptr ( out ) , cblocksize , ct . c_int ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
else :
2022-11-07 00:27:48 +00:00
raise ValueError ( f " Blockwise quantization only supports 16/32-bit floats, but got { A . dtype } " )
2022-11-08 02:06:18 +00:00
post_call ( A . device )
2021-10-06 02:16:20 +00:00
else :
# cpu
2022-11-08 02:06:18 +00:00
code = code . cpu ( )
2021-10-06 02:16:20 +00:00
assert rand is None
2022-09-11 18:55:09 +00:00
lib . cquantize_blockwise_cpu_fp32 ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( absmax ) , get_ptr ( out ) , ct . c_longlong ( blocksize ) , ct . c_longlong ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
return out , ( absmax , code )
2022-08-01 10:31:48 +00:00
def dequantize_blockwise (
A : Tensor ,
quant_state : Tuple [ Tensor , Tensor ] = None ,
absmax : Tensor = None ,
code : Tensor = None ,
out : Tensor = None ,
blocksize : int = 4096 ,
) - > Tensor :
"""
2021-10-06 02:16:20 +00:00
Dequantizes blockwise quantized values .
Dequantizes the tensor A with maximum absolute values absmax in
blocks of size 4096.
Parameters
- - - - - - - - - -
A : torch . Tensor
The input 8 - bit tensor .
quant_state : tuple ( torch . Tensor , torch . Tensor )
2022-08-01 10:31:48 +00:00
Tuple of code and absmax values .
2021-10-06 02:16:20 +00:00
absmax : torch . Tensor
The absmax values .
code : torch . Tensor
The quantization map .
out : torch . Tensor
Dequantized output tensor ( default : float32 )
Returns
- - - - - - -
torch . Tensor :
Dequantized tensor ( default : float32 )
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
assert quant_state is not None or absmax is not None
if code is None and quant_state is None :
2022-08-01 10:31:48 +00:00
if " dynamic " not in name2qmap :
name2qmap [ " dynamic " ] = create_dynamic_map ( ) . to ( A . device )
code = name2qmap [ " dynamic " ]
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
if out is None :
out = torch . zeros_like ( A , dtype = torch . float32 )
if quant_state is None :
quant_state = ( absmax , code )
2022-11-08 02:06:18 +00:00
else :
absmax , code = quant_state
2021-10-06 02:16:20 +00:00
if A . device . type != ' cpu ' :
2022-11-08 02:06:18 +00:00
device = pre_call ( A . device )
code = code . to ( A . device )
2022-11-20 22:18:15 +00:00
if blocksize not in [ 2048 , 4096 , 1024 , 512 , 256 , 128 , 64 ] :
raise ValueError ( f " The blockwise of { blocksize } is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64] " )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ A , out ] )
2021-10-06 02:16:20 +00:00
if out . dtype == torch . float32 :
2022-11-08 02:06:18 +00:00
lib . cdequantize_blockwise_fp32 ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( absmax ) , get_ptr ( out ) , ct . c_int ( blocksize ) , ct . c_int ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
elif out . dtype == torch . float16 :
2022-11-08 02:06:18 +00:00
lib . cdequantize_blockwise_fp16 ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( absmax ) , get_ptr ( out ) , ct . c_int ( blocksize ) , ct . c_int ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
else :
2022-11-07 00:36:31 +00:00
raise ValueError ( f " Blockwise quantization only supports 16/32-bit floats, but got { A . dtype } " )
2022-11-08 02:06:18 +00:00
post_call ( A . device )
2021-10-06 02:16:20 +00:00
else :
2022-11-08 02:06:18 +00:00
code = code . cpu ( )
2022-09-11 18:55:09 +00:00
lib . cdequantize_blockwise_cpu_fp32 ( get_ptr ( quant_state [ 1 ] ) , get_ptr ( A ) , get_ptr ( quant_state [ 0 ] ) , get_ptr ( out ) , ct . c_longlong ( blocksize ) , ct . c_longlong ( A . numel ( ) ) )
2021-10-06 02:16:20 +00:00
return out
2022-08-01 10:31:48 +00:00
def quantize ( A : Tensor , code : Tensor = None , out : Tensor = None ) - > Tensor :
2021-10-06 02:16:20 +00:00
if code is None :
2022-08-01 10:31:48 +00:00
if " dynamic " not in name2qmap :
name2qmap [ " dynamic " ] = create_dynamic_map ( ) . to ( A . device )
code = name2qmap [ " dynamic " ]
2021-10-06 02:16:20 +00:00
code = code . to ( A . device )
absmax = torch . abs ( A ) . max ( )
2022-08-01 10:31:48 +00:00
inp = A / absmax
2021-10-06 02:16:20 +00:00
out = quantize_no_absmax ( inp , code , out )
return out , ( absmax , code )
2022-08-01 10:31:48 +00:00
def dequantize (
A : Tensor ,
quant_state : Tuple [ Tensor , Tensor ] = None ,
absmax : Tensor = None ,
code : Tensor = None ,
out : Tensor = None ,
) - > Tensor :
2021-10-06 02:16:20 +00:00
assert quant_state is not None or absmax is not None
if code is None and quant_state is None :
2022-08-01 10:31:48 +00:00
if " dynamic " not in name2qmap :
name2qmap [ " dynamic " ] = create_dynamic_map ( ) . to ( A . device )
code = name2qmap [ " dynamic " ]
2021-10-06 02:16:20 +00:00
code = code . to ( A . device )
2022-08-01 10:31:48 +00:00
if quant_state is None :
quant_state = ( absmax , code )
2021-10-06 02:16:20 +00:00
out = dequantize_no_absmax ( A , quant_state [ 1 ] , out )
2022-08-01 10:31:48 +00:00
return out * quant_state [ 0 ]
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
def quantize_no_absmax ( A : Tensor , code : Tensor , out : Tensor = None ) - > Tensor :
2021-10-06 02:16:20 +00:00
'''
Quantizes input tensor to 8 - bit .
Quantizes the 32 - bit input tensor ` A ` to the 8 - bit output tensor
` out ` using the quantization map ` code ` .
Parameters
- - - - - - - - - -
A : torch . Tensor
The input tensor .
code : torch . Tensor
The quantization map .
out : torch . Tensor , optional
The output tensor . Needs to be of type byte .
Returns
- - - - - - -
torch . Tensor :
Quantized 8 - bit tensor .
'''
if out is None : out = torch . zeros_like ( A , dtype = torch . uint8 )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ A , out ] )
2021-10-06 02:16:20 +00:00
lib . cquantize ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( out ) , ct . c_int ( A . numel ( ) ) )
return out
2022-08-01 10:31:48 +00:00
def dequantize_no_absmax ( A : Tensor , code : Tensor , out : Tensor = None ) - > Tensor :
2021-10-06 02:16:20 +00:00
'''
Dequantizes the 8 - bit tensor to 32 - bit .
Dequantizes the 8 - bit tensor ` A ` to the 32 - bit tensor ` out ` via
the quantization map ` code ` .
Parameters
- - - - - - - - - -
A : torch . Tensor
The 8 - bit input tensor .
code : torch . Tensor
The quantization map .
out : torch . Tensor
The 32 - bit output tensor .
Returns
- - - - - - -
torch . Tensor :
32 - bit output tensor .
'''
if out is None : out = torch . zeros_like ( A , dtype = torch . float32 )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ code , A , out ] )
2021-10-06 02:16:20 +00:00
lib . cdequantize ( get_ptr ( code ) , get_ptr ( A ) , get_ptr ( out ) , ct . c_int ( A . numel ( ) ) )
return out
2022-08-01 10:31:48 +00:00
def optimizer_update_32bit (
optimizer_name : str ,
g : Tensor ,
p : Tensor ,
state1 : Tensor ,
beta1 : float ,
eps : float ,
step : int ,
lr : float ,
state2 : Tensor = None ,
beta2 : float = 0.0 ,
weight_decay : float = 0.0 ,
gnorm_scale : float = 1.0 ,
unorm_vec : Tensor = None ,
max_unorm : float = 0.0 ,
skip_zeros = False ,
) - > None :
"""
2021-10-06 02:16:20 +00:00
Performs an inplace optimizer update with one or two optimizer states .
Universal optimizer update for 32 - bit state and 32 / 16 - bit gradients / weights .
Parameters
- - - - - - - - - -
optimizer_name : str
The name of the optimizer : { adam } .
g : torch . Tensor
Gradient tensor .
p : torch . Tensor
Parameter tensor .
state1 : torch . Tensor
Optimizer state 1.
beta1 : float
Optimizer beta1 .
eps : float
Optimizer epsilon .
weight_decay : float
Weight decay .
step : int
Current optimizer step .
lr : float
The learning rate .
state2 : torch . Tensor
Optimizer state 2.
beta2 : float
Optimizer beta2 .
gnorm_scale : float
The factor to rescale the gradient to the max clip value .
2021-10-21 01:37:44 +00:00
unorm_vec : torch . Tensor
The tensor for the update norm .
max_unorm : float
The maximum update norm relative to the weight norm .
skip_zeros : bool
Whether to skip zero - valued gradients or not ( default : False ) .
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
param_norm = 0.0
if max_unorm > 0.0 :
param_norm = torch . norm ( p . data . float ( ) )
if optimizer_name not in str2optimizer32bit :
2022-08-01 10:31:48 +00:00
raise NotImplementedError (
f ' Optimizer not implemented: { optimizer_name } . Choices: { " , " . join ( str2optimizer32bit . keys ( ) ) } '
)
2021-10-06 02:16:20 +00:00
if g . dtype == torch . float32 and state1 . dtype == torch . float32 :
2022-08-01 10:31:48 +00:00
str2optimizer32bit [ optimizer_name ] [ 0 ] (
get_ptr ( g ) ,
get_ptr ( p ) ,
get_ptr ( state1 ) ,
get_ptr ( state2 ) ,
get_ptr ( unorm_vec ) ,
ct . c_float ( max_unorm ) ,
ct . c_float ( param_norm ) ,
ct . c_float ( beta1 ) ,
ct . c_float ( beta2 ) ,
ct . c_float ( eps ) ,
ct . c_float ( weight_decay ) ,
ct . c_int32 ( step ) ,
ct . c_float ( lr ) ,
ct . c_float ( gnorm_scale ) ,
ct . c_bool ( skip_zeros ) ,
ct . c_int32 ( g . numel ( ) ) ,
)
2021-10-06 02:16:20 +00:00
elif g . dtype == torch . float16 and state1 . dtype == torch . float32 :
2022-08-01 10:31:48 +00:00
str2optimizer32bit [ optimizer_name ] [ 1 ] (
get_ptr ( g ) ,
get_ptr ( p ) ,
get_ptr ( state1 ) ,
get_ptr ( state2 ) ,
get_ptr ( unorm_vec ) ,
ct . c_float ( max_unorm ) ,
ct . c_float ( param_norm ) ,
ct . c_float ( beta1 ) ,
ct . c_float ( beta2 ) ,
ct . c_float ( eps ) ,
ct . c_float ( weight_decay ) ,
ct . c_int32 ( step ) ,
ct . c_float ( lr ) ,
ct . c_float ( gnorm_scale ) ,
ct . c_bool ( skip_zeros ) ,
ct . c_int32 ( g . numel ( ) ) ,
)
2021-10-06 02:16:20 +00:00
else :
2022-08-01 10:31:48 +00:00
raise ValueError (
f " Gradient+optimizer bit data type combination not supported: grad { g . dtype } , optimizer { state1 . dtype } "
)
def optimizer_update_8bit (
optimizer_name : str ,
g : Tensor ,
p : Tensor ,
state1 : Tensor ,
state2 : Tensor ,
beta1 : float ,
beta2 : float ,
eps : float ,
step : int ,
lr : float ,
qmap1 : Tensor ,
qmap2 : Tensor ,
max1 : Tensor ,
max2 : Tensor ,
new_max1 : Tensor ,
new_max2 : Tensor ,
weight_decay : float = 0.0 ,
gnorm_scale : float = 1.0 ,
unorm_vec : Tensor = None ,
max_unorm : float = 0.0 ,
) - > None :
"""
2021-10-06 02:16:20 +00:00
Performs an inplace Adam update .
Universal Adam update for 32 / 8 - bit state and 32 / 16 - bit gradients / weights .
Uses AdamW formulation if weight decay > 0.0 .
Parameters
- - - - - - - - - -
optimizer_name : str
The name of the optimizer . Choices { adam , momentum }
g : torch . Tensor
Gradient tensor .
p : torch . Tensor
Parameter tensor .
state1 : torch . Tensor
Adam state 1.
state2 : torch . Tensor
Adam state 2.
beta1 : float
Adam beta1 .
beta2 : float
Adam beta2 .
eps : float
Adam epsilon .
weight_decay : float
Weight decay .
step : int
Current optimizer step .
lr : float
The learning rate .
qmap1 : torch . Tensor
Quantization map for first Adam state .
qmap2 : torch . Tensor
Quantization map for second Adam state .
max1 : torch . Tensor
Max value for first Adam state update .
max2 : torch . Tensor
Max value for second Adam state update .
new_max1 : torch . Tensor
Max value for the next Adam update of the first state .
new_max2 : torch . Tensor
Max value for the next Adam update of the second state .
gnorm_scale : float
The factor to rescale the gradient to the max clip value .
2021-10-21 01:37:44 +00:00
unorm_vec : torch . Tensor
The tensor for the update norm .
max_unorm : float
The maximum update norm relative to the weight norm .
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
param_norm = 0.0
if max_unorm > 0.0 :
param_norm = torch . norm ( p . data . float ( ) )
if g . dtype == torch . float32 and state1 . dtype == torch . uint8 :
2022-08-01 10:31:48 +00:00
str2optimizer8bit [ optimizer_name ] [ 0 ] (
get_ptr ( p ) ,
get_ptr ( g ) ,
get_ptr ( state1 ) ,
get_ptr ( state2 ) ,
get_ptr ( unorm_vec ) ,
ct . c_float ( max_unorm ) ,
ct . c_float ( param_norm ) ,
ct . c_float ( beta1 ) ,
ct . c_float ( beta2 ) ,
ct . c_float ( eps ) ,
ct . c_int32 ( step ) ,
ct . c_float ( lr ) ,
get_ptr ( qmap1 ) ,
get_ptr ( qmap2 ) ,
get_ptr ( max1 ) ,
get_ptr ( max2 ) ,
get_ptr ( new_max1 ) ,
get_ptr ( new_max2 ) ,
ct . c_float ( weight_decay ) ,
ct . c_float ( gnorm_scale ) ,
ct . c_int32 ( g . numel ( ) ) ,
)
2021-10-06 02:16:20 +00:00
elif g . dtype == torch . float16 and state1 . dtype == torch . uint8 :
2022-08-01 10:31:48 +00:00
str2optimizer8bit [ optimizer_name ] [ 1 ] (
get_ptr ( p ) ,
get_ptr ( g ) ,
get_ptr ( state1 ) ,
get_ptr ( state2 ) ,
get_ptr ( unorm_vec ) ,
ct . c_float ( max_unorm ) ,
ct . c_float ( param_norm ) ,
ct . c_float ( beta1 ) ,
ct . c_float ( beta2 ) ,
ct . c_float ( eps ) ,
ct . c_int32 ( step ) ,
ct . c_float ( lr ) ,
get_ptr ( qmap1 ) ,
get_ptr ( qmap2 ) ,
get_ptr ( max1 ) ,
get_ptr ( max2 ) ,
get_ptr ( new_max1 ) ,
get_ptr ( new_max2 ) ,
ct . c_float ( weight_decay ) ,
ct . c_float ( gnorm_scale ) ,
ct . c_int32 ( g . numel ( ) ) ,
)
2021-10-06 02:16:20 +00:00
else :
2022-08-01 10:31:48 +00:00
raise ValueError (
f " Gradient+optimizer bit data type combination not supported: grad { g . dtype } , optimizer { state1 . dtype } "
)
def optimizer_update_8bit_blockwise (
optimizer_name : str ,
g : Tensor ,
p : Tensor ,
state1 : Tensor ,
state2 : Tensor ,
beta1 : float ,
beta2 : float ,
eps : float ,
step : int ,
lr : float ,
qmap1 : Tensor ,
qmap2 : Tensor ,
absmax1 : Tensor ,
absmax2 : Tensor ,
weight_decay : float = 0.0 ,
gnorm_scale : float = 1.0 ,
skip_zeros = False ,
) - > None :
2021-10-06 02:16:20 +00:00
if g . dtype == torch . float32 and state1 . dtype == torch . uint8 :
2022-08-01 10:31:48 +00:00
str2optimizer8bit_blockwise [ optimizer_name ] [ 0 ] (
get_ptr ( p ) ,
get_ptr ( g ) ,
get_ptr ( state1 ) ,
get_ptr ( state2 ) ,
ct . c_float ( beta1 ) ,
ct . c_float ( beta2 ) ,
ct . c_float ( eps ) ,
ct . c_int32 ( step ) ,
ct . c_float ( lr ) ,
get_ptr ( qmap1 ) ,
get_ptr ( qmap2 ) ,
get_ptr ( absmax1 ) ,
get_ptr ( absmax2 ) ,
ct . c_float ( weight_decay ) ,
ct . c_float ( gnorm_scale ) ,
ct . c_bool ( skip_zeros ) ,
ct . c_int32 ( g . numel ( ) ) ,
)
2021-10-06 02:16:20 +00:00
elif g . dtype == torch . float16 and state1 . dtype == torch . uint8 :
2022-08-01 10:31:48 +00:00
str2optimizer8bit_blockwise [ optimizer_name ] [ 1 ] (
get_ptr ( p ) ,
get_ptr ( g ) ,
get_ptr ( state1 ) ,
get_ptr ( state2 ) ,
ct . c_float ( beta1 ) ,
ct . c_float ( beta2 ) ,
ct . c_float ( eps ) ,
ct . c_int32 ( step ) ,
ct . c_float ( lr ) ,
get_ptr ( qmap1 ) ,
get_ptr ( qmap2 ) ,
get_ptr ( absmax1 ) ,
get_ptr ( absmax2 ) ,
ct . c_float ( weight_decay ) ,
ct . c_float ( gnorm_scale ) ,
ct . c_bool ( skip_zeros ) ,
ct . c_int32 ( g . numel ( ) ) ,
)
2021-10-06 02:16:20 +00:00
else :
2022-08-01 10:31:48 +00:00
raise ValueError (
f " Gradient+optimizer bit data type combination not supported: grad { g . dtype } , optimizer { state1 . dtype } "
)
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
def percentile_clipping (
grad : Tensor , gnorm_vec : Tensor , step : int , percentile : int = 5
) :
2021-10-06 02:16:20 +00:00
""" Applies percentile clipping
grad : torch . Tensor
The gradient tensor .
gnorm_vec : torch . Tensor
Vector of gradient norms . 100 elements expected .
step : int
The current optimiation steps ( number of past gradient norms ) .
"""
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ grad , gnorm_vec ] )
2021-10-06 02:16:20 +00:00
if grad . dtype == torch . float32 :
2022-08-01 10:31:48 +00:00
lib . cpercentile_clipping_g32 (
get_ptr ( grad ) ,
get_ptr ( gnorm_vec ) ,
ct . c_int32 ( step ) ,
ct . c_int32 ( grad . numel ( ) ) ,
)
2021-10-06 02:16:20 +00:00
elif grad . dtype == torch . float16 :
2022-08-01 10:31:48 +00:00
lib . cpercentile_clipping_g16 (
get_ptr ( grad ) ,
get_ptr ( gnorm_vec ) ,
ct . c_int32 ( step ) ,
ct . c_int32 ( grad . numel ( ) ) ,
)
2021-10-06 02:16:20 +00:00
else :
2022-08-01 10:31:48 +00:00
raise ValueError ( f " Gradient type { grad . dtype } not supported! " )
2021-10-06 02:16:20 +00:00
current_gnorm = torch . sqrt ( gnorm_vec [ step % 100 ] )
vals , idx = torch . sort ( gnorm_vec )
clip_value = torch . sqrt ( vals [ percentile ] )
gnorm_scale = 1.0
if current_gnorm > clip_value :
2022-08-01 10:31:48 +00:00
gnorm_scale = clip_value / current_gnorm
2021-10-06 02:16:20 +00:00
return current_gnorm , clip_value , gnorm_scale
2022-08-01 10:31:48 +00:00
def histogram_scatter_add_2d (
histogram : Tensor , index1 : Tensor , index2 : Tensor , source : Tensor
) :
2021-10-06 02:16:20 +00:00
assert len ( histogram . shape ) == 2
assert histogram . dtype == torch . float32
assert source . dtype == torch . float32
assert index1 . dtype == torch . int32
assert index2 . dtype == torch . int32
2022-08-01 10:31:48 +00:00
assert histogram . device . type == " cuda "
assert index1 . device . type == " cuda "
assert index2 . device . type == " cuda "
assert source . device . type == " cuda "
2021-10-06 02:16:20 +00:00
maxdim1 = ct . c_int32 ( histogram . shape [ 0 ] )
n = ct . c_int32 ( index1 . numel ( ) )
2022-10-27 11:16:53 +00:00
is_on_gpu ( [ histogram , index1 , index2 , source ] )
2021-10-06 02:16:20 +00:00
lib . chistogram_scatter_add_2d ( get_ptr ( histogram ) , get_ptr ( index1 ) , get_ptr ( index2 ) , get_ptr ( source ) , maxdim1 , n )
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
def check_matmul ( A , B , out , transposed_A , transposed_B , expected_type = torch . int8 ) :
if not torch . cuda . is_initialized ( ) : torch . cuda . init ( )
if A . dtype != expected_type or B . dtype != expected_type :
2022-08-01 10:31:48 +00:00
raise TypeError (
f " Expected torch.int8 input tensors A and B, but got { A . dtype } and { B . dtype } "
)
2022-07-22 21:41:05 +00:00
sA = A . shape
sB = B . shape
tA = transposed_A
tB = transposed_B
correct = True
if len ( sA ) == 2 and len ( sB ) == 2 :
2022-08-01 10:31:48 +00:00
if not tA and not tB and A . shape [ 1 ] != B . shape [ 0 ] :
correct = False
elif tA and not tB and A . shape [ 0 ] != B . shape [ 0 ] :
correct = False
elif tA and tB and A . shape [ 0 ] != B . shape [ 1 ] :
correct = False
elif not tA and tB and A . shape [ 1 ] != B . shape [ 1 ] :
correct = False
2022-07-22 21:41:05 +00:00
elif len ( sA ) == 3 and len ( sB ) == 2 :
2022-08-01 10:31:48 +00:00
if not tA and not tB and A . shape [ 2 ] != B . shape [ 0 ] :
correct = False
elif tA and not tB and A . shape [ 1 ] != B . shape [ 0 ] :
correct = False
elif tA and tB and A . shape [ 1 ] != B . shape [ 1 ] :
correct = False
elif not tA and tB and A . shape [ 2 ] != B . shape [ 1 ] :
correct = False
2022-07-22 21:41:05 +00:00
elif len ( sA ) == 3 and len ( sB ) == 3 :
2022-08-01 10:31:48 +00:00
if not tA and not tB and A . shape [ 2 ] != B . shape [ 1 ] :
correct = False
elif tA and not tB and A . shape [ 1 ] != B . shape [ 1 ] :
correct = False
elif tA and tB and A . shape [ 1 ] != B . shape [ 2 ] :
correct = False
elif not tA and tB and A . shape [ 2 ] != B . shape [ 2 ] :
correct = False
2022-07-22 21:41:05 +00:00
if out is not None :
sout = out . shape
# special case common in backprop
if not correct and len ( sA ) == 3 and len ( sB ) == 3 :
2022-08-01 10:31:48 +00:00
if (
sout [ 0 ] == sA [ 2 ]
and sout [ 1 ] == sB [ 2 ]
and sA [ 0 ] == sB [ 0 ]
and sA [ 1 ] == sB [ 1 ]
) :
2022-07-22 21:41:05 +00:00
correct = True
else :
if len ( sA ) == 2 and len ( sB ) == 2 :
2022-08-01 10:31:48 +00:00
if not tA and not tB :
sout = ( sA [ 0 ] , sB [ 1 ] )
elif tA and tB :
sout = ( sA [ 1 ] , sB [ 0 ] )
elif tA and not tB :
sout = ( sA [ 1 ] , sB [ 1 ] )
elif not tA and tB :
sout = ( sA [ 0 ] , sB [ 0 ] )
2022-07-22 21:41:05 +00:00
elif len ( sA ) == 3 and len ( sB ) == 2 :
2022-08-01 10:31:48 +00:00
if not tA and not tB :
sout = ( sA [ 0 ] , sA [ 1 ] , sB [ 1 ] )
elif tA and tB :
sout = ( sA [ 0 ] , sA [ 2 ] , sB [ 0 ] )
elif tA and not tB :
sout = ( sA [ 0 ] , sA [ 2 ] , sB [ 1 ] )
elif not tA and tB :
sout = ( sA [ 0 ] , sA [ 1 ] , sB [ 0 ] )
2022-07-22 21:41:05 +00:00
elif len ( sA ) == 3 and len ( sB ) == 3 :
2022-08-01 10:31:48 +00:00
if not tA and not tB :
sout = ( sA [ 0 ] , sA [ 1 ] , sB [ 2 ] )
elif tA and tB :
sout = ( sA [ 0 ] , sA [ 2 ] , sB [ 1 ] )
elif tA and not tB :
sout = ( sA [ 0 ] , sA [ 2 ] , sB [ 2 ] )
elif not tA and tB :
sout = ( sA [ 0 ] , sA [ 1 ] , sB [ 1 ] )
2022-07-22 21:41:05 +00:00
if not correct :
2022-08-01 10:31:48 +00:00
raise ValueError (
f " Tensor dimensions incorrect for matrix mulitiplication: A x B: { sA } x { sB } with transpose for A x B: { tA } x { tB } . "
)
2022-07-22 21:41:05 +00:00
return sout
2022-08-01 10:31:48 +00:00
def igemm (
2022-08-01 16:32:47 +00:00
A : Tensor ,
B : Tensor ,
out : Tensor = None ,
transposed_A = False ,
transposed_B = False ,
2022-08-01 10:31:48 +00:00
) :
2022-07-22 21:41:05 +00:00
sout = check_matmul ( A , B , out , transposed_A , transposed_B )
2022-08-01 10:31:48 +00:00
if out is None :
out = torch . zeros ( size = sout , dtype = torch . int32 , device = A . device )
2022-07-22 21:41:05 +00:00
if len ( A . shape ) == 3 and len ( B . shape ) == 3 :
if A . shape [ 0 ] == B . shape [ 0 ] and A . shape [ 2 ] == B . shape [ 1 ] :
return batched_igemm ( A , B , out )
sA = A . shape
sB = B . shape
2022-08-01 10:31:48 +00:00
if transposed_A and len ( sA ) == 2 :
sA = ( sA [ 1 ] , sA [ 0 ] )
elif transposed_A and len ( sA ) == 3 :
sA = ( sA [ 0 ] , sA [ 2 ] , sA [ 0 ] )
if transposed_B and len ( sB ) == 2 :
sB = ( sB [ 1 ] , sB [ 0 ] )
elif transposed_B and len ( sB ) == 3 :
sB = ( sB [ 0 ] , sB [ 2 ] , sB [ 0 ] )
2022-07-22 21:41:05 +00:00
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
if len ( sB ) == 2 :
2022-08-01 10:31:48 +00:00
if B . stride ( ) [ 0 ] == B . shape [ 1 ] :
transposed_B = False
elif B . stride ( ) [ 1 ] == B . shape [ 0 ] :
transposed_B = True
2022-07-22 21:41:05 +00:00
if len ( A . shape ) == 2 :
2022-08-01 10:31:48 +00:00
if A . stride ( ) [ 0 ] == A . shape [ 1 ] :
transposed_A = False
elif A . stride ( ) [ 1 ] == A . shape [ 0 ] :
transposed_A = True
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
if A . stride ( ) [ 1 ] == A . shape [ 2 ] :
transposed_A = False
elif A . stride ( ) [ 2 ] == A . shape [ 1 ] :
transposed_A = True
2022-07-22 21:41:05 +00:00
if len ( sA ) == 2 :
n = sA [ 0 ]
ldb = A . stride ( ) [ 1 if transposed_A else 0 ]
elif len ( sA ) == 3 and len ( sB ) == 2 :
2022-08-01 10:31:48 +00:00
n = sA [ 0 ] * sA [ 1 ]
2022-07-22 21:41:05 +00:00
ldb = sA [ 2 ]
m = sB [ 1 ]
k = sB [ 0 ]
lda = B . stride ( ) [ ( 1 if transposed_B else 0 ) ]
ldc = sB [ 1 ]
elif len ( sB ) == 3 :
# special case
assert len ( sA ) == 3
if not ( sA [ 0 ] == sB [ 0 ] and sA [ 1 ] == sB [ 1 ] ) :
2022-08-01 10:31:48 +00:00
raise ValueError (
f " Only bsi,bso->io supported for tensor contractions, but dims for A x B were: { sA } x { sB } "
)
2022-07-22 21:41:05 +00:00
transposed_A = True
transposed_B = False
m = sB [ 2 ]
n = sA [ 2 ]
2022-08-01 10:31:48 +00:00
k = sB [ 0 ] * sB [ 1 ]
2022-07-22 21:41:05 +00:00
lda = m
ldb = sA [ 2 ]
ldc = m
ptr = CUBLAS_Context . get_instance ( ) . get_context ( A . device )
# B^T @ A^T = C^T
2022-10-27 11:11:29 +00:00
# [km, nk -> mn]
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ B , A , out ] )
2022-07-22 21:41:05 +00:00
lib . cigemm ( ptr , ct . c_bool ( transposed_B ) , ct . c_bool ( transposed_A ) , ct . c_int32 ( m ) , ct . c_int32 ( n ) , ct . c_int32 ( k ) ,
get_ptr ( B ) , get_ptr ( A ) , get_ptr ( out ) , ct . c_int32 ( lda ) , ct . c_int32 ( ldb ) , ct . c_int32 ( ldc ) )
return out
2022-08-01 10:31:48 +00:00
def batched_igemm (
2022-08-01 16:32:47 +00:00
A : Tensor ,
B : Tensor ,
out : Tensor = None ,
transposed_A = False ,
transposed_B = False ,
2022-08-01 10:31:48 +00:00
) :
2022-07-22 21:41:05 +00:00
if not len ( A . shape ) == 3 or not len ( B . shape ) == 3 :
2022-08-01 10:31:48 +00:00
raise ValueError (
f " Expected 3-dimensional tensors for bmm, but got shapes A and B: { A . shape } and { B . shape } "
)
2022-07-22 21:41:05 +00:00
sout = check_matmul ( A , B , out , transposed_A , transposed_B )
2022-08-01 10:31:48 +00:00
if out is None :
out = torch . zeros ( size = sout , dtype = torch . int32 , device = A . device )
2022-07-22 21:41:05 +00:00
if B . is_contiguous ( ) :
lda = B . stride ( ) [ 1 ]
transposed_A = False
else :
s = B . stride ( )
if s [ 0 ] != B . shape [ 0 ] :
B = B . contiguous ( )
lda = B . stride ( ) [ 1 ]
elif s [ 2 ] == B . shape [ 1 ] :
transposed_A = True
lda = B . stride ( ) [ 2 ]
else :
if s [ 2 ] == 1 :
B = B . contiguous ( )
lda = B . stride ( ) [ 1 ]
elif s [ 1 ] == 1 :
B = B . contiguous ( )
lda = B . stride ( ) [ 1 ]
else :
B = B . contiguous ( )
lda = B . stride ( ) [ 1 ]
if A . is_contiguous ( ) :
ldb = A . stride ( ) [ 1 ]
transposed_B = False
else :
s = A . stride ( )
if s [ 0 ] != A . shape [ 0 ] :
A = A . contiguous ( )
ldb = A . stride ( ) [ 1 ]
transposed_B = False
elif s [ 2 ] == A . shape [ 1 ] :
ldb = A . stride ( ) [ 2 ]
transposed_B = True
else :
A = A . contiguous ( )
ldb = A . stride ( ) [ 1 ]
transposed_B = False
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
# (transpose of row major is column major)
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
# matrices in the input arguments for cuBLAS
# column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n]
# row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n]
# column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m]
num_batch = A . shape [ 0 ]
n = A . shape [ 1 ]
m = B . shape [ 2 ]
k = B . shape [ 1 ]
ldc = m
2022-08-01 10:31:48 +00:00
strideA = B . shape [ 1 ] * B . shape [ 2 ]
strideB = A . shape [ 1 ] * A . shape [ 2 ]
strideC = A . shape [ 1 ] * B . shape [ 2 ]
2022-07-22 21:41:05 +00:00
ptr = CUBLAS_Context . get_instance ( ) . get_context ( A . device )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ B , A , out ] )
2022-07-22 21:41:05 +00:00
lib . cbatched_igemm ( ptr , ct . c_bool ( transposed_B ) , ct . c_bool ( transposed_A ) , ct . c_int32 ( m ) , ct . c_int32 ( n ) , ct . c_int32 ( k ) ,
get_ptr ( B ) , get_ptr ( A ) , get_ptr ( out ) , ct . c_int32 ( lda ) , ct . c_int32 ( ldb ) , ct . c_int32 ( ldc ) ,
ct . c_long ( strideA ) , ct . c_long ( strideB ) , ct . c_long ( strideC ) , ct . c_uint32 ( num_batch ) )
return out
2022-08-01 10:31:48 +00:00
2022-07-26 00:27:57 +00:00
def igemmlt ( A , B , SA , SB , out = None , Sout = None , dtype = torch . int32 ) :
2022-07-22 21:41:05 +00:00
shapeA = SA [ 0 ]
shapeB = SB [ 0 ]
dimsA = len ( shapeA )
dimsB = len ( shapeB )
2022-08-03 18:54:01 +00:00
assert dimsB == 2 , ' Only two dimensional matrices are supported for argument B '
2022-07-22 21:41:05 +00:00
if dimsA == 2 :
m = shapeA [ 0 ]
elif dimsA == 3 :
2022-08-01 10:31:48 +00:00
m = shapeA [ 0 ] * shapeA [ 1 ]
2022-07-22 21:41:05 +00:00
2022-08-03 18:54:01 +00:00
rows = n = shapeB [ 0 ]
2022-08-08 16:13:22 +00:00
assert prod ( list ( shapeA ) ) > 0 , f ' Input tensor dimensions need to be > 0: { shapeA } '
2022-08-03 18:54:01 +00:00
# if the tensor is empty, return a transformed empty tensor with the right dimensions
if shapeA [ 0 ] == 0 and dimsA == 2 :
return torch . empty ( ( 0 , shapeB [ 0 ] ) , device = A . device , dtype = torch . float16 )
elif shapeA [ 1 ] == 0 and dimsA == 3 :
return torch . empty ( tuple ( shapeA [ : 2 ] + [ shapeB [ 0 ] ] ) , device = A . device , dtype = torch . float16 )
2022-07-22 21:41:05 +00:00
if dimsA == 2 and out is None :
2022-08-01 10:31:48 +00:00
out , Sout = get_transform_buffer (
( shapeA [ 0 ] , shapeB [ 0 ] ) , dtype , A . device , " col32 " , " row "
)
2022-07-22 21:41:05 +00:00
elif dimsA == 3 and out is None :
2022-08-01 10:31:48 +00:00
out , Sout = get_transform_buffer (
( shapeA [ 0 ] , shapeA [ 1 ] , shapeB [ 0 ] ) , dtype , A . device , " col32 " , " row "
)
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
assert dimsB != 3 , " len(B.shape)==3 not supported "
assert A . device . type == " cuda "
assert B . device . type == " cuda "
2022-07-22 21:41:05 +00:00
assert A . dtype == torch . int8
assert B . dtype == torch . int8
assert out . dtype == dtype
2022-08-01 10:31:48 +00:00
assert SA [ 1 ] == " col32 "
assert SB [ 1 ] in [ " col_turing " , " col_ampere " ]
assert Sout [ 1 ] == " col32 "
assert (
shapeA [ - 1 ] == shapeB [ - 1 ]
) , f " Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = { shapeA } @ { shapeB } "
2022-07-22 21:41:05 +00:00
formatB = SB [ 1 ]
prev_device = A . device
torch . cuda . set_device ( A . device )
ptr = CUBLAS_Context . get_instance ( ) . get_context ( A . device )
ptrA = get_ptr ( A )
ptrB = get_ptr ( B )
ptrC = get_ptr ( out )
k = shapeA [ - 1 ]
2022-08-01 10:31:48 +00:00
lda = ct . c_int32 ( m * 32 )
if formatB == " col_turing " :
2022-07-22 21:41:05 +00:00
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
2022-08-01 10:31:48 +00:00
ldb = ct . c_int32 ( ( ( rows + 7 ) / / 8 ) * 8 * 32 )
2022-07-22 21:41:05 +00:00
else :
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
2022-08-01 10:31:48 +00:00
ldb = ct . c_int32 ( ( ( rows + 31 ) / / 32 ) * 32 * 32 )
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
ldc = ct . c_int32 ( m * 32 )
2022-07-22 21:41:05 +00:00
m = ct . c_int32 ( m )
n = ct . c_int32 ( n )
k = ct . c_int32 ( k )
has_error = 0
2022-07-26 00:27:57 +00:00
ptrRowScale = get_ptr ( None )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ A , B , out ] )
2022-07-22 21:41:05 +00:00
if formatB == ' col_turing ' :
if dtype == torch . int32 :
2022-08-01 10:31:48 +00:00
has_error = lib . cigemmlt_turing_32 (
ptr , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc
)
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
has_error = lib . cigemmlt_turing_8 (
ptr , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc
)
elif formatB == " col_ampere " :
2022-07-22 21:41:05 +00:00
if dtype == torch . int32 :
2022-08-01 10:31:48 +00:00
has_error = lib . cigemmlt_ampere_32 (
ptr , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc
)
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
has_error = lib . cigemmlt_ampere_8 (
ptr , m , n , k , ptrA , ptrB , ptrC , ptrRowScale , lda , ldb , ldc
)
2022-07-22 21:41:05 +00:00
if has_error == 1 :
2022-08-03 18:54:01 +00:00
print ( f ' A: { shapeA } , B: { shapeB } , C: { Sout [ 0 ] } ; (lda, ldb, ldc): { ( lda , ldb , ldc ) } ; (m, n, k): { ( m , n , k ) } ' )
2022-07-22 21:41:05 +00:00
raise Exception ( ' cublasLt ran into an error! ' )
torch . cuda . set_device ( prev_device )
return out , Sout
2022-08-01 10:31:48 +00:00
def mm_dequant (
A ,
quant_state ,
row_stats ,
col_stats ,
out = None ,
new_row_stats = None ,
new_col_stats = None ,
2022-08-16 17:56:17 +00:00
bias = None
2022-08-01 10:31:48 +00:00
) :
2022-07-22 21:41:05 +00:00
assert A . dtype == torch . int32
2022-08-16 17:56:17 +00:00
if bias is not None : assert bias . dtype == torch . float16
2022-07-22 21:41:05 +00:00
out_shape = quant_state [ 0 ]
2022-08-01 10:31:48 +00:00
if len ( out_shape ) == 3 :
out_shape = ( out_shape [ 0 ] * out_shape [ 1 ] , out_shape [ 2 ] )
if out is None :
out = torch . empty ( out_shape , dtype = torch . float16 , device = A . device )
if new_row_stats is None :
2022-08-01 16:32:47 +00:00
new_row_stats = torch . empty (
out_shape [ 0 ] , dtype = torch . float32 , device = A . device
)
2022-08-01 10:31:48 +00:00
if new_col_stats is None :
2022-08-01 16:32:47 +00:00
new_col_stats = torch . empty (
out_shape [ 1 ] , dtype = torch . float32 , device = A . device
)
2022-08-01 10:31:48 +00:00
assert (
new_row_stats . shape [ 0 ] == row_stats . shape [ 0 ]
) , f " { new_row_stats . shape } vs { row_stats . shape } "
assert (
new_col_stats . shape [ 0 ] == col_stats . shape [ 0 ]
) , f " { new_col_stats . shape } vs { col_stats . shape } "
2022-07-22 21:41:05 +00:00
2022-08-16 17:56:17 +00:00
prev_device = pre_call ( A . device )
2022-07-22 21:41:05 +00:00
ptrA = get_ptr ( A )
ptrOut = get_ptr ( out )
ptrRowStats = get_ptr ( row_stats )
ptrColStats = get_ptr ( col_stats )
ptrNewRowStats = get_ptr ( new_row_stats )
ptrNewColStats = get_ptr ( new_col_stats )
2022-08-16 17:56:17 +00:00
ptrBias = get_ptr ( bias )
2022-07-22 21:41:05 +00:00
numRows = ct . c_int32 ( out_shape [ 0 ] )
numCols = ct . c_int32 ( out_shape [ 1 ] )
2022-08-16 17:56:17 +00:00
is_on_gpu ( [ A , row_stats , col_stats , out , new_row_stats , new_col_stats , bias ] )
lib . cdequant_mm_int32_fp16 ( ptrA , ptrRowStats , ptrColStats , ptrOut , ptrNewRowStats , ptrNewColStats , ptrBias , numRows , numCols )
post_call ( prev_device )
2022-07-22 21:41:05 +00:00
return out
2022-08-01 10:31:48 +00:00
def get_colrow_absmax (
A , row_stats = None , col_stats = None , nnz_block_ptr = None , threshold = 0.0
) :
2022-07-22 21:41:05 +00:00
assert A . dtype == torch . float16
device = A . device
cols = A . shape [ - 1 ]
if len ( A . shape ) == 3 :
2022-08-01 10:31:48 +00:00
rows = A . shape [ 0 ] * A . shape [ 1 ]
2022-07-22 21:41:05 +00:00
else :
rows = A . shape [ 0 ]
2022-08-01 10:31:48 +00:00
col_tiles = ( cols + 255 ) / / 256
tiled_rows = ( ( rows + 15 ) / / 16 ) * 16
if row_stats is None :
2022-08-01 16:32:47 +00:00
row_stats = torch . empty (
( rows , ) , dtype = torch . float32 , device = device
) . fill_ ( - 50000.0 )
2022-08-01 10:31:48 +00:00
if col_stats is None :
2022-08-01 16:32:47 +00:00
col_stats = torch . empty (
( cols , ) , dtype = torch . float32 , device = device
) . fill_ ( - 50000.0 )
2022-08-01 10:31:48 +00:00
if nnz_block_ptr is None and threshold > 0.0 :
nnz_block_ptr = torch . zeros (
( ( tiled_rows * col_tiles ) + 1 , ) , dtype = torch . int32 , device = device
)
2022-07-22 21:41:05 +00:00
ptrA = get_ptr ( A )
ptrRowStats = get_ptr ( row_stats )
ptrColStats = get_ptr ( col_stats )
ptrNnzrows = get_ptr ( nnz_block_ptr )
rows = ct . c_int32 ( rows )
cols = ct . c_int32 ( cols )
prev_device = pre_call ( A . device )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ A , row_stats , col_stats , nnz_block_ptr ] )
2022-07-22 21:41:05 +00:00
lib . cget_col_row_stats ( ptrA , ptrRowStats , ptrColStats , ptrNnzrows , ct . c_float ( threshold ) , rows , cols )
post_call ( prev_device )
if threshold > 0.0 :
nnz_block_ptr . cumsum_ ( 0 )
return row_stats , col_stats , nnz_block_ptr
2022-08-01 10:31:48 +00:00
2022-10-27 11:14:13 +00:00
class COOSparseTensor :
2022-07-22 21:41:05 +00:00
def __init__ ( self , rows , cols , nnz , rowidx , colidx , values ) :
assert rowidx . dtype == torch . int32
assert colidx . dtype == torch . int32
assert values . dtype == torch . float16
assert values . numel ( ) == nnz
assert rowidx . numel ( ) == nnz
assert colidx . numel ( ) == nnz
self . rows = rows
self . cols = cols
self . nnz = nnz
self . rowidx = rowidx
self . colidx = colidx
self . values = values
2022-08-01 10:31:48 +00:00
2022-10-27 11:14:13 +00:00
class CSRSparseTensor :
2022-07-22 21:41:05 +00:00
def __init__ ( self , rows , cols , nnz , rowptr , colidx , values ) :
assert rowptr . dtype == torch . int32
assert colidx . dtype == torch . int32
assert values . dtype == torch . float16
assert values . numel ( ) == nnz
assert colidx . numel ( ) == nnz
2022-08-01 10:31:48 +00:00
assert rowptr . numel ( ) == rows + 1
2022-07-22 21:41:05 +00:00
self . rows = rows
self . cols = cols
self . nnz = nnz
self . rowptr = rowptr
self . colidx = colidx
self . values = values
2022-08-01 10:31:48 +00:00
2022-10-27 11:14:13 +00:00
class CSCSparseTensor :
2022-07-22 21:41:05 +00:00
def __init__ ( self , rows , cols , nnz , colptr , rowidx , values ) :
assert colptr . dtype == torch . int32
assert rowidx . dtype == torch . int32
assert values . dtype == torch . float16
assert values . numel ( ) == nnz
assert rowidx . numel ( ) == nnz
2022-08-01 10:31:48 +00:00
assert colptr . numel ( ) == cols + 1
2022-07-22 21:41:05 +00:00
self . rows = rows
self . cols = cols
self . nnz = nnz
self . colptr = colptr
self . rowidx = rowidx
self . values = values
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
def coo2csr ( cooA ) :
values , counts = torch . unique ( cooA . rowidx , return_counts = True )
values . add_ ( 1 )
2022-08-01 16:32:47 +00:00
rowptr = torch . zeros (
( cooA . rows + 1 , ) , dtype = torch . int32 , device = cooA . rowidx . device
)
2022-07-22 21:41:05 +00:00
rowptr . scatter_ ( index = values . long ( ) , src = counts . int ( ) , dim = 0 )
rowptr . cumsum_ ( 0 )
2022-08-01 10:31:48 +00:00
return CSRSparseTensor (
cooA . rows , cooA . cols , cooA . nnz , rowptr , cooA . colidx , cooA . values
)
2022-07-22 21:41:05 +00:00
def coo2csc ( cooA ) :
val , col2rowidx = torch . sort ( cooA . colidx )
rowidx = cooA . rowidx [ col2rowidx ]
values = cooA . values [ col2rowidx ]
colvalues , counts = torch . unique ( val , return_counts = True )
colvalues . add_ ( 1 )
2022-08-01 16:32:47 +00:00
colptr = torch . zeros (
( cooA . cols + 1 , ) , dtype = torch . int32 , device = cooA . colidx . device
)
2022-07-22 21:41:05 +00:00
colptr . scatter_ ( index = colvalues . long ( ) , src = counts . int ( ) , dim = 0 )
colptr . cumsum_ ( 0 )
2022-08-01 16:32:47 +00:00
return CSCSparseTensor (
cooA . rows , cooA . cols , cooA . nnz , colptr , rowidx , values
)
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
def coo_zeros ( rows , cols , nnz , device , dtype = torch . half ) :
rowidx = torch . zeros ( ( nnz , ) , dtype = torch . int32 , device = device )
colidx = torch . zeros ( ( nnz , ) , dtype = torch . int32 , device = device )
values = torch . zeros ( ( nnz , ) , dtype = dtype , device = device )
return COOSparseTensor ( rows , cols , nnz , rowidx , colidx , values )
2022-08-01 10:31:48 +00:00
def double_quant (
A , col_stats = None , row_stats = None , out_col = None , out_row = None , threshold = 0.0
) :
2022-07-22 21:41:05 +00:00
device = A . device
assert A . dtype == torch . half
2022-08-01 10:31:48 +00:00
assert device . type == " cuda "
2022-07-22 21:41:05 +00:00
prev_device = pre_call ( A . device )
cols = A . shape [ - 1 ]
if len ( A . shape ) == 3 :
2022-08-01 10:31:48 +00:00
rows = A . shape [ 0 ] * A . shape [ 1 ]
2022-07-22 21:41:05 +00:00
else :
rows = A . shape [ 0 ]
if row_stats is None or col_stats is None :
2022-08-01 16:32:47 +00:00
row_stats , col_stats , nnz_row_ptr = get_colrow_absmax (
A , threshold = threshold
)
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
if out_col is None :
out_col = torch . zeros ( A . shape , device = device , dtype = torch . int8 )
if out_row is None :
out_row = torch . zeros ( A . shape , device = device , dtype = torch . int8 )
2022-07-22 21:41:05 +00:00
coo_tensor = None
ptrA = get_ptr ( A )
ptrColStats = get_ptr ( col_stats )
ptrRowStats = get_ptr ( row_stats )
ptrOutCol = get_ptr ( out_col )
ptrOutRow = get_ptr ( out_row )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ A , col_stats , row_stats , out_col , out_row ] )
2022-07-22 21:41:05 +00:00
if threshold > 0.0 :
nnz = nnz_row_ptr [ - 1 ] . item ( )
if nnz > 0 :
2022-08-01 10:31:48 +00:00
coo_tensor = coo_zeros (
A . shape [ 0 ] , A . shape [ 1 ] , nnz_row_ptr [ - 1 ] . item ( ) , device
)
2022-07-22 21:41:05 +00:00
ptrRowIdx = get_ptr ( coo_tensor . rowidx )
ptrColIdx = get_ptr ( coo_tensor . colidx )
ptrVal = get_ptr ( coo_tensor . values )
ptrRowPtr = get_ptr ( nnz_row_ptr )
2022-08-01 10:31:48 +00:00
lib . cdouble_rowcol_quant (
ptrA ,
ptrRowStats ,
ptrColStats ,
ptrOutCol ,
ptrOutRow ,
ptrRowIdx ,
ptrColIdx ,
ptrVal ,
ptrRowPtr ,
ct . c_float ( threshold ) ,
ct . c_int32 ( rows ) ,
ct . c_int32 ( cols ) ,
)
2022-07-22 21:41:05 +00:00
val , idx = torch . sort ( coo_tensor . rowidx )
coo_tensor . rowidx = val
coo_tensor . colidx = coo_tensor . colidx [ idx ]
coo_tensor . values = coo_tensor . values [ idx ]
else :
2022-08-01 10:31:48 +00:00
lib . cdouble_rowcol_quant (
ptrA ,
ptrRowStats ,
ptrColStats ,
ptrOutCol ,
ptrOutRow ,
None ,
None ,
None ,
None ,
ct . c_float ( 0.0 ) ,
ct . c_int32 ( rows ) ,
ct . c_int32 ( cols ) ,
)
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
lib . cdouble_rowcol_quant (
ptrA ,
ptrRowStats ,
ptrColStats ,
ptrOutCol ,
ptrOutRow ,
None ,
None ,
None ,
None ,
ct . c_float ( threshold ) ,
ct . c_int32 ( rows ) ,
ct . c_int32 ( cols ) ,
)
2022-07-22 21:41:05 +00:00
post_call ( prev_device )
return out_row , out_col , row_stats , col_stats , coo_tensor
def transform ( A , to_order , from_order = ' row ' , out = None , transpose = False , state = None , ld = None ) :
2022-08-04 14:47:22 +00:00
prev_device = pre_call ( A . device )
2022-07-22 21:41:05 +00:00
if state is None : state = ( A . shape , from_order )
else : from_order = state [ 1 ]
if out is None : out , new_state = get_transform_buffer ( state [ 0 ] , A . dtype , A . device , to_order , state [ 1 ] , transpose )
else : new_state = ( state [ 0 ] , to_order ) # (shape, order)
shape = state [ 0 ]
if len ( shape ) == 2 :
dim1 = ct . c_int32 ( shape [ 0 ] )
dim2 = ct . c_int32 ( shape [ 1 ] )
else :
2022-08-01 10:31:48 +00:00
dim1 = ct . c_int32 ( shape [ 0 ] * shape [ 1 ] )
2022-07-22 21:41:05 +00:00
dim2 = ct . c_int32 ( shape [ 2 ] )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ A , out ] )
2022-07-22 21:41:05 +00:00
if to_order == ' col32 ' :
if transpose :
lib . ctransform_row2col32T ( get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
else :
lib . ctransform_row2col32 ( get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
2022-08-01 10:31:48 +00:00
elif to_order == " col_turing " :
2022-07-22 21:41:05 +00:00
if transpose :
lib . ctransform_row2turingT ( get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
else :
lib . ctransform_row2turing ( get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
2022-08-01 10:31:48 +00:00
elif to_order == " col_ampere " :
2022-07-22 21:41:05 +00:00
if transpose :
lib . ctransform_row2ampereT ( get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
else :
lib . ctransform_row2ampere ( get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
2022-08-01 10:31:48 +00:00
elif to_order == " row " :
if from_order == " col_turing " :
2022-07-22 21:41:05 +00:00
lib . ctransform_turing2row ( get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
2022-08-01 10:31:48 +00:00
elif from_order == " col_ampere " :
2022-07-22 21:41:05 +00:00
lib . ctransform_ampere2row ( get_ptr ( A ) , get_ptr ( out ) , dim1 , dim2 )
else :
raise NotImplementedError ( f ' Transform function not implemented: From { from_order } to { to_order } ' )
2022-08-04 14:47:22 +00:00
post_call ( prev_device )
2022-07-22 21:41:05 +00:00
return out , new_state
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
def spmm_coo ( cooA , B , out = None ) :
2022-08-01 10:31:48 +00:00
if out is None :
2022-08-01 16:32:47 +00:00
out = torch . empty (
( cooA . rows , B . shape [ 1 ] ) , device = B . device , dtype = B . dtype
)
2022-07-22 21:41:05 +00:00
nnz = cooA . nnz
assert cooA . rowidx . numel ( ) == nnz
assert cooA . colidx . numel ( ) == nnz
assert cooA . values . numel ( ) == nnz
assert cooA . cols == B . shape [ 0 ]
2022-08-01 10:31:48 +00:00
transposed_B = False if B . is_contiguous ( ) else True
2022-07-22 21:41:05 +00:00
ldb = B . stride ( ) [ ( 1 if transposed_B else 0 ) ]
ldc = B . shape [ 1 ]
ptr = Cusparse_Context . get_instance ( ) . context
ptrRowidx = get_ptr ( cooA . rowidx )
ptrColidx = get_ptr ( cooA . colidx )
ptrValues = get_ptr ( cooA . values )
ptrB = get_ptr ( B )
ptrC = get_ptr ( out )
cnnz = ct . c_int32 ( cooA . nnz )
crowsA = ct . c_int32 ( cooA . rows )
ccolsA = ct . c_int32 ( cooA . cols )
ccolsB = ct . c_int32 ( B . shape [ 1 ] )
cldb = ct . c_int32 ( ldb )
cldc = ct . c_int32 ( ldc )
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ cooA . rowidx , cooA . colidx , cooA . values , B , out ] )
2022-07-22 21:41:05 +00:00
lib . cspmm_coo ( ptr , ptrRowidx , ptrColidx , ptrValues , cnnz , crowsA , ccolsA , ccolsB , cldb , ptrB , cldc , ptrC , ct . c_bool ( transposed_B ) )
return out
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
def spmm_coo_very_sparse ( cooA , B , dequant_stats = None , out = None ) :
2022-08-01 10:31:48 +00:00
if out is None :
out = torch . zeros (
( cooA . rows , B . shape [ 1 ] ) , device = B . device , dtype = cooA . values . dtype
)
2022-07-22 21:41:05 +00:00
nnz = cooA . nnz
assert cooA . rowidx . numel ( ) == nnz
assert cooA . colidx . numel ( ) == nnz
assert cooA . values . numel ( ) == nnz
2022-08-01 10:31:48 +00:00
assert cooA . cols == B . shape [ 0 ] , f " { cooA . cols } vs { B . shape } "
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
transposed_B = False if B . is_contiguous ( ) else True
2022-07-22 21:41:05 +00:00
ldb = B . stride ( ) [ ( 1 if transposed_B else 0 ) ]
ldc = B . shape [ 1 ]
values , counts = torch . unique ( cooA . rowidx , return_counts = True )
offset = counts . cumsum ( 0 ) . int ( )
max_count , max_idx = torch . sort ( counts , descending = True )
max_idx = max_idx . int ( )
max_count = max_count . int ( )
2022-08-01 10:31:48 +00:00
assert (
max_count [ 0 ] < = 32
) , f " Current max count per row is 8 but found { max_count [ 0 ] } . "
2022-07-22 21:41:05 +00:00
assert B . dtype in [ torch . float16 , torch . int8 ]
ptrOffset = get_ptr ( offset )
ptrMaxCount = get_ptr ( max_count )
ptrMaxIdx = get_ptr ( max_idx )
ptrRowidx = get_ptr ( cooA . rowidx )
ptrColidx = get_ptr ( cooA . colidx )
ptrValues = get_ptr ( cooA . values )
ptrB = get_ptr ( B )
ptrC = get_ptr ( out )
ptrDequantStats = get_ptr ( dequant_stats )
cnnz_rows = ct . c_int32 ( counts . numel ( ) )
cnnz = ct . c_int32 ( cooA . nnz )
crowsA = ct . c_int32 ( cooA . rows )
ccolsA = ct . c_int32 ( cooA . cols )
crowsB = ct . c_int32 ( B . shape [ 1 ] )
ccolsB = ct . c_int32 ( B . shape [ 1 ] )
cldb = ct . c_int32 ( ldb )
cldc = ct . c_int32 ( ldc )
2022-08-01 10:31:48 +00:00
# print(cooA.rowidx[:64])
# print(cooA.colidx[:64].sort()[0])
2022-07-22 21:41:05 +00:00
2022-08-03 16:05:37 +00:00
is_on_gpu ( [ cooA . rowidx , cooA . colidx , cooA . values , B , out , dequant_stats ] )
2022-07-22 21:41:05 +00:00
if B . dtype == torch . float16 :
2022-08-01 10:31:48 +00:00
lib . cspmm_coo_very_sparse_naive_fp16 (
ptrMaxCount ,
ptrMaxIdx ,
ptrOffset ,
ptrRowidx ,
ptrColidx ,
ptrValues ,
ptrB ,
ptrC ,
ptrDequantStats ,
cnnz_rows ,
cnnz ,
crowsA ,
crowsB ,
ccolsB ,
)
2022-07-22 21:41:05 +00:00
elif B . dtype == torch . int8 :
2022-08-01 10:31:48 +00:00
lib . cspmm_coo_very_sparse_naive_int8 (
ptrMaxCount ,
ptrMaxIdx ,
ptrOffset ,
ptrRowidx ,
ptrColidx ,
ptrValues ,
ptrB ,
ptrC ,
ptrDequantStats ,
cnnz_rows ,
cnnz ,
crowsA ,
crowsB ,
ccolsB ,
)
# else: assertion error
2022-07-22 21:41:05 +00:00
return out
C = 127.0
2022-08-01 10:31:48 +00:00
def vectorwise_quant ( x , dim = 1 , quant_type = " vector " ) :
if quant_type == " linear " :
2022-07-22 21:41:05 +00:00
max1 = torch . abs ( x ) . max ( ) . float ( )
2022-08-01 10:31:48 +00:00
xq = torch . round ( x / max1 * 127 ) . to ( torch . int8 )
2022-07-22 21:41:05 +00:00
return xq , max1
2022-08-01 10:31:48 +00:00
elif quant_type in [ " vector " , " row " ] :
2022-07-22 21:41:05 +00:00
max1 = torch . amax ( torch . abs ( x ) , dim = dim , keepdim = True )
2022-08-01 10:31:48 +00:00
xq = torch . round ( x * ( C / max1 ) ) . to ( torch . int8 )
2022-07-22 21:41:05 +00:00
return xq , max1
2022-08-01 10:31:48 +00:00
elif quant_type == " zeropoint " :
2022-07-22 21:41:05 +00:00
dtype = x . dtype
x = x . float ( )
dyna = x . max ( ) - x . min ( )
2022-08-01 10:31:48 +00:00
if dyna == 0 :
dyna = 1
qx = 255.0 / dyna
2022-07-22 21:41:05 +00:00
minx = x . min ( )
2022-08-01 10:31:48 +00:00
zpx = torch . round ( minx * qx )
x = torch . round ( qx * x - zpx ) + zpx
2022-07-22 21:41:05 +00:00
return x , qx
2022-08-01 10:31:48 +00:00
elif quant_type in [ " vector-zeropoint " , " row-zeropoint " ] :
2022-07-22 21:41:05 +00:00
dtype = x . dtype
x = x . float ( )
2022-08-01 10:31:48 +00:00
dyna = torch . amax ( x , dim = dim , keepdim = True ) - torch . amin (
x , dim = dim , keepdim = True
)
dyna [ dyna == 0 ] = 1
qx = 255.0 / dyna
2022-07-22 21:41:05 +00:00
minx = torch . amin ( x , dim = dim , keepdim = True )
2022-08-01 10:31:48 +00:00
zpx = torch . round ( minx * qx )
x = torch . round ( qx * x - zpx ) + zpx
2022-07-22 21:41:05 +00:00
return x , qx
2022-08-01 10:31:48 +00:00
elif quant_type == " truncated-vector " :
2022-07-22 21:41:05 +00:00
with torch . no_grad ( ) :
absx = torch . abs ( x )
max1 = torch . amax ( absx , dim = dim , keepdim = True )
2022-08-01 10:31:48 +00:00
max1 = max1 * 0.7
idx = absx > max1 . expand_as ( absx )
2022-07-22 21:41:05 +00:00
sign = torch . sign ( x [ idx ] )
2022-08-01 10:31:48 +00:00
x [ idx ] = max1 . expand_as ( absx ) [ idx ] * sign
xq = torch . round ( x / max1 * C ) . to ( torch . int8 )
2022-07-22 21:41:05 +00:00
return xq , max1
2022-08-01 10:31:48 +00:00
else :
return None
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
def vectorwise_dequant ( xq , max1 , quant_type = " vector " ) :
if quant_type == " vector " :
x = ( xq / C * max1 ) . to ( torch . float32 )
2022-07-22 21:41:05 +00:00
return x
2022-08-01 10:31:48 +00:00
else :
return None
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
def vectorwise_mm_dequant ( xq , S1 , S2 , dtype = torch . half , quant_type = " vector " ) :
if quant_type == " linear " :
norm = S1 * S2 / ( C * C )
2022-07-22 21:41:05 +00:00
# double cast needed to prevent overflows
2022-08-01 10:31:48 +00:00
return ( xq . float ( ) * norm ) . to ( dtype )
elif quant_type == " zeropoint " :
norm = 1.0 / ( S1 * S2 )
return ( xq . float ( ) * norm ) . to ( dtype )
elif quant_type == " row-zeropoint " :
norm = 1.0 / ( S1 * S2 )
2022-07-22 21:41:05 +00:00
x = xq . float ( )
2022-08-01 10:31:48 +00:00
if len ( S1 . shape ) == 3 and len ( x . shape ) == 2 :
S1 = S1 . squeeze ( 0 )
if len ( S2 . shape ) == 3 and len ( x . shape ) == 2 :
S2 = S2 . squeeze ( 0 )
2022-07-22 21:41:05 +00:00
if len ( S1 . shape ) == 2 :
x * = norm
else :
x * = norm
return x . to ( dtype )
2022-08-01 10:31:48 +00:00
elif quant_type == " vector-zeropoint " :
2022-07-22 21:41:05 +00:00
x = xq . float ( )
2022-08-01 10:31:48 +00:00
if len ( S1 . shape ) == 3 and len ( x . shape ) == 2 :
S1 = S1 . squeeze ( 0 )
if len ( S2 . shape ) == 3 and len ( x . shape ) == 2 :
S2 = S2 . squeeze ( 0 )
2022-07-22 21:41:05 +00:00
if len ( S1 . shape ) == 2 :
2022-08-01 10:31:48 +00:00
x * = 1.0 / S1
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
x * = 1.0 / S1
x * = 1.0 / S2 . t ( )
2022-07-22 21:41:05 +00:00
return x . to ( dtype )
2022-08-01 10:31:48 +00:00
elif quant_type == " row " :
2022-07-22 21:41:05 +00:00
x = xq . float ( )
2022-08-01 10:31:48 +00:00
if len ( S1 . shape ) == 3 and len ( x . shape ) == 2 :
S1 = S1 . squeeze ( 0 )
if len ( S2 . shape ) == 3 and len ( x . shape ) == 2 :
S2 = S2 . squeeze ( 0 )
2022-07-22 21:41:05 +00:00
if len ( S1 . shape ) == 2 :
2022-08-01 10:31:48 +00:00
x * = S1 * S2 / ( C * C )
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
x * = S1 * S2 / ( C * C )
2022-07-22 21:41:05 +00:00
return x . to ( dtype )
2022-08-01 10:31:48 +00:00
elif quant_type in [ " truncated-vector " , " vector " ] :
2022-07-22 21:41:05 +00:00
x = xq . float ( )
2022-08-01 10:31:48 +00:00
if len ( S1 . shape ) == 3 and len ( x . shape ) == 2 :
S1 = S1 . squeeze ( 0 )
if len ( S2 . shape ) == 3 and len ( x . shape ) == 2 :
S2 = S2 . squeeze ( 0 )
2022-07-22 21:41:05 +00:00
if len ( S1 . shape ) == 2 :
2022-08-01 10:31:48 +00:00
x * = S1 / C
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
x * = S1 / C
x * = S2 / C
2022-07-22 21:41:05 +00:00
return x . to ( dtype )
2022-08-01 10:31:48 +00:00
else :
return None
2022-07-22 21:41:05 +00:00
def dequant_min_max ( xq , A , B , SA , SB , dtype = torch . half ) :
2022-08-01 10:31:48 +00:00
offset = B . float ( ) . t ( ) . sum ( 0 ) * ( SA [ 0 ] + SA [ 1 ] )
2022-07-22 21:41:05 +00:00
x = xq . float ( )
2022-08-01 10:31:48 +00:00
if len ( xq . shape ) == 2 and len ( SB . shape ) == 3 :
SB = SB . squeeze ( 0 )
2022-07-22 21:41:05 +00:00
if len ( SB . shape ) == 2 :
2022-08-01 10:31:48 +00:00
x * = SB . t ( ) / 127
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
x * = SB / 127
x * = SA [ 1 ] / 127
x + = offset
2022-07-22 21:41:05 +00:00
return x . to ( dtype )
2022-07-26 19:12:38 +00:00
2022-08-01 10:31:48 +00:00
2022-07-26 19:12:38 +00:00
def extract_outliers ( A , SA , idx ) :
shapeA = SA [ 0 ]
formatA = SA [ 1 ]
2022-08-01 10:31:48 +00:00
assert formatA in [ " col_turing " , " col_ampere " ]
assert A . device . type == " cuda "
2022-07-26 19:12:38 +00:00
2022-08-01 16:32:47 +00:00
out = torch . zeros (
( shapeA [ 0 ] , idx . numel ( ) ) , dtype = torch . int8 , device = A . device
)
2022-07-26 19:12:38 +00:00
idx_size = ct . c_int32 ( idx . numel ( ) )
rows = ct . c_int32 ( shapeA [ 0 ] )
cols = ct . c_int32 ( shapeA [ 1 ] )
ptrA = get_ptr ( A )
ptrIdx = get_ptr ( idx )
ptrOut = get_ptr ( out )
2022-08-04 14:47:22 +00:00
prev_device = pre_call ( A . device )
2022-07-26 19:12:38 +00:00
if formatA == ' col_turing ' :
lib . cextractOutliers_turing ( ptrA , ptrIdx , ptrOut , idx_size , rows , cols )
2022-08-01 10:31:48 +00:00
elif formatA == " col_ampere " :
2022-07-26 19:12:38 +00:00
lib . cextractOutliers_ampere ( ptrA , ptrIdx , ptrOut , idx_size , rows , cols )
2022-08-04 14:47:22 +00:00
post_call ( prev_device )
2022-07-26 19:12:38 +00:00
return out