bitsandbytes-rocm/speed_benchmark/speed_benchmark.py

103 lines
4.9 KiB
Python

import json
import time
import torch
import torch.nn as nn
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
for _ in range(repeat // 2):
fn()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms
if __name__ == '__main__':
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
# switch switches dim_in and dim_out
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
dim_out = dim * wm
dim_in = dim
if switch:
dim_out = dim
dim_in = wm * dim
dim_in = round(dim_in)
dim_out = round(dim_out)
# simulate forward pass
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
x_int8 = x.clone().to(torch.int8)
g_int8 = g.clone().to(torch.int8)
w_int8 = w.clone().to(torch.int8)
wt_int8 = w.t().contiguous().clone().to(torch.int8)
state_x_rowwise = x.max(dim=1)[0]
state_g_rowwise = g.max(dim=1)[0]
state_w_columnwise = w.max(dim=0)[0]
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info_json = json.dumps(info)
# TODO: change this to what you want.
with open("speed_benchmark/info.jsonl", "a") as file:
file.write(info_json + "\n")