import json import time import torch import torch.nn as nn from bitsandbytes.triton.quantize_rowwise import quantize_rowwise from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. def get_time(k, fn, info_dict): for _ in range(repeat // 2): fn() torch.cuda.synchronize() start = time.time() for _ in range(repeat): fn() torch.cuda.synchronize() end = time.time() ms = (end - start) / repeat * 1000 print(f"time {k}: {ms:.3f} ms") info_dict[k] = ms if __name__ == '__main__': torch.manual_seed(0) wm = 4 for dim in [1024, 1280, 1408, 1664, 2048, 4096]: # note "batch_size" is actually "batch_size * embed_dim", which is why it's large for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: # switch switches dim_in and dim_out for switch in [False, True]: # hparams repeat = 64 batch_size = batch_size dim_out = dim * wm dim_in = dim if switch: dim_out = dim dim_in = wm * dim dim_in = round(dim_in) dim_out = round(dim_out) # simulate forward pass x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() x_int8 = x.clone().to(torch.int8) g_int8 = g.clone().to(torch.int8) w_int8 = w.clone().to(torch.int8) wt_int8 = w.t().contiguous().clone().to(torch.int8) state_x_rowwise = x.max(dim=1)[0] state_g_rowwise = g.max(dim=1)[0] state_w_columnwise = w.max(dim=0)[0] state_w_rowwise = w.max(dim=1)[0] state_w_global = w.max() info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} get_time('standard_fwd', lambda : x.matmul(w.t()), info) get_time('standard_gw', lambda : g.t().matmul(x), info) get_time('standard_gx', lambda : g.matmul(w), info) get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) get_time('w_quantize_global', lambda : quantize_global(w), info) get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] print('TOTAL STANDARD', time_standard) print('TOTAL ROWWISE', time_rowwise) print('TOTAL GLOBAL', time_global) print('speedup', -100*(time_global - time_standard)/time_standard) info['time_standard'] = time_standard info['time_rowwise'] = time_rowwise info['time_global'] = time_global info_json = json.dumps(info) # TODO: change this to what you want. with open("speed_benchmark/info.jsonl", "a") as file: file.write(info_json + "\n")