166 lines
5.1 KiB
Python
166 lines
5.1 KiB
Python
|
|
import torch
|
|
import json
|
|
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
|
|
import time
|
|
|
|
if __name__ == '__main__':
|
|
|
|
print('Startin')
|
|
|
|
|
|
for dim in [1024, 1280, 1408, 1664, 2048]:
|
|
for batch in [2**14, 2**15, 2**16, 2**17]:
|
|
|
|
x1 = torch.randn(batch, dim).cuda().requires_grad_(True)
|
|
d = 2
|
|
|
|
standard = torch.nn.Sequential(
|
|
torch.nn.LayerNorm(dim),
|
|
torch.nn.Linear(dim, 4 * dim),
|
|
torch.nn.GELU(),
|
|
torch.nn.Linear(4 * dim, dim),
|
|
).cuda()
|
|
|
|
my_standard = torch.nn.Sequential(
|
|
torch.nn.LayerNorm(dim),
|
|
StandardLinear(dim, 4 * dim),
|
|
torch.nn.GELU(),
|
|
StandardLinear(4 * dim, dim),
|
|
).cuda()
|
|
|
|
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
|
|
|
|
sb = torch.nn.Sequential(
|
|
torch.nn.LayerNorm(dim),
|
|
SwitchBackGlobalLinear(dim, 4 * dim),
|
|
torch.nn.GELU(),
|
|
SwitchBackGlobalLinear(4 * dim, dim),
|
|
).cuda()
|
|
|
|
standard_compiled = torch.compile(standard)
|
|
|
|
print('Model part 2')
|
|
|
|
repeat = 32
|
|
|
|
|
|
info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim}
|
|
|
|
k = 'standard'
|
|
for _ in range(repeat // 2):
|
|
with torch.cuda.amp.autocast():
|
|
out_standard = standard(x1)
|
|
((2 ** 16) * out_standard).abs().mean().backward()
|
|
|
|
torch.cuda.synchronize()
|
|
start = time.time()
|
|
for _ in range(repeat):
|
|
with torch.cuda.amp.autocast():
|
|
out_standard = standard(x1)
|
|
((2 ** 16) * out_standard).abs().mean().backward()
|
|
|
|
torch.cuda.synchronize()
|
|
end = time.time()
|
|
ms = (end - start) / repeat * 1000
|
|
print(f"time {k}: {ms:.3f} ms")
|
|
info[k] = ms
|
|
|
|
|
|
x1.grad.zero_()
|
|
|
|
k = 'my_standard'
|
|
for _ in range(repeat // 2):
|
|
with torch.cuda.amp.autocast():
|
|
out_my_standard = my_standard(x1)
|
|
((2 ** 16) * out_my_standard).abs().mean().backward()
|
|
|
|
torch.cuda.synchronize()
|
|
start = time.time()
|
|
for _ in range(repeat):
|
|
with torch.cuda.amp.autocast():
|
|
out_my_standard = my_standard(x1)
|
|
((2 ** 16) * out_my_standard).abs().mean().backward()
|
|
|
|
torch.cuda.synchronize()
|
|
end = time.time()
|
|
ms = (end - start) / repeat * 1000
|
|
print(f"time {k}: {ms:.3f} ms")
|
|
info[k] = ms
|
|
|
|
x1.grad.zero_()
|
|
|
|
k = 'standard_compiled'
|
|
for _ in range(repeat // 2):
|
|
with torch.cuda.amp.autocast():
|
|
out_standard_compiled = standard_compiled(x1)
|
|
((2 ** 16) * out_standard_compiled).abs().mean().backward()
|
|
|
|
torch.cuda.synchronize()
|
|
start = time.time()
|
|
for _ in range(repeat):
|
|
with torch.cuda.amp.autocast():
|
|
out_standard_compiled = standard_compiled(x1)
|
|
((2 ** 16) * out_standard_compiled).abs().mean().backward()
|
|
|
|
torch.cuda.synchronize()
|
|
end = time.time()
|
|
ms = (end - start) / repeat * 1000
|
|
print(f"time {k}: {ms:.3f} ms")
|
|
info[k] = ms
|
|
|
|
x1.grad.zero_()
|
|
|
|
k = 'sb'
|
|
for _ in range(repeat // 2):
|
|
with torch.cuda.amp.autocast():
|
|
out_sb = sb(x1)
|
|
((2 ** 16) * out_sb).abs().mean().backward()
|
|
|
|
torch.cuda.synchronize()
|
|
start = time.time()
|
|
for _ in range(repeat):
|
|
with torch.cuda.amp.autocast():
|
|
out_sb = sb(x1)
|
|
((2 ** 16) * out_sb).abs().mean().backward()
|
|
|
|
torch.cuda.synchronize()
|
|
end = time.time()
|
|
ms = (end - start) / repeat * 1000
|
|
print(f"time {k}: {ms:.3f} ms")
|
|
info[k] = ms
|
|
|
|
|
|
info_json = json.dumps(info)
|
|
|
|
|
|
with open("tests/triton_tests/info_mlp_autocast_ln.jsonl", "a") as file:
|
|
file.write(info_json + "\n")
|
|
|
|
|
|
#exit()
|
|
|
|
# err_fused = (out_standard - out_fused).abs().mean()
|
|
# err_sb = (out_standard - out_sb).abs().mean()
|
|
# print('OUT', err_fused, err_sb)
|
|
|
|
# err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean()
|
|
# err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean()
|
|
|
|
# print('GW2', err_fused, err_sb)
|
|
|
|
# err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean()
|
|
# err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean()
|
|
|
|
# print('GW1', err_fused, err_sb)
|
|
|
|
# err_fused = (x1.grad - x2.grad).abs().mean()
|
|
# err_sb = (x1.grad - x3.grad).abs().mean()
|
|
|
|
# print('GX1', err_fused, err_sb)
|
|
|
|
# import pdb; pdb.set_trace()
|
|
|
|
|
|
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
|