Added better default compute_dtype handling for Linear4bit layers.

This commit is contained in:
Tim Dettmers 2023-07-22 12:56:29 -07:00
parent c82f51c0f7
commit 412fd0e717
2 changed files with 60 additions and 6 deletions

View File

@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Optional, TypeVar, Union, overload from typing import Optional, TypeVar, Union, overload
import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
@ -205,6 +206,28 @@ class Linear4bit(nn.Linear):
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type)
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.compute_type_is_set = False
def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
# the input is in a dtype that is safe to compute in, we switch
# to this type for speed and stability
self.compute_dtype = x.dtype
elif x.dtype == torch.float16:
# we take the compoute dtype passed into the layer
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
# warn the user about this
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference.')
warnings.filterwarnings('ignore', message='.*inference.')
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
warnings.filterwarnings('ignore', message='.*inference or training')
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually # weights are cast automatically as Int8Params, but the bias has to be cast manually
@ -213,6 +236,10 @@ class Linear4bit(nn.Linear):
if getattr(self.weight, 'quant_state', None) is None: if getattr(self.weight, 'quant_state', None) is None:
print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
inp_dtype = x.dtype inp_dtype = x.dtype
if self.compute_dtype is not None: if self.compute_dtype is not None:
x = x.to(self.compute_dtype) x = x.to(self.compute_dtype)

View File

@ -516,7 +516,10 @@ modules.append(bnb.nn.LinearFP4)
modules.append(bnb.nn.LinearNF4) modules.append(bnb.nn.LinearNF4)
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True))
modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C'] modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float32))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.float16))
modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compute_dtype=torch.bfloat16))
names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C', 'NF4+fp32', 'NF4+fp16', 'NF4+bf16']
@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") @pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU")
@pytest.mark.parametrize("module", modules, ids=names) @pytest.mark.parametrize("module", modules, ids=names)
def test_kbit_backprop(module): def test_kbit_backprop(module):
@ -563,10 +566,10 @@ def test_kbit_backprop(module):
relerrs2.append(relerr2.mean().item()) relerrs2.append(relerr2.mean().item())
if isinstance(module, bnb.nn.Linear8bitLt): if isinstance(module, bnb.nn.Linear8bitLt):
torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05) assert_all_approx_close(grad1, grad2, atol=0.008, rtol=0.05, count=1)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05)
else: else:
torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05) assert_all_approx_close(grad1, grad2, atol=0.015, rtol=0.05, count=1)
torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05) torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05)
ref.zero_grad() ref.zero_grad()
kbit.zero_grad() kbit.zero_grad()
@ -608,9 +611,33 @@ def test_fp8linear():
assert graderr < 0.00002 assert graderr < 0.00002
assert bgraderr < 0.00002 assert bgraderr < 0.00002
def test_4bit_warnings():
dim1 = 64
with pytest.warns(UserWarning, match=r'inference or training'):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net(inp)
with pytest.warns(UserWarning, match=r'inference.'):
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net(inp)
with pytest.warns(UserWarning) as record:
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(10, dim1).cuda().half()
net(inp)
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, compute_dtype=torch.float32) for i in range(10)])
net = net.cuda()
inp = torch.rand(1, dim1).cuda().half()
net(inp)
assert len(record) == 2