2022-08-08 12:20:36 +00:00
import operator
2022-09-17 17:46:04 +00:00
import warnings
2022-10-27 11:15:21 +00:00
from dataclasses import dataclass
from functools import reduce # Required in Python 3
2023-02-05 05:11:21 +00:00
from typing import Tuple , Optional , List
2023-07-11 12:55:49 +00:00
from warnings import warn
2022-09-17 17:46:04 +00:00
2022-07-22 21:41:05 +00:00
import torch
2022-10-27 11:15:21 +00:00
2022-07-22 21:41:05 +00:00
import bitsandbytes . functional as F
2022-08-08 12:20:36 +00:00
2022-08-08 16:13:22 +00:00
# math.prod not compatible with python < 3.8
2022-08-08 12:20:36 +00:00
def prod ( iterable ) :
return reduce ( operator . mul , iterable , 1 )
2022-07-22 21:41:05 +00:00
tensor = torch . Tensor
2023-02-02 04:09:31 +00:00
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
2022-08-01 10:31:48 +00:00
"""
2022-07-22 21:41:05 +00:00
This class pools outlier dimensions across layers .
2022-10-27 11:11:29 +00:00
This is particularly important for small models where outlier features
2022-07-22 21:41:05 +00:00
are less systematic and occur with low frequency .
2022-08-01 10:31:48 +00:00
"""
2022-10-27 11:14:13 +00:00
class GlobalOutlierPooler :
2022-07-22 21:41:05 +00:00
_instance = None
def __init__ ( self ) :
2022-08-01 10:31:48 +00:00
raise RuntimeError ( " Call get_instance() instead " )
2022-07-22 21:41:05 +00:00
def initialize ( self ) :
self . outliers = set ( )
self . model_dim = None
@classmethod
def get_instance ( cls ) :
if cls . _instance is None :
cls . _instance = cls . __new__ ( cls )
cls . _instance . initialize ( )
return cls . _instance
def add_outliers ( self , outlier_idx , feature_dim ) :
2022-08-01 10:31:48 +00:00
if self . model_dim is None :
self . model_dim = feature_dim
if feature_dim != self . model_dim :
return # we do not encode outliers for the 2nd FFN layer
2022-07-22 21:41:05 +00:00
self . outliers . update ( outlier_idx . tolist ( ) )
def get_current_outlier_idx ( self ) :
return torch . Tensor ( list ( self . outliers ) ) . to ( torch . int64 )
2023-02-02 04:09:31 +00:00
def get_inverse_transform_indices ( transform_tile : callable , tile_size : Tuple [ int , int ] ) :
"""
Compute a permutation of indices that invert the specified ( tiled ) matrix transformation
: param transform_tile : a function that applies forward transform to a tensor of shape [ dim1 , dim2 ]
: param tile_size : higher - level tile dimensions , i . e . ( 8 , 32 ) for Turing and ( 32 , 32 ) for Ampere
: note : we assume that tile_transform applies to a cpu - based int8 tensor of shape tile_size
: example : transform_tile function for the turing layout ( bitsandbytes . functional as F )
: returns : indices
"""
d1 , d2 = tile_size
assert 0 < d1 * d2 < 2 * * 64
tile_indices = torch . arange ( d1 * d2 , dtype = torch . int64 ) . view ( d1 , d2 )
# encode each position in tile as a tuple of <= 8 unique bytes
permuted_tile_indices = torch . zeros_like ( tile_indices )
for i in range ( 8 ) :
# select i-th byte, apply transformation and trace where each index ended up
ith_dim_indices = torch . div ( tile_indices , 256 * * i , rounding_mode = " trunc " ) % 256
sample_tile_i = ( ith_dim_indices - 128 ) . to ( torch . int8 ) . contiguous ( )
assert torch . all ( sample_tile_i . int ( ) + 128 == ith_dim_indices ) , " int overflow "
permuted_tile_i = transform_tile ( sample_tile_i )
ith_permuted_indices = permuted_tile_i . to ( tile_indices . dtype ) + 128
permuted_tile_indices + = ith_permuted_indices * ( 256 * * i )
if d1 * d2 < 256 * * i :
break # if all indices fit in i bytes, stop early
return permuted_tile_indices
def undo_layout ( permuted_tensor : torch . Tensor , tile_indices : torch . LongTensor ) - > torch . Tensor :
"""
Undo a tiled permutation such as turing or ampere layout
: param permuted_tensor : torch tensor in a permuted layout
: param tile_indices : reverse transformation indices , from get_inverse_transform_indices
: return : contiguous row - major tensor
"""
( rows , cols ) , ( tile_rows , tile_cols ) = permuted_tensor . shape , tile_indices . shape
assert rows % tile_rows == cols % tile_cols == 0 , " tensor must contain a whole number of tiles "
tensor = permuted_tensor . reshape ( - 1 , tile_indices . numel ( ) ) . t ( )
outputs = torch . empty_like ( tensor ) # note: not using .index_copy because it was slower on cuda
outputs [ tile_indices . flatten ( ) ] = tensor
outputs = outputs . reshape ( tile_rows , tile_cols , cols / / tile_cols , rows / / tile_rows )
outputs = outputs . permute ( 3 , 0 , 2 , 1 ) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols)
return outputs . reshape ( rows , cols ) . contiguous ( )
2022-08-01 10:31:48 +00:00
class MatMul8bit ( torch . autograd . Function ) :
2022-07-22 21:41:05 +00:00
@staticmethod
2022-10-27 11:25:51 +00:00
def forward ( ctx , A , B , out = None , quant_type = " vector " , precision = None ) :
if precision is None :
precision = [ 8 , 8 , 8 ]
2022-07-22 21:41:05 +00:00
if precision [ 0 ] != 8 :
with torch . no_grad ( ) :
output = torch . matmul ( A , B )
else :
2022-08-01 10:31:48 +00:00
if len ( B . shape ) == 2 :
dim = 0
else :
dim = 1
2022-07-22 21:41:05 +00:00
qA , SA = F . vectorwise_quant ( A , dim = - 1 , quant_type = quant_type )
qB , SB = F . vectorwise_quant ( B , dim = dim , quant_type = quant_type )
iout = F . igemm ( qA , qB )
output = F . vectorwise_mm_dequant ( iout , SA , SB , A . dtype , quant_type )
if A . requires_grad or B . requires_grad :
ctx . save_for_backward ( A , B )
ctx . quant_type = quant_type
ctx . precision = precision
return output
@staticmethod
def backward ( ctx , grad_output ) :
A , B = ctx . saved_tensors
quant_type = ctx . quant_type
precision = ctx . precision
grad_A = grad_B = None
if B . requires_grad :
if len ( A . shape ) == 3 :
dims = [ 0 , 1 ]
# bsi -> ibs
permute_dim = [ 0 , 2 , 1 ]
else :
dims = [ 0 ]
# bs -> sb
permute_dim = [ 1 , 0 ]
if precision [ 1 ] != 8 :
with torch . no_grad ( ) :
grad_B = torch . matmul ( A . permute ( permute_dim ) , grad_output )
else :
if len ( B . shape ) == 2 and len ( A . shape ) == 3 :
grad_output = grad_output . contiguous ( )
2022-08-01 10:31:48 +00:00
if not grad_output . is_contiguous ( ) :
grad_output . contiguous ( )
qgrad_output , S1 = F . vectorwise_quant (
grad_output . view ( - 1 , grad_output . shape [ 2 ] ) ,
dim = 0 ,
quant_type = quant_type ,
)
if not A . is_contiguous ( ) :
A = A . contiguous ( )
qA , S2 = F . vectorwise_quant (
A . view ( - 1 , A . shape [ 2 ] ) , dim = 0 , quant_type = quant_type
)
2022-07-22 21:41:05 +00:00
igrad_B = F . igemm ( qA . t ( ) , qgrad_output )
2022-08-01 10:31:48 +00:00
grad_B = F . vectorwise_mm_dequant (
igrad_B , S2 . t ( ) , S1 , grad_output . dtype , quant_type
)
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
qgrad_output , S1 = F . vectorwise_quant (
grad_output , dim = dims , quant_type = quant_type
)
2022-08-01 16:32:47 +00:00
qA , S2 = F . vectorwise_quant (
A , dim = dims , quant_type = quant_type
)
2022-07-22 21:41:05 +00:00
igrad_B = F . igemm ( qA . permute ( permute_dim ) , qgrad_output )
2022-08-01 10:31:48 +00:00
grad_B = F . vectorwise_mm_dequant (
igrad_B ,
S2 . permute ( permute_dim ) ,
S1 ,
grad_output . dtype ,
quant_type ,
)
2022-07-22 21:41:05 +00:00
if A . requires_grad :
2022-08-01 10:31:48 +00:00
if len ( grad_output . shape ) == 3 :
dims = [ 2 ]
else :
dims = [ 1 ]
2022-07-22 21:41:05 +00:00
if len ( B . shape ) == 3 :
# bio -> boi
permute_dim = [ 0 , 2 , 1 ]
dim_B = dims
else :
# io -> oi
permute_dim = [ 1 , 0 ]
dim_B = [ 1 ]
if precision [ 2 ] != 8 :
with torch . no_grad ( ) :
grad_A = torch . matmul ( grad_output , B . permute ( permute_dim ) )
else :
2022-08-01 10:31:48 +00:00
qgrad_output , S1 = F . vectorwise_quant (
grad_output , dim = dims , quant_type = quant_type
)
2022-07-22 21:41:05 +00:00
qB , S3 = F . vectorwise_quant ( B , dim = dim_B , quant_type = quant_type )
igrad_A = F . igemm ( qgrad_output , qB . permute ( permute_dim ) )
2022-08-01 10:31:48 +00:00
grad_A = F . vectorwise_mm_dequant (
2022-08-01 16:32:47 +00:00
igrad_A ,
S1 ,
S3 . permute ( permute_dim ) ,
grad_output . dtype ,
quant_type ,
2022-08-01 10:31:48 +00:00
)
2022-07-22 21:41:05 +00:00
return grad_A , grad_B , None , None , None
mm_cublas = MatMul8bit . apply
bmm_cublas = MatMul8bit . apply
matmul_cublas = MatMul8bit . apply
2022-08-01 10:31:48 +00:00
2023-04-10 18:26:52 +00:00
def supports_igemmlt ( device : torch . device ) - > bool :
""" check if this device supports the optimized int8 kernel """
if torch . cuda . get_device_capability ( device = device ) < ( 7 , 5 ) :
return False
device_name = torch . cuda . get_device_name ( device = device )
nvidia16_models = ( ' GTX 1630 ' , ' GTX 1650 ' , ' GTX 1660 ' ) # https://en.wikipedia.org/wiki/GeForce_16_series
if any ( model_name in device_name for model_name in nvidia16_models ) :
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
return True
2023-06-09 19:39:37 +00:00
def _get_tile_size ( format ) :
assert format in (
" col_turing " ,
" col_ampere " ,
) , f " please find this assert and manually enter tile size for { format } "
return ( 8 , 32 ) if format == " col_turing " else ( 32 , 32 )
def get_tile_inds ( format , device ) :
transform = lambda x : F . transform ( x . to ( device ) , from_order = " row " , to_order = format ) [ 0 ] . to ( x . device )
with torch . no_grad ( ) :
return get_inverse_transform_indices ( transform , _get_tile_size ( format ) ) . to ( device )
2022-07-22 21:41:05 +00:00
@dataclass
class MatmulLtState :
2023-02-25 05:23:07 +00:00
_tile_indices : Optional [ torch . Tensor ] = None
2023-02-02 04:09:31 +00:00
force_no_igemmlt : bool = False
2022-07-22 21:41:05 +00:00
CB = None
CxB = None
SB = None
SCB = None
CxBt = None
SBt = None
CBt = None
subB = None
outlier_pool = None
has_accumulated_gradients = False
threshold = 0.0
idx = None
is_training = True
has_fp16_weights = True
2022-09-11 03:26:15 +00:00
memory_efficient_backward = False
2022-07-22 21:41:05 +00:00
use_pool = False
formatB = F . get_special_format_str ( )
def reset_grads ( self ) :
self . CB = None
self . CxB = None
self . SB = None
self . SCB = None
self . CxBt = None
self . SBt = None
2022-09-11 03:26:15 +00:00
self . CBt = None
2022-07-22 21:41:05 +00:00
2023-02-25 05:23:07 +00:00
@property
def tile_indices ( self ) :
if self . _tile_indices is None :
2023-06-09 19:39:37 +00:00
self . _tile_indices = get_tile_inds ( self . formatB , self . CxB . device )
2023-02-25 05:23:07 +00:00
return self . _tile_indices
2022-07-22 21:41:05 +00:00
class MatMul8bitLt ( torch . autograd . Function ) :
2023-02-02 04:09:31 +00:00
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
2022-07-22 21:41:05 +00:00
@staticmethod
2023-02-02 04:09:31 +00:00
def forward ( ctx , A , B , out = None , bias = None , state = MatmulLtState ) :
2023-04-10 18:26:52 +00:00
using_igemmlt = supports_igemmlt ( A . device ) and not state . force_no_igemmlt
2023-02-02 04:09:31 +00:00
# default of pytorch behavior if inputs are empty
2022-08-03 18:54:01 +00:00
ctx . is_empty = False
2022-08-08 12:20:36 +00:00
if prod ( A . shape ) == 0 :
2022-08-03 18:54:01 +00:00
ctx . is_empty = True
ctx . A = A
ctx . B = B
2022-08-16 19:00:54 +00:00
ctx . bias = bias
2022-08-03 18:54:01 +00:00
if A . shape [ - 1 ] == B . shape [ 0 ] :
2023-02-02 04:09:31 +00:00
return torch . empty ( A . shape [ : - 1 ] + B . shape [ 1 : ] , dtype = A . dtype , device = A . device )
2022-08-03 18:54:01 +00:00
else :
2023-02-02 04:09:31 +00:00
return torch . empty ( A . shape [ : - 1 ] + B . shape [ : 1 ] , dtype = A . dtype , device = A . device )
2022-08-03 18:54:01 +00:00
2022-07-22 21:41:05 +00:00
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB = state . formatB
input_shape = A . shape
2022-08-01 10:31:48 +00:00
if state . outlier_pool is None :
state . outlier_pool = GlobalOutlierPooler . get_instance ( )
2022-08-28 21:56:08 +00:00
# Cast A to fp16
2022-09-17 21:07:05 +00:00
if A . dtype != torch . float16 :
2022-09-17 21:35:03 +00:00
warnings . warn ( f " MatMul8bitLt: inputs will be cast from { A . dtype } to float16 during quantization " )
2022-08-28 21:56:08 +00:00
2022-07-22 21:41:05 +00:00
# 1. Quantize A
2022-08-01 10:31:48 +00:00
if len ( A . shape ) == 3 :
A = A . view ( - 1 , A . shape [ - 1 ] ) . contiguous ( )
2023-02-02 04:09:31 +00:00
CA , CAt , SCA , SCAt , coo_tensorA = F . double_quant ( A . to ( torch . float16 ) , threshold = state . threshold )
2022-07-22 21:41:05 +00:00
if state . threshold > 0.0 and coo_tensorA is not None :
if state . has_fp16_weights :
idx = torch . unique ( coo_tensorA . colidx ) . long ( )
CA [ : , idx ] = 0
CAt [ : , idx ] = 0
subA = A [ : , idx ]
state . subB = B [ : , idx ] . t ( ) . contiguous ( )
state . idx = idx
2022-08-23 20:51:00 +00:00
else :
2023-02-02 04:09:31 +00:00
if state . CxB is None and using_igemmlt :
2022-08-23 20:51:00 +00:00
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state . CxB , state . SB = F . transform ( state . CB , to_order = formatB )
2022-07-22 21:41:05 +00:00
else :
2023-02-02 04:09:31 +00:00
if not state . has_fp16_weights and state . CxB is None and using_igemmlt :
2022-07-22 21:41:05 +00:00
state . CxB , state . SB = F . transform ( state . CB , to_order = formatB )
subA = None
# 2. Quantize B
if state . has_fp16_weights :
2022-08-01 10:31:48 +00:00
has_grad = True if ( getattr ( B , " grad " , None ) is not None ) else False
2022-07-22 21:41:05 +00:00
is_transposed = not B . is_contiguous ( ) and B . shape [ 0 ] == B . stride ( 1 )
2022-08-01 10:31:48 +00:00
if is_transposed :
B = B . contiguous ( )
2022-07-22 21:41:05 +00:00
if ( state . is_training and not has_grad ) or state . CxB is None :
state . reset_grads ( )
2022-08-01 16:32:47 +00:00
(
CB ,
state . CBt ,
state . SCB ,
state . SCBt ,
coo_tensorB ,
2022-09-17 20:34:22 +00:00
) = F . double_quant ( B . to ( torch . float16 ) )
2023-02-02 04:09:31 +00:00
if using_igemmlt :
state . CxB , state . SB = F . transform ( CB , to_order = formatB )
else :
state . CB = CB
2022-07-22 21:41:05 +00:00
else :
has_grad = False
2022-07-27 02:15:35 +00:00
if coo_tensorA is not None and not state . has_fp16_weights :
# extract outliers
outlier_idx = torch . unique ( coo_tensorA . colidx )
2022-07-27 08:46:35 +00:00
state . idx = outlier_idx
2022-09-11 02:51:29 +00:00
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
2023-02-02 04:09:31 +00:00
if state . CxB is not None :
outliers = F . extract_outliers ( state . CxB , state . SB , state . idx . int ( ) )
else :
outliers = state . CB [ : , state . idx . long ( ) ] . clone ( )
state . subB = ( outliers * state . SCB . view ( - 1 , 1 ) / 127.0 ) . t ( ) . contiguous ( ) . to ( A . dtype )
2022-07-27 02:15:35 +00:00
CA [ : , state . idx . long ( ) ] = 0
CAt [ : , state . idx . long ( ) ] = 0
subA = A [ : , state . idx . long ( ) ]
2023-02-02 04:09:31 +00:00
shapeB = state . SB [ 0 ] if state . SB else B . shape
2022-07-22 21:41:05 +00:00
if len ( input_shape ) == 3 :
output_shape = ( input_shape [ 0 ] , input_shape [ 1 ] , shapeB [ 0 ] )
else :
output_shape = ( input_shape [ 0 ] , shapeB [ 0 ] )
# 3. Matmul
2023-02-02 04:09:31 +00:00
if using_igemmlt :
C32A , SA = F . transform ( CA , " col32 " )
out32 , Sout32 = F . igemmlt ( C32A , state . CxB , SA , state . SB )
if bias is None or bias . dtype == torch . float16 :
# we apply the fused bias here
output = F . mm_dequant ( out32 , Sout32 , SCA , state . SCB , bias = bias )
output = output . to ( A . dtype )
else : # apply bias separately
output = F . mm_dequant ( out32 , Sout32 , SCA , state . SCB , bias = None )
output = output . to ( A . dtype ) . add_ ( bias )
2022-09-17 20:44:28 +00:00
2023-02-02 04:09:31 +00:00
else :
A_wo_outliers = A . clone ( )
if state . idx is not None :
A_wo_outliers [ : , state . idx . long ( ) ] = 0
output = torch . nn . functional . linear ( A_wo_outliers , state . CB . to ( A . dtype ) )
output = output . mul_ ( state . SCB . unsqueeze ( 0 ) . mul ( 1.0 / 127.0 ) )
if bias is not None :
output = output . add_ ( bias )
2022-07-22 21:41:05 +00:00
# 4. Mixed-precision decomposition matmul
2022-07-27 08:46:35 +00:00
if coo_tensorA is not None and subA is not None :
2022-09-17 21:43:56 +00:00
output + = torch . matmul ( subA , state . subB )
2022-07-22 21:41:05 +00:00
# 5. Save state
ctx . state = state
ctx . formatB = formatB
ctx . grad_shape = input_shape
2022-09-17 21:15:18 +00:00
ctx . dtype_A , ctx . dtype_B , ctx . dtype_bias = A . dtype , B . dtype , None if bias is None else bias . dtype
2022-07-22 21:41:05 +00:00
2022-09-17 21:36:46 +00:00
if any ( ctx . needs_input_grad [ : 2 ] ) :
2023-02-24 18:17:57 +00:00
ctx . tensors = ( CAt , subA , A )
2022-07-22 21:41:05 +00:00
ctx . tensor_states = ( SCAt , state . idx )
else :
2023-02-24 18:17:57 +00:00
ctx . tensors = [ None , None , A ]
2022-07-22 21:41:05 +00:00
ctx . tensor_states = ( None , None )
ctx . save_for_backward ( None , None )
2023-02-02 04:09:31 +00:00
clone_func = torch . clone if len ( output_shape ) == 3 else lambda x : x
2022-07-22 21:41:05 +00:00
return clone_func ( output . view ( output_shape ) )
2022-09-11 02:51:29 +00:00
@staticmethod
2022-07-22 21:41:05 +00:00
def backward ( ctx , grad_output ) :
2022-08-03 18:54:01 +00:00
if ctx . is_empty :
2023-02-02 04:09:31 +00:00
bias_grad = None if ctx . bias is None else torch . zeros_like ( ctx . bias )
2022-08-16 19:00:54 +00:00
return torch . zeros_like ( ctx . A ) , torch . zeros_like ( ctx . B ) , None , bias_grad , None
2022-09-17 21:15:18 +00:00
req_gradA , req_gradB , _ , req_gradBias , _ = ctx . needs_input_grad
2023-02-24 18:17:57 +00:00
CAt , subA , A = ctx . tensors
2022-09-11 02:51:29 +00:00
SCAt , idx = ctx . tensor_states
formatB = ctx . formatB
2022-07-22 21:41:05 +00:00
state = ctx . state
2022-09-17 20:53:49 +00:00
grad_A = grad_B = grad_bias = None
if req_gradBias :
# compute grad_bias first before changing grad_output dtype
2022-09-17 21:19:22 +00:00
grad_bias = grad_output . sum ( 0 , dtype = ctx . dtype_bias )
2022-07-22 21:41:05 +00:00
2022-08-28 21:56:08 +00:00
# Cast grad_output to fp16
2022-07-22 21:41:05 +00:00
if len ( grad_output . shape ) == 3 :
2023-02-02 04:09:31 +00:00
grad_output = grad_output . reshape ( - 1 , grad_output . shape [ - 1 ] ) . contiguous ( )
2022-07-22 21:41:05 +00:00
2022-09-17 21:15:18 +00:00
Cgrad , Cgradt , SCgrad , SCgradt , coo_tensor = F . double_quant ( grad_output . to ( torch . float16 ) )
2022-09-11 02:51:29 +00:00
if req_gradB :
CxAt , SAt = F . transform ( CAt , formatB , transpose = True )
C32grad , Sgrad = F . transform ( Cgradt , " col32 " , transpose = True )
gradB32 , SgradB32 = F . igemmlt ( C32grad , CxAt , Sgrad , SAt )
2022-09-17 22:02:13 +00:00
grad_B = F . mm_dequant ( gradB32 , SgradB32 , SCgradt , SCAt )
2022-09-11 02:51:29 +00:00
if state . threshold > 0.0 and subA is not None :
2022-09-17 21:47:58 +00:00
grad_B [ : , idx ] + = torch . matmul ( grad_output . t ( ) , subA )
2022-09-11 02:51:29 +00:00
2022-07-22 21:41:05 +00:00
if req_gradA :
2022-09-11 03:18:44 +00:00
if state . CBt is not None :
2022-09-11 02:51:29 +00:00
C32grad , Sgrad = F . transform ( Cgrad , " col32 " )
if state . CxBt is None :
2023-02-02 04:09:31 +00:00
state . CxBt , state . SBt = F . transform ( state . CBt , to_order = formatB , transpose = True )
2022-09-11 02:51:29 +00:00
gradA32 , SgradA32 = F . igemmlt ( C32grad , state . CxBt , Sgrad , state . SBt )
2022-09-17 21:21:15 +00:00
grad_A = F . mm_dequant ( gradA32 , SgradA32 , SCgrad , state . SCBt ) . view ( ctx . grad_shape ) . to ( ctx . dtype_A )
2022-09-17 21:15:18 +00:00
2022-09-11 03:18:44 +00:00
elif state . CB is not None :
2023-02-02 04:09:31 +00:00
CB = state . CB . to ( ctx . dtype_A , copy = True ) . mul_ ( state . SCB . unsqueeze ( 1 ) . mul ( 1.0 / 127.0 ) )
grad_A = torch . matmul ( grad_output , CB ) . view ( ctx . grad_shape ) . to ( ctx . dtype_A )
elif state . CxB is not None :
CB = (
undo_layout ( state . CxB , state . tile_indices )
. to ( ctx . dtype_A )
. mul_ ( state . SCB . unsqueeze ( 1 ) . mul ( 1.0 / 127.0 ) )
)
2022-09-17 21:21:15 +00:00
grad_A = torch . matmul ( grad_output , CB ) . view ( ctx . grad_shape ) . to ( ctx . dtype_A )
2022-09-11 02:51:29 +00:00
else :
2023-02-02 04:09:31 +00:00
raise Exception ( " State must contain either CBt or CB or CxB matrix for backward " )
2022-08-16 19:00:54 +00:00
return grad_A , grad_B , None , grad_bias , None
2022-07-22 21:41:05 +00:00
2023-04-03 18:00:12 +00:00
class MatMul4Bit ( torch . autograd . Function ) :
2023-02-05 05:11:21 +00:00
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward ( ctx , A , B , out = None , bias = None , state = None ) :
# default of pytorch behavior if inputs are empty
ctx . is_empty = False
if prod ( A . shape ) == 0 :
ctx . is_empty = True
ctx . A = A
ctx . B = B
ctx . bias = bias
B_shape = state [ 1 ]
if A . shape [ - 1 ] == B_shape [ 0 ] :
return torch . empty ( A . shape [ : - 1 ] + B_shape [ 1 : ] , dtype = A . dtype , device = A . device )
else :
return torch . empty ( A . shape [ : - 1 ] + B_shape [ : 1 ] , dtype = A . dtype , device = A . device )
# 1. Dequantize
2023-02-05 06:00:04 +00:00
# 2. MatmulnN
2023-07-09 21:46:19 +00:00
output = torch . nn . functional . linear ( A , F . dequantize_4bit ( B , state ) . to ( A . dtype ) . t ( ) , bias )
2023-02-05 05:11:21 +00:00
# 3. Save state
ctx . state = state
ctx . dtype_A , ctx . dtype_B , ctx . dtype_bias = A . dtype , B . dtype , None if bias is None else bias . dtype
if any ( ctx . needs_input_grad [ : 2 ] ) :
2023-02-05 05:35:43 +00:00
ctx . tensors = ( A , B )
2023-02-05 05:11:21 +00:00
else :
2023-02-05 05:35:43 +00:00
ctx . tensors = ( None , None )
2023-02-05 05:11:21 +00:00
return output
@staticmethod
def backward ( ctx , grad_output ) :
if ctx . is_empty :
bias_grad = None if ctx . bias is None else torch . zeros_like ( ctx . bias )
return torch . zeros_like ( ctx . A ) , torch . zeros_like ( ctx . B ) , None , bias_grad , None
2023-02-05 05:35:43 +00:00
req_gradA , _ , _ , req_gradBias , _ = ctx . needs_input_grad
A , B = ctx . tensors
2023-02-05 05:11:21 +00:00
state = ctx . state
2023-02-05 05:35:43 +00:00
grad_A , grad_B , grad_bias = None , None , None
2023-02-05 05:11:21 +00:00
if req_gradBias :
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output . sum ( 0 , dtype = ctx . dtype_bias )
2023-02-05 05:35:43 +00:00
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
2023-07-09 21:46:19 +00:00
if req_gradA : grad_A = torch . matmul ( grad_output , F . dequantize_4bit ( B , ctx . state ) . to ( grad_output . dtype ) . t ( ) )
2023-02-05 05:11:21 +00:00
return grad_A , grad_B , None , grad_bias , None
2022-08-01 10:31:48 +00:00
def matmul (
2022-08-01 16:32:47 +00:00
A : tensor ,
B : tensor ,
out : tensor = None ,
state : MatmulLtState = None ,
threshold = 0.0 ,
2022-08-16 19:00:54 +00:00
bias = None
2022-08-01 10:31:48 +00:00
) :
2022-07-22 21:41:05 +00:00
state = state or MatmulLtState ( )
if threshold > 0.0 :
state . threshold = threshold
2022-08-16 19:00:54 +00:00
return MatMul8bitLt . apply ( A , B , out , bias , state )
2023-02-05 05:11:21 +00:00
2023-04-03 18:00:12 +00:00
def matmul_4bit ( A : tensor , B : tensor , quant_state : List , out : tensor = None , bias = None ) :
2023-02-05 06:00:04 +00:00
assert quant_state is not None
2023-07-05 02:58:31 +00:00
if A . numel ( ) == A . shape [ - 1 ] and A . requires_grad == False :
2023-07-11 12:55:49 +00:00
absmax , shape , dtype , blocksize , compressed_stats , quant_type , data_type = quant_state
if A . shape [ - 1 ] % blocksize != 0 :
warn ( f ' Some matrices hidden dimension is not a multiple of { blocksize } and efficient inference kernels are not supported for these (slow). Matrix input size found: { A . shape } ' )
return MatMul4Bit . apply ( A , B , out , bias , quant_state )
else :
2023-07-12 00:13:33 +00:00
out = F . gemv_4bit ( A , B . t ( ) , out , state = quant_state )
if bias is not None :
out + = bias
return out
2023-07-05 02:58:31 +00:00
else :
return MatMul4Bit . apply ( A , B , out , bias , quant_state )