import math import torch import time import triton import triton.language as tl from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time # TODO: autotune this better. @triton.autotune( configs=[ triton.Config({}, num_stages=1, num_warps=8), triton.Config({}, num_stages=2, num_warps=8), triton.Config({}, num_stages=4, num_warps=8), triton.Config({}, num_stages=8, num_warps=8), triton.Config({}, num_stages=1), triton.Config({}, num_stages=2), triton.Config({}, num_stages=4), triton.Config({}, num_stages=8), triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), ], key=['n_elements'] ) @triton.jit def _quantize_rowwise_nogroup( x_ptr, output_ptr, output_maxs, n_elements, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE arange = tl.arange(0, P2) offsets = block_start + arange row_mask = arange < BLOCK_SIZE x = tl.load(x_ptr + offsets, mask=row_mask) abs_x = tl.abs(x) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) output = tl.libdevice.llrint(127. * (x / max_val)) tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_maxs + pid, max_val) def quantize_rowwise_nogroup(x: torch.Tensor): output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) assert x.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (x.shape[0],) _quantize_rowwise_nogroup[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) return output, output_maxs @triton.autotune( configs=[ triton.Config({}, num_warps=1), triton.Config({}, num_warps=2), triton.Config({}, num_warps=4), triton.Config({}, num_warps=8), ], key=['n_elements'] ) @triton.jit def _experimental_quantize_rowwise_nogroup( x_ptr, output_ptr, bias_grad_ptr, output_maxs, n_elements, M: tl.constexpr, N: tl.constexpr, BLOCK_SIZE: tl.constexpr, P2: tl.constexpr, P2M: tl.constexpr, ): pid = tl.program_id(axis=0) if pid < M: block_start = pid * BLOCK_SIZE arange = tl.arange(0, P2) offsets = block_start + arange row_mask = arange < BLOCK_SIZE x = tl.load(x_ptr + offsets, mask=row_mask) abs_x = tl.abs(x) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) output = tl.libdevice.llrint(127. * (x / max_val)) tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_maxs + pid, max_val) else: real_pid = pid - M arange_new = tl.arange(0, P2M) mask_new = arange_new < M offsets_new = real_pid + arange_new * N new_x = tl.load(x_ptr + offsets_new, mask=mask_new) s = tl.sum(tl.where(mask_new, new_x, 0).to(tl.float32), axis=0) tl.store(bias_grad_ptr + real_pid, s) def experimental_quantize_rowwise_nogroup(x: torch.Tensor): M, N = x.shape output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) bias_grad = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) P2M = int(2 ** (math.ceil(math.log2(x.shape[0])))) assert x.is_cuda and output.is_cuda n_elements = output.numel() grid = lambda meta: (x.shape[0] + x.shape[1],) _experimental_quantize_rowwise_nogroup[grid](x, output, bias_grad, output_maxs, n_elements, M, N, BLOCK_SIZE=x.shape[1], P2=P2, P2M=P2M) return output, output_maxs, bias_grad if __name__ == '__main__': torch.manual_seed(0) x = torch.randn(1280, 768).cuda().to(torch.float16) out = quantize_rowwise_nogroup(x) x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8) max2 = x.abs().max(1)[0] print(torch.allclose(out[1], max2)) print( (x_real == out[0]).float().mean() ) # for i in range(x.shape[0]): # print( (x_real[i, :] == out[0][i, :]).float().mean() ) # print(out[0]) # print(x_real) # import pdb; pdb.set_trace() # print(out[2]) # print(out[2][:10]) sums = x.sum(dim=0) #print(sums[:10]) #print( (sums == out[2]).float().mean() ) import pdb; pdb.set_trace() # import pdb; pdb.set_trace() # exit() # repeat = 16 # for _ in range(8): # out = quantize_rowwise_nogroup(x) # triton_graph = torch.cuda.CUDAGraph() # with torch.cuda.graph(triton_graph): # out = quantize_rowwise_nogroup(x) # triton_graph.replay() # torch.cuda.synchronize() # start = time.time() # for _ in range(repeat): # triton_graph.replay() # torch.cuda.synchronize() # end = time.time() # print(out[0]) # print(out[1]) # print(x / x.abs().max(dim=1, keepdim=True)[0]) # max1 = out[1] # max2 = x.abs().max(1)[0] # print(max1, max2) # print(torch.allclose(max1, max2)) #print(f"time: {(end - start) / repeat * 1000:.3f} ms")