Fixed LinearFP8 and added tests.
This commit is contained in:
parent
fa255cbc56
commit
2dfa3ce16d
|
@ -352,10 +352,10 @@ class LinearFP8(nn.Linear):
|
|||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = F.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = F.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, code=self.bw_code)
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
|
|
|
@ -525,3 +525,40 @@ def test_linear8bitlt_fp32_bias():
|
|||
b1 = torch.randn(16, 8, 32, device="cuda").half()
|
||||
o1 = l1(b1)
|
||||
assert l1.bias is None
|
||||
|
||||
def test_fp8linear():
|
||||
|
||||
b = 10
|
||||
h = 1024
|
||||
inp = torch.randn(b, h).cuda()
|
||||
fp32 = torch.nn.Linear(h, h*2).cuda()
|
||||
fp8 = bnb.nn.LinearFP8(h, h*2).cuda()
|
||||
fp32b = torch.nn.Linear(h*2, h).cuda()
|
||||
fp8b = bnb.nn.LinearFP8(h*2, h).cuda()
|
||||
|
||||
fp8.weight.data.copy_(fp32.weight.data)
|
||||
fp8.bias.data.copy_(fp32.bias.data)
|
||||
fp8b.weight.data.copy_(fp32b.weight.data)
|
||||
fp8b.bias.data.copy_(fp32b.bias.data)
|
||||
|
||||
a = fp32b(torch.nn.functional.gelu(fp32(inp)))
|
||||
b = fp8b(torch.nn.functional.gelu(fp8(inp)))
|
||||
|
||||
err = (a-b).abs().mean()
|
||||
|
||||
a.mean().backward()
|
||||
b.mean().backward()
|
||||
|
||||
graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean()
|
||||
bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean()
|
||||
|
||||
assert err < 0.05
|
||||
assert graderr < 0.00002
|
||||
assert bgraderr < 0.00002
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user