2022-08-01 10:31:48 +00:00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
2021-10-06 02:16:20 +00:00
# LICENSE file in the root directory of this source tree.
2022-10-27 11:32:01 +00:00
from typing import Optional , TypeVar , Union , overload
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
import torch
2021-10-06 02:16:20 +00:00
import torch . nn . functional as F
2022-08-01 10:31:48 +00:00
from torch import Tensor , device , dtype , nn
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
import bitsandbytes as bnb
2021-10-06 02:16:20 +00:00
from bitsandbytes . optim import GlobalOptimManager
2022-08-01 10:31:48 +00:00
T = TypeVar ( " T " , bound = " torch.nn.Module " )
2022-07-22 21:41:05 +00:00
2021-10-06 02:16:20 +00:00
class StableEmbedding ( torch . nn . Embedding ) :
2022-08-01 10:31:48 +00:00
def __init__ (
self ,
num_embeddings : int ,
embedding_dim : int ,
padding_idx : Optional [ int ] = None ,
max_norm : Optional [ float ] = None ,
norm_type : float = 2.0 ,
scale_grad_by_freq : bool = False ,
sparse : bool = False ,
_weight : Optional [ Tensor ] = None ,
2022-11-04 21:05:30 +00:00
device = None ,
dtype = None ,
2022-08-01 10:31:48 +00:00
) - > None :
2022-10-27 11:14:13 +00:00
super ( ) . __init__ (
2022-08-01 10:31:48 +00:00
num_embeddings ,
embedding_dim ,
padding_idx ,
max_norm ,
norm_type ,
scale_grad_by_freq ,
sparse ,
_weight ,
2022-11-04 21:05:30 +00:00
device ,
dtype ,
2022-08-01 10:31:48 +00:00
)
2022-11-04 21:05:30 +00:00
self . norm = torch . nn . LayerNorm ( embedding_dim , device = device )
2022-08-01 10:31:48 +00:00
GlobalOptimManager . get_instance ( ) . register_module_override (
self , " weight " , { " optim_bits " : 32 }
)
2021-10-06 02:16:20 +00:00
def reset_parameters ( self ) - > None :
torch . nn . init . xavier_uniform_ ( self . weight )
self . _fill_padding_idx_with_zero ( )
2022-08-01 10:31:48 +00:00
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
2021-10-06 02:16:20 +00:00
to make the Layer compatible with Pytorch < 1.9 .
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome . However , with this we can ensure compatibility with previous
PyTorch releases .
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
def _fill_padding_idx_with_zero ( self ) - > None :
if self . padding_idx is not None :
with torch . no_grad ( ) :
self . weight [ self . padding_idx ] . fill_ ( 0 )
def forward ( self , input : Tensor ) - > Tensor :
emb = F . embedding (
2022-08-01 10:31:48 +00:00
input ,
self . weight ,
self . padding_idx ,
self . max_norm ,
self . norm_type ,
self . scale_grad_by_freq ,
self . sparse ,
)
2021-10-06 02:16:20 +00:00
2022-11-04 21:05:30 +00:00
# always apply layer norm in full precision
emb = emb . to ( torch . get_default_dtype ( ) )
return self . norm ( emb ) . to ( self . weight . dtype )
2021-11-29 17:32:13 +00:00
class Embedding ( torch . nn . Embedding ) :
2022-08-01 10:31:48 +00:00
def __init__ (
self ,
num_embeddings : int ,
embedding_dim : int ,
padding_idx : Optional [ int ] = None ,
max_norm : Optional [ float ] = None ,
norm_type : float = 2.0 ,
scale_grad_by_freq : bool = False ,
sparse : bool = False ,
_weight : Optional [ Tensor ] = None ,
) - > None :
2022-10-27 11:14:13 +00:00
super ( ) . __init__ (
2022-08-01 10:31:48 +00:00
num_embeddings ,
embedding_dim ,
padding_idx ,
max_norm ,
norm_type ,
scale_grad_by_freq ,
sparse ,
_weight ,
)
GlobalOptimManager . get_instance ( ) . register_module_override (
self , " weight " , { " optim_bits " : 32 }
)
2021-11-29 17:32:13 +00:00
def reset_parameters ( self ) - > None :
torch . nn . init . xavier_uniform_ ( self . weight )
self . _fill_padding_idx_with_zero ( )
2022-08-01 10:31:48 +00:00
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
2021-11-29 17:32:13 +00:00
to make the Layer compatible with Pytorch < 1.9 .
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome . However , with this we can ensure compatibility with previous
PyTorch releases .
2022-08-01 10:31:48 +00:00
"""
2021-11-29 17:32:13 +00:00
def _fill_padding_idx_with_zero ( self ) - > None :
if self . padding_idx is not None :
with torch . no_grad ( ) :
self . weight [ self . padding_idx ] . fill_ ( 0 )
def forward ( self , input : Tensor ) - > Tensor :
emb = F . embedding (
2022-08-01 10:31:48 +00:00
input ,
self . weight ,
self . padding_idx ,
self . max_norm ,
self . norm_type ,
self . scale_grad_by_freq ,
self . sparse ,
)
2021-11-29 17:32:13 +00:00
return emb
2022-07-22 21:41:05 +00:00
2023-04-03 18:00:12 +00:00
class Params4bit ( torch . nn . Parameter ) :
def __new__ ( cls , data = None , requires_grad = True , quant_state = None , blocksize = 64 , compress_statistics = True , quant_type = ' fp4 ' ) :
2023-02-05 05:11:21 +00:00
if data is None :
data = torch . empty ( 0 )
2023-04-04 01:47:00 +00:00
self = torch . Tensor . _make_subclass ( cls , data , requires_grad )
self . blocksize = blocksize
self . compress_statistics = compress_statistics
self . quant_type = quant_type
2023-04-07 16:59:21 +00:00
self . quant_state = quant_state
self . data = data
2023-04-04 01:47:00 +00:00
return self
2023-02-05 05:11:21 +00:00
def cuda ( self , device ) :
w = self . data . contiguous ( ) . half ( ) . cuda ( device )
2023-04-07 16:59:21 +00:00
w_4bit , quant_state = bnb . functional . quantize_4bit ( w , blocksize = self . blocksize , compress_statistics = self . compress_statistics , quant_type = self . quant_type )
self . data = w_4bit
2023-02-05 05:11:21 +00:00
self . quant_state = quant_state
return self
@overload
def to ( self : T , device : Optional [ Union [ int , device ] ] = . . . , dtype : Optional [ Union [ dtype , str ] ] = . . . , non_blocking : bool = . . . , ) - > T :
. . .
@overload
def to ( self : T , dtype : Union [ dtype , str ] , non_blocking : bool = . . . ) - > T :
. . .
@overload
def to ( self : T , tensor : Tensor , non_blocking : bool = . . . ) - > T :
. . .
def to ( self , * args , * * kwargs ) :
device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to ( * args , * * kwargs )
if ( device is not None and device . type == " cuda " and self . data . device . type == " cpu " ) :
return self . cuda ( device )
else :
2023-04-07 16:59:21 +00:00
s = self . quant_state
if s is not None :
# make sure the quantization state is on the right device
s [ 0 ] = s [ 0 ] . to ( device )
if self . compress_statistics :
# TODO: refactor this. This is a nightmare
2023-04-18 01:01:49 +00:00
# for 4-bit:
# state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type]
# state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type]
#s[-2][0] = s[-2][0].to(device) # offset
#s[-2][1][0] = s[-2][1][0].to(device) # nested absmax
# for 8-bit
2023-04-07 16:59:21 +00:00
s [ - 2 ] [ 0 ] = s [ - 2 ] [ 0 ] . to ( device ) # offset
s [ - 2 ] [ 1 ] [ 0 ] = s [ - 2 ] [ 1 ] [ 0 ] . to ( device ) # nested quantiation state statitics
s [ - 2 ] [ 1 ] [ 1 ] = s [ - 2 ] [ 1 ] [ 1 ] . to ( device ) # nested quantiation codebook
2023-04-03 18:00:12 +00:00
new_param = Params4bit ( super ( ) . to ( device = device , dtype = dtype , non_blocking = non_blocking ) ,
2023-04-07 16:59:21 +00:00
requires_grad = self . requires_grad , quant_state = self . quant_state ,
blocksize = self . blocksize , compress_statistics = self . compress_statistics ,
quant_type = self . quant_type )
2023-02-05 05:11:21 +00:00
return new_param
2023-04-03 18:00:12 +00:00
class Linear4bit ( nn . Linear ) :
def __init__ ( self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True , quant_type = ' fp4 ' ) :
2023-02-05 05:11:21 +00:00
super ( ) . __init__ ( input_features , output_features , bias )
2023-04-03 18:00:12 +00:00
self . weight = Params4bit ( self . weight . data , requires_grad = False , compress_statistics = compress_statistics , quant_type = quant_type )
2023-03-27 16:12:57 +00:00
self . compute_dtype = compute_dtype
2023-02-05 05:11:21 +00:00
def forward ( self , x : torch . Tensor ) :
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self . bias is not None and self . bias . dtype != x . dtype :
self . bias . data = self . bias . data . to ( x . dtype )
2023-02-05 14:29:52 +00:00
if getattr ( self . weight , ' quant_state ' , None ) is None :
print ( ' FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first. ' )
2023-02-24 18:17:57 +00:00
inp_dtype = x . dtype
2023-03-27 16:12:57 +00:00
if self . compute_dtype is not None :
x = x . to ( self . compute_dtype )
2023-04-12 19:28:35 +00:00
bias = None if self . bias is None else self . bias . to ( self . compute_dtype )
2023-04-03 18:00:12 +00:00
out = bnb . matmul_4bit ( x , self . weight . t ( ) , bias = bias , quant_state = self . weight . quant_state )
2023-03-27 16:12:57 +00:00
2023-02-24 18:17:57 +00:00
out = out . to ( inp_dtype )
2023-02-05 05:11:21 +00:00
return out
2023-04-03 18:00:12 +00:00
class LinearFP4 ( Linear4bit ) :
def __init__ ( self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True ) :
super ( ) . __init__ ( input_features , output_features , bias , compute_dtype , compress_statistics , ' fp4 ' )
class LinearNF4 ( Linear4bit ) :
def __init__ ( self , input_features , output_features , bias = True , compute_dtype = None , compress_statistics = True ) :
super ( ) . __init__ ( input_features , output_features , bias , compute_dtype , compress_statistics , ' nf4 ' )
2022-08-01 10:31:48 +00:00
2023-04-26 00:15:51 +00:00
2022-07-22 21:41:05 +00:00
class Int8Params ( torch . nn . Parameter ) :
2022-08-01 10:31:48 +00:00
def __new__ (
2022-08-01 16:32:47 +00:00
cls ,
data = None ,
requires_grad = True ,
has_fp16_weights = False ,
CB = None ,
SCB = None ,
2022-08-01 10:31:48 +00:00
) :
2022-07-22 21:41:05 +00:00
cls . has_fp16_weights = has_fp16_weights
cls . CB = None
cls . SCB = None
if data is None :
data = torch . empty ( 0 )
return torch . Tensor . _make_subclass ( cls , data , requires_grad )
def cuda ( self , device ) :
if self . has_fp16_weights :
return super ( ) . cuda ( device )
else :
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self . data . contiguous ( ) . half ( ) . cuda ( device )
CB , CBt , SCB , SCBt , coo_tensorB = bnb . functional . double_quant ( B )
del CBt
2022-08-25 16:09:23 +00:00
del SCBt
2022-07-22 21:41:05 +00:00
self . data = CB
2022-08-01 10:31:48 +00:00
setattr ( self , " CB " , CB )
setattr ( self , " SCB " , SCB )
2022-07-22 21:41:05 +00:00
return self
@overload
2022-08-01 10:31:48 +00:00
def to (
self : T ,
device : Optional [ Union [ int , device ] ] = . . . ,
dtype : Optional [ Union [ dtype , str ] ] = . . . ,
non_blocking : bool = . . . ,
) - > T :
2022-07-22 21:41:05 +00:00
. . .
@overload
def to ( self : T , dtype : Union [ dtype , str ] , non_blocking : bool = . . . ) - > T :
. . .
@overload
def to ( self : T , tensor : Tensor , non_blocking : bool = . . . ) - > T :
. . .
def to ( self , * args , * * kwargs ) :
2022-08-01 10:31:48 +00:00
device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to (
* args , * * kwargs
)
if (
device is not None
and device . type == " cuda "
and self . data . device . type == " cpu "
) :
return self . cuda ( device )
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
new_param = Int8Params (
2022-08-01 16:32:47 +00:00
super ( ) . to (
device = device , dtype = dtype , non_blocking = non_blocking
) ,
2022-08-01 10:31:48 +00:00
requires_grad = self . requires_grad ,
has_fp16_weights = self . has_fp16_weights ,
)
2022-07-22 21:41:05 +00:00
new_param . CB = self . CB
new_param . SCB = self . SCB
return new_param
2023-02-05 05:11:21 +00:00
2022-07-22 21:41:05 +00:00
class Linear8bitLt ( nn . Linear ) :
2023-02-02 04:09:31 +00:00
def __init__ ( self , input_features , output_features , bias = True , has_fp16_weights = True ,
memory_efficient_backward = False , threshold = 0.0 , index = None ) :
super ( ) . __init__ ( input_features , output_features , bias )
assert not memory_efficient_backward , " memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0 "
2022-07-22 21:41:05 +00:00
self . state = bnb . MatmulLtState ( )
2022-08-01 10:31:48 +00:00
self . index = index
2022-07-22 21:41:05 +00:00
self . state . threshold = threshold
self . state . has_fp16_weights = has_fp16_weights
2022-09-11 02:51:29 +00:00
self . state . memory_efficient_backward = memory_efficient_backward
2022-07-22 21:41:05 +00:00
if threshold > 0.0 and not has_fp16_weights :
self . state . use_pool = True
2023-02-02 04:09:31 +00:00
self . weight = Int8Params ( self . weight . data , has_fp16_weights = has_fp16_weights , requires_grad = has_fp16_weights )
2022-07-22 21:41:05 +00:00
def init_8bit_state ( self ) :
self . state . CB = self . weight . CB
self . state . SCB = self . weight . SCB
self . weight . CB = None
self . weight . SCB = None
2023-02-02 04:09:31 +00:00
def forward ( self , x : torch . Tensor ) :
2022-07-22 21:41:05 +00:00
self . state . is_training = self . training
2022-08-01 10:31:48 +00:00
if self . weight . CB is not None :
self . init_8bit_state ( )
2022-08-17 10:45:57 +00:00
# weights are cast automatically as Int8Params, but the bias has to be cast manually
2023-02-02 04:09:31 +00:00
if self . bias is not None and self . bias . dtype != x . dtype :
self . bias . data = self . bias . data . to ( x . dtype )
2022-07-22 21:41:05 +00:00
2022-08-16 19:00:54 +00:00
out = bnb . matmul ( x , self . weight , bias = self . bias , state = self . state )
2022-09-11 02:51:29 +00:00
if not self . state . has_fp16_weights :
2023-02-02 04:09:31 +00:00
if self . state . CB is not None and self . state . CxB is not None :
2022-09-11 02:51:29 +00:00
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self . state . CB
self . weight . data = self . state . CxB
2022-07-22 21:41:05 +00:00
return out