Added better default compute_dtype handling for Linear4bit layers.
This commit is contained in:
parent
c82f51c0f7
commit
412fd0e717
|
@ -4,6 +4,7 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
from typing import Optional, TypeVar, Union, overload
|
||||
|
||||
import warnings
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, device, dtype, nn
|
||||
|
@ -205,6 +206,28 @@ class Linear4bit(nn.Linear):
|
|||
super().__init__(input_features, output_features, bias, device)
|
||||
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
|
||||
self.compute_dtype = compute_dtype
|
||||
self.compute_type_is_set = False
|
||||
|
||||
def set_compute_type(self, x):
|
||||
if x.dtype in [torch.float32, torch.bfloat16]:
|
||||
# the input is in a dtype that is safe to compute in, we switch
|
||||
# to this type for speed and stability
|
||||
self.compute_dtype = x.dtype
|
||||
elif x.dtype == torch.float16:
|
||||
# we take the compoute dtype passed into the layer
|
||||
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
|
||||
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
|
||||
# warn the user about this
|
||||
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference.')
|
||||
warnings.filterwarnings('ignore', message='.*inference.')
|
||||
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
|
||||
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
|
||||
warnings.filterwarnings('ignore', message='.*inference or training')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
|
@ -213,6 +236,10 @@ class Linear4bit(nn.Linear):
|
|||
|
||||
if getattr(self.weight, 'quant_state', None) is None:
|
||||
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
|
||||
if not self.compute_type_is_set:
|
||||
self.set_compute_type(x)
|
||||
self.compute_type_is_set = True
|
||||
|
||||
inp_dtype = x.dtype
|
||||
if self.compute_dtype is not None:
|
||||
x = x.to(self.compute_dtype)
|
||||
|
|
|
@ -516,7 +516,10 @@ modules.append(bnb.nn.LinearFP4)
|
|||
modules.append(bnb.nn.LinearNF4)
|
||||
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
|
||||
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
|
||||
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C']
|
||||
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32))
|
||||
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
|
||||
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
|
||||
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
|
||||
@pytest.mark.parametrize("module", modules, ids=names)
|
||||
def test_kbit_backprop(module):
|
||||
|
@ -563,10 +566,10 @@ def test_kbit_backprop(module):
|
|||
relerrs2.append(relerr2.mean().item())
|
||||
|
||||
if isinstance(module, bnb.nn.Linear8bitLt):
|
||||
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
|
||||
assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
|
||||
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
|
||||
else:
|
||||
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
|
||||
assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
|
||||
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
|
||||
ref.zero_grad()
|
||||
kbit.zero_grad()
|
||||
|
@ -608,9 +611,33 @@ def test_fp8linear():
|
|||
assert graderr < 0.00002
|
||||
assert bgraderr < 0.00002
|
||||
|
||||
|
||||
|
||||
|
||||
def test_4bit_warnings():
|
||||
dim1 = 64
|
||||
|
||||
with pytest.warns(UserWarning, match=r'inference or training'):
|
||||
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
|
||||
net = net.cuda()
|
||||
inp = torch.rand(10, dim1).cuda().half()
|
||||
net(inp)
|
||||
with pytest.warns(UserWarning, match=r'inference.'):
|
||||
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
|
||||
net = net.cuda()
|
||||
inp = torch.rand(1, dim1).cuda().half()
|
||||
net(inp)
|
||||
|
||||
with pytest.warns(UserWarning) as record:
|
||||
|
||||
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
|
||||
net = net.cuda()
|
||||
inp = torch.rand(10, dim1).cuda().half()
|
||||
net(inp)
|
||||
|
||||
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
|
||||
net = net.cuda()
|
||||
inp = torch.rand(1, dim1).cuda().half()
|
||||
net(inp)
|
||||
|
||||
assert len(record) == 2
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user