Added better default compute_dtype handling for Linear4bit layers.

This commit is contained in:
Tim Dettmers 2023-07-22 12:56:29 -07:00
parent c82f51c0f7
commit 412fd0e717
2 changed files with 60 additions and 6 deletions

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
from typing import Optional, TypeVar, Union, overload
import warnings
import torch
import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
@ -205,6 +206,28 @@ class Linear4bit(nn.Linear):
super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype
self.compute_type_is_set = False
def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
# the input is in a dtype that is safe to compute in, we switch
# to this type for speed and stability
self.compute_dtype = x.dtype
elif x.dtype == torch.float16:
# we take the compoute dtype passed into the layer
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.')
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')
def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
@ -213,6 +236,10 @@ class Linear4bit(nn.Linear):
if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
inp_dtype = x.dtype
if self.compute_dtype is not None:
x = x.to(self.compute_dtype)

View File

@ -516,7 +516,10 @@ modules.append(bnb.nn.LinearFP4)
modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C']
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("module", modules, ids=names)
def test_kbit_backprop(module):
@ -563,10 +566,10 @@ def test_kbit_backprop(module):
relerrs2.append(relerr2.mean().item())
if isinstance(module, bnb.nn.Linear8bitLt):
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05)
assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
else:
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05)
assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
ref.zero_grad()
kbit.zero_grad()
@ -608,9 +611,33 @@ def test_fp8linear():
assert graderr < 0.00002
assert bgraderr < 0.00002
def test_4bit_warnings():
dim1 = 64
with pytest.warns(UserWarning, match=r'inference or training'):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net(inp)
with pytest.warns(UserWarning, match=r'inference.'):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net(inp)
with pytest.warns(UserWarning) as record:
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net(inp)
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net(inp)
assert len(record) == 2