Added triton test.
This commit is contained in:
parent
a13a522c4c
commit
30d21d585c
44
tests/test_triton.py
Normal file
44
tests/test_triton.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("triton_module", [SwitchBackGlobalLinear, SwitchBackLinear])
|
||||||
|
def test_switchbatch(triton_module):
|
||||||
|
for dim in [83, 17, 128]:
|
||||||
|
for batch in [13, 128, 256]:
|
||||||
|
|
||||||
|
standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
|
||||||
|
switchback = triton_module(dim, 4 * dim).cuda().half()
|
||||||
|
switchback.weight.data.copy_(standard.weight)
|
||||||
|
switchback.bias.data.copy_(standard.bias)
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(100):
|
||||||
|
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
|
||||||
|
x2 = x1.clone().detach().requires_grad_(True)
|
||||||
|
print('standard')
|
||||||
|
out_standard = standard(x1)
|
||||||
|
print('switchback')
|
||||||
|
out_sb = switchback(x1)
|
||||||
|
|
||||||
|
(out_standard.abs().mean()).backward()
|
||||||
|
(out_sb.abs().mean()).backward()
|
||||||
|
|
||||||
|
err_sb = (out_standard - out_sb).abs().mean()
|
||||||
|
print('OUT', err_sb)
|
||||||
|
|
||||||
|
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
|
||||||
|
|
||||||
|
print('GW2', err_sb)
|
||||||
|
|
||||||
|
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
|
||||||
|
|
||||||
|
print('GW1', err_sb)
|
||||||
|
|
||||||
|
#err_sb = (x1.grad - x2.grad).abs().mean()
|
||||||
|
|
||||||
|
#print('GX1', err_sb)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user