2022-08-01 10:31:48 +00:00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
2021-10-06 02:16:20 +00:00
# LICENSE file in the root directory of this source tree.
2022-10-27 11:32:01 +00:00
from typing import Optional , TypeVar , Union , overload
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
import torch
2021-10-06 02:16:20 +00:00
import torch . nn . functional as F
2022-08-01 10:31:48 +00:00
from torch import Tensor , device , dtype , nn
2021-10-06 02:16:20 +00:00
2022-08-01 10:31:48 +00:00
import bitsandbytes as bnb
2021-10-06 02:16:20 +00:00
from bitsandbytes . optim import GlobalOptimManager
2022-08-01 10:31:48 +00:00
T = TypeVar ( " T " , bound = " torch.nn.Module " )
2022-07-22 21:41:05 +00:00
2021-10-06 02:16:20 +00:00
class StableEmbedding ( torch . nn . Embedding ) :
2022-08-01 10:31:48 +00:00
def __init__ (
self ,
num_embeddings : int ,
embedding_dim : int ,
padding_idx : Optional [ int ] = None ,
max_norm : Optional [ float ] = None ,
norm_type : float = 2.0 ,
scale_grad_by_freq : bool = False ,
sparse : bool = False ,
_weight : Optional [ Tensor ] = None ,
2022-11-04 21:05:30 +00:00
device = None ,
dtype = None ,
2022-08-01 10:31:48 +00:00
) - > None :
2022-10-27 11:14:13 +00:00
super ( ) . __init__ (
2022-08-01 10:31:48 +00:00
num_embeddings ,
embedding_dim ,
padding_idx ,
max_norm ,
norm_type ,
scale_grad_by_freq ,
sparse ,
_weight ,
2022-11-04 21:05:30 +00:00
device ,
dtype ,
2022-08-01 10:31:48 +00:00
)
2022-11-04 21:05:30 +00:00
self . norm = torch . nn . LayerNorm ( embedding_dim , device = device )
2022-08-01 10:31:48 +00:00
GlobalOptimManager . get_instance ( ) . register_module_override (
self , " weight " , { " optim_bits " : 32 }
)
2021-10-06 02:16:20 +00:00
def reset_parameters ( self ) - > None :
torch . nn . init . xavier_uniform_ ( self . weight )
self . _fill_padding_idx_with_zero ( )
2022-08-01 10:31:48 +00:00
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
2021-10-06 02:16:20 +00:00
to make the Layer compatible with Pytorch < 1.9 .
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome . However , with this we can ensure compatibility with previous
PyTorch releases .
2022-08-01 10:31:48 +00:00
"""
2021-10-06 02:16:20 +00:00
def _fill_padding_idx_with_zero ( self ) - > None :
if self . padding_idx is not None :
with torch . no_grad ( ) :
self . weight [ self . padding_idx ] . fill_ ( 0 )
def forward ( self , input : Tensor ) - > Tensor :
emb = F . embedding (
2022-08-01 10:31:48 +00:00
input ,
self . weight ,
self . padding_idx ,
self . max_norm ,
self . norm_type ,
self . scale_grad_by_freq ,
self . sparse ,
)
2021-10-06 02:16:20 +00:00
2022-11-04 21:05:30 +00:00
# always apply layer norm in full precision
emb = emb . to ( torch . get_default_dtype ( ) )
return self . norm ( emb ) . to ( self . weight . dtype )
2021-11-29 17:32:13 +00:00
class Embedding ( torch . nn . Embedding ) :
2022-08-01 10:31:48 +00:00
def __init__ (
self ,
num_embeddings : int ,
embedding_dim : int ,
padding_idx : Optional [ int ] = None ,
max_norm : Optional [ float ] = None ,
norm_type : float = 2.0 ,
scale_grad_by_freq : bool = False ,
sparse : bool = False ,
_weight : Optional [ Tensor ] = None ,
) - > None :
2022-10-27 11:14:13 +00:00
super ( ) . __init__ (
2022-08-01 10:31:48 +00:00
num_embeddings ,
embedding_dim ,
padding_idx ,
max_norm ,
norm_type ,
scale_grad_by_freq ,
sparse ,
_weight ,
)
GlobalOptimManager . get_instance ( ) . register_module_override (
self , " weight " , { " optim_bits " : 32 }
)
2021-11-29 17:32:13 +00:00
def reset_parameters ( self ) - > None :
torch . nn . init . xavier_uniform_ ( self . weight )
self . _fill_padding_idx_with_zero ( )
2022-08-01 10:31:48 +00:00
""" !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
2021-11-29 17:32:13 +00:00
to make the Layer compatible with Pytorch < 1.9 .
This means that if this changes in future PyTorch releases this need to change too
which is cumbersome . However , with this we can ensure compatibility with previous
PyTorch releases .
2022-08-01 10:31:48 +00:00
"""
2021-11-29 17:32:13 +00:00
def _fill_padding_idx_with_zero ( self ) - > None :
if self . padding_idx is not None :
with torch . no_grad ( ) :
self . weight [ self . padding_idx ] . fill_ ( 0 )
def forward ( self , input : Tensor ) - > Tensor :
emb = F . embedding (
2022-08-01 10:31:48 +00:00
input ,
self . weight ,
self . padding_idx ,
self . max_norm ,
self . norm_type ,
self . scale_grad_by_freq ,
self . sparse ,
)
2021-11-29 17:32:13 +00:00
return emb
2022-07-22 21:41:05 +00:00
2022-08-01 10:31:48 +00:00
2022-07-22 21:41:05 +00:00
class Int8Params ( torch . nn . Parameter ) :
2022-08-01 10:31:48 +00:00
def __new__ (
2022-08-01 16:32:47 +00:00
cls ,
data = None ,
requires_grad = True ,
has_fp16_weights = False ,
CB = None ,
SCB = None ,
2022-08-01 10:31:48 +00:00
) :
2022-07-22 21:41:05 +00:00
cls . has_fp16_weights = has_fp16_weights
cls . CB = None
cls . SCB = None
if data is None :
data = torch . empty ( 0 )
return torch . Tensor . _make_subclass ( cls , data , requires_grad )
def cuda ( self , device ) :
if self . has_fp16_weights :
return super ( ) . cuda ( device )
else :
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
B = self . data . contiguous ( ) . half ( ) . cuda ( device )
CB , CBt , SCB , SCBt , coo_tensorB = bnb . functional . double_quant ( B )
del CBt
2022-08-25 16:09:23 +00:00
del SCBt
2022-07-22 21:41:05 +00:00
self . data = CB
2022-08-01 10:31:48 +00:00
setattr ( self , " CB " , CB )
setattr ( self , " SCB " , SCB )
2022-07-22 21:41:05 +00:00
return self
@overload
2022-08-01 10:31:48 +00:00
def to (
self : T ,
device : Optional [ Union [ int , device ] ] = . . . ,
dtype : Optional [ Union [ dtype , str ] ] = . . . ,
non_blocking : bool = . . . ,
) - > T :
2022-07-22 21:41:05 +00:00
. . .
@overload
def to ( self : T , dtype : Union [ dtype , str ] , non_blocking : bool = . . . ) - > T :
. . .
@overload
def to ( self : T , tensor : Tensor , non_blocking : bool = . . . ) - > T :
. . .
def to ( self , * args , * * kwargs ) :
2022-08-01 10:31:48 +00:00
device , dtype , non_blocking , convert_to_format = torch . _C . _nn . _parse_to (
* args , * * kwargs
)
if (
device is not None
and device . type == " cuda "
and self . data . device . type == " cpu "
) :
return self . cuda ( device )
2022-07-22 21:41:05 +00:00
else :
2022-08-01 10:31:48 +00:00
new_param = Int8Params (
2022-08-01 16:32:47 +00:00
super ( ) . to (
device = device , dtype = dtype , non_blocking = non_blocking
) ,
2022-08-01 10:31:48 +00:00
requires_grad = self . requires_grad ,
has_fp16_weights = self . has_fp16_weights ,
)
2022-07-22 21:41:05 +00:00
new_param . CB = self . CB
new_param . SCB = self . SCB
return new_param
class Linear8bitLt ( nn . Linear ) :
2023-02-02 04:09:31 +00:00
def __init__ ( self , input_features , output_features , bias = True , has_fp16_weights = True ,
memory_efficient_backward = False , threshold = 0.0 , index = None ) :
super ( ) . __init__ ( input_features , output_features , bias )
assert not memory_efficient_backward , " memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0 "
2022-07-22 21:41:05 +00:00
self . state = bnb . MatmulLtState ( )
2022-08-01 10:31:48 +00:00
self . index = index
2022-07-22 21:41:05 +00:00
self . state . threshold = threshold
self . state . has_fp16_weights = has_fp16_weights
2022-09-11 02:51:29 +00:00
self . state . memory_efficient_backward = memory_efficient_backward
2022-07-22 21:41:05 +00:00
if threshold > 0.0 and not has_fp16_weights :
self . state . use_pool = True
2023-02-02 04:09:31 +00:00
self . weight = Int8Params ( self . weight . data , has_fp16_weights = has_fp16_weights , requires_grad = has_fp16_weights )
2022-07-22 21:41:05 +00:00
def init_8bit_state ( self ) :
self . state . CB = self . weight . CB
self . state . SCB = self . weight . SCB
self . weight . CB = None
self . weight . SCB = None
2023-02-02 04:09:31 +00:00
def forward ( self , x : torch . Tensor ) :
2022-07-22 21:41:05 +00:00
self . state . is_training = self . training
2022-08-01 10:31:48 +00:00
if self . weight . CB is not None :
self . init_8bit_state ( )
2022-08-17 10:45:57 +00:00
# weights are cast automatically as Int8Params, but the bias has to be cast manually
2023-02-02 04:09:31 +00:00
if self . bias is not None and self . bias . dtype != x . dtype :
self . bias . data = self . bias . data . to ( x . dtype )
2022-07-22 21:41:05 +00:00
2022-08-16 19:00:54 +00:00
out = bnb . matmul ( x , self . weight , bias = self . bias , state = self . state )
2022-09-11 02:51:29 +00:00
if not self . state . has_fp16_weights :
2023-02-02 04:09:31 +00:00
if self . state . CB is not None and self . state . CxB is not None :
2022-09-11 02:51:29 +00:00
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self . state . CB
self . weight . data = self . state . CxB
2022-07-22 21:41:05 +00:00
return out