bitsandbytes-rocm/tests/triton_tests/rowwise.py

43 lines
870 B
Python
Raw Normal View History

2023-03-29 06:47:08 +00:00
import time
import torch
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
# 256 * 256 * 4096 _> 0.7
# 256 * 128 * 8192 -> 10
if __name__ == '__main__':
torch.manual_seed(0)
# hparams
repeat = 16
dim=8192
layers = 4
batch_size = 256 * 128
# simulate forward pass
x = torch.randn(batch_size, dim, dtype=torch.float16).cuda()
for _ in range(repeat // 2):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
end = time.time()
print(f"time: {(end - start) / repeat * 1000:.3f} ms")