@ -516,7 +516,10 @@ modules.append(bnb.nn.LinearFP4)
modules . append ( bnb . nn . LinearNF4 )
modules . append ( lambda d1 , d2 : bnb . nn . LinearFP4 ( d1 , d2 , compress_statistics = True ) )
modules . append ( lambda d1 , d2 : bnb . nn . LinearNF4 ( d1 , d2 , compress_statistics = True ) )
names = [ ' Int8Lt ' , ' 4bit ' , ' FP4 ' , ' NF4 ' , ' FP4+C ' , ' NF4+C ' ]
modules . append ( lambda d1 , d2 : bnb . nn . LinearFP4 ( d1 , d2 , compute_dtype = torch . float32 ) )
modules . append ( lambda d1 , d2 : bnb . nn . LinearFP4 ( d1 , d2 , compute_dtype = torch . float16 ) )
modules . append ( lambda d1 , d2 : bnb . nn . LinearFP4 ( d1 , d2 , compute_dtype = torch . bfloat16 ) )
names = [ ' Int8Lt ' , ' 4bit ' , ' FP4 ' , ' NF4 ' , ' FP4+C ' , ' NF4+C ' , ' NF4+fp32 ' , ' NF4+fp16 ' , ' NF4+bf16 ' ]
@pytest.mark.skipif ( not torch . cuda . is_available ( ) , reason = " this test requires a GPU " )
@pytest.mark.parametrize ( " module " , modules , ids = names )
def test_kbit_backprop ( module ) :
@ -563,10 +566,10 @@ def test_kbit_backprop(module):
relerrs2 . append ( relerr2 . mean ( ) . item ( ) )
if isinstance ( module , bnb . nn . Linear8bitLt ) :
torch. testing . assert_close( grad1 , grad2 , atol = 0.008 , rtol = 0.05 )
assert_all_approx _close( grad1 , grad2 , atol = 0.008 , rtol = 0.05 , count = 1 )
torch . testing . assert_close ( bgrad1 , bgrad2 , atol = 0.008 , rtol = 0.05 )
else :
torch. testing . assert_close( grad1 , grad2 , atol = 0.015 , rtol = 0.05 )
assert_all_approx _close( grad1 , grad2 , atol = 0.015 , rtol = 0.05 , count = 1 )
torch . testing . assert_close ( bgrad1 , bgrad2 , atol = 0.02 , rtol = 0.05 )
ref . zero_grad ( )
kbit . zero_grad ( )
@ -608,9 +611,33 @@ def test_fp8linear():
assert graderr < 0.00002
assert bgraderr < 0.00002
def test_4bit_warnings ( ) :
dim1 = 64
with pytest . warns ( UserWarning , match = r ' inference or training ' ) :
net = nn . Sequential ( * [ bnb . nn . Linear4bit ( dim1 , dim1 , compute_dtype = torch . float32 ) for i in range ( 10 ) ] )
net = net . cuda ( )
inp = torch . rand ( 10 , dim1 ) . cuda ( ) . half ( )
net ( inp )
with pytest . warns ( UserWarning , match = r ' inference. ' ) :
net = nn . Sequential ( * [ bnb . nn . Linear4bit ( dim1 , dim1 , compute_dtype = torch . float32 ) for i in range ( 10 ) ] )
net = net . cuda ( )
inp = torch . rand ( 1 , dim1 ) . cuda ( ) . half ( )
net ( inp )
with pytest . warns ( UserWarning ) as record :
net = nn . Sequential ( * [ bnb . nn . Linear4bit ( dim1 , dim1 , compute_dtype = torch . float32 ) for i in range ( 10 ) ] )
net = net . cuda ( )
inp = torch . rand ( 10 , dim1 ) . cuda ( ) . half ( )
net ( inp )
net = nn . Sequential ( * [ bnb . nn . Linear4bit ( dim1 , dim1 , compute_dtype = torch . float32 ) for i in range ( 10 ) ] )
net = net . cuda ( )
inp = torch . rand ( 1 , dim1 ) . cuda ( ) . half ( )
net ( inp )
assert len ( record ) == 2