Fixed LinearFP8 and added tests.

This commit is contained in:
Tim Dettmers 2023-02-13 17:48:52 -08:00
parent fa255cbc56
commit 2dfa3ce16d
2 changed files with 40 additions and 3 deletions

View File

@ -352,10 +352,10 @@ class LinearFP8(nn.Linear):
def forward(self, x: torch.Tensor):
if self.fw_code is None:
self.bw_code = F.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = F.create_fp8_map(True, 4, 3, 8).to(x.device)
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, code=self.bw_code)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code)
if self.bias is not None:
out += self.bias

View File

@ -525,3 +525,40 @@ def test_linear8bitlt_fp32_bias():
b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1)
assert l1.bias is None
def test_fp8linear():
b = 10
h = 1024
inp = torch.randn(b, h).cuda()
fp32 = torch.nn.Linear(h, h*2).cuda()
fp8 = bnb.nn.LinearFP8(h, h*2).cuda()
fp32b = torch.nn.Linear(h*2, h).cuda()
fp8b = bnb.nn.LinearFP8(h*2, h).cuda()
fp8.weight.data.copy_(fp32.weight.data)
fp8.bias.data.copy_(fp32.bias.data)
fp8b.weight.data.copy_(fp32b.weight.data)
fp8b.bias.data.copy_(fp32b.bias.data)
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
err = (a-b).abs().mean()
a.mean().backward()
b.mean().backward()
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
assert err < 0.05
assert graderr < 0.00002
assert bgraderr < 0.00002