diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 8be7674..8e3a598 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -3,3 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed +from .triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py new file mode 100644 index 0000000..9fe0b69 --- /dev/null +++ b/bitsandbytes/nn/triton_based_modules.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +import time + +from .triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup +from .triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose +from .triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias +from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose +from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias +from .triton_utils.v0.fused_gelu_quantize import quantize_rowwise_nogroup_gelu, quantize_rowwise_nogroup_back_gelu + +class _switchback(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + + X = X_3D.view(-1, X_3D.size(-1)) + + ctx.save_for_backward = X, W + X_int8, state_X = quantize_rowwise_nogroup(X) + W_int8, state_W = quantize_rowwise_nogroup(W) + return int8_matmul_rowwise_dequantize_bias( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + X, W = ctx.save_for_backward + + G = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X = grad_W = grad_bias = None + + if ctx.needs_input_grad[0]: + G_int8, state_G = quantize_rowwise_nogroup(G) + W_int8, state_W = quantize_columnwise_nogroup_transpose(W) + grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + grad_W = torch.matmul(G.t(), X.to(G.dtype)) + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + + return grad_X, grad_W, grad_bias + +class SwitchBackLinear(nn.Linear): + + def prepare_for_eval(self): + state_W = self.weight.abs().max(dim=1, keepdim=True)[0] + W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8) + state_W = state_W.squeeze() + + self.register_buffer("W_int8", W_int8) + self.register_buffer("state_W", state_W) + + del self.weight + + def forward(self, x): + if self.training: + return _switchback.apply(x, self.weight, self.bias) + else: + if not hasattr(self, "state_W"): + self.prepare_for_eval() + X = x.view(-1, x.size(-1)) + X_int8, state_X = quantize_rowwise_nogroup(X) + return int8_matmul_rowwise_dequantize_bias( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) + + +class _switchback_global(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + + X = X_3D.view(-1, X_3D.size(-1)) + + X_int8, state_X = quantize_rowwise_nogroup(X) + W_int8, state_W = quantize_global(W) + ctx.save_for_backward = X, W + return int8_matmul_mixed_dequanitze_bias( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + + G = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X = grad_W = grad_bias = None + + X, W = ctx.save_for_backward + if ctx.needs_input_grad[0]: + G_int8, state_G = quantize_rowwise_nogroup(G) + W_int8, state_W = quantize_global_transpose(W) + grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + grad_W = torch.matmul(G.t(), X.to(G.dtype)) + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + + return grad_X, grad_W, grad_bias + + + +class SwitchBackGlobalLinear(nn.Linear): + + def prepare_for_eval(self): + state_W = self.weight.abs().max() + W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8) + + self.register_buffer("W_int8", W_int8) + self.register_buffer("state_W", state_W) + + del self.weight + + def forward(self, x): + if self.training: + return _switchback_global.apply(x, self.weight, self.bias) + else: + if not hasattr(self, "state_W"): + self.prepare_for_eval() + X = x.view(-1, x.size(-1)) + X_int8, state_X = quantize_rowwise_nogroup(X) + return int8_matmul_mixed_dequanitze_bias( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) + + + + +class LinearFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None): + X = input.view(-1, input.size(-1)) + + ctx.save_for_backward(X, weight, bias) + output = input.matmul(weight.t()) + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output.view(*input.size()[:-1], -1) + + @staticmethod + def backward(ctx, grad_output_3D): + input, weight, bias = ctx.saved_tensors + + grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1)) + + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1) + if ctx.needs_input_grad[1]: + grad_weight = grad_output.t().matmul(input.to(grad_output.dtype)) + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + +class MyLinear(nn.Linear): + + def forward(self, x): + return LinearFunction.apply(x, self.weight, self.bias) + + + + +class _switchback_mlp(torch.autograd.Function): + + + @staticmethod + def forward(ctx, X_3D, W1, B1, W2, B2): + + X1 = X_3D.view(-1, X_3D.size(-1)) + + X1_int8, state_X1 = quantize_rowwise_nogroup(X1) + W1_int8, state_W1 = quantize_global(W1) + + X2_pre = int8_matmul_mixed_dequanitze_bias( + X1_int8, W1_int8.t(), state_X1, state_W1, B1 + ) + + # X2_v1 = torch.nn.functional.gelu(X2) + # X2_int8, state_X2, = quantize_rowwise_nogroup(X2_v1) + X2_int8, state_X2, X2 = quantize_rowwise_nogroup_gelu(X2_pre) + + W2_int8, state_W2 = quantize_global(W2) + + out = int8_matmul_mixed_dequanitze_bias( + X2_int8, W2_int8.t(), state_X2, state_W2, B2 + ) + + ctx.save_for_backward = X1, W1, X2, X2_pre, W2 + + return out.view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + + G2 = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X1 = grad_W1 = grad_B1 = grad_W2 = grad_B2 = None + + X1, W1, X2, X2_pre, W2 = ctx.save_for_backward + + G2_int8, state_G2 = quantize_rowwise_nogroup(G2) + W2_int8, state_W2 = quantize_global_transpose(W2) + + G1 = int8_matmul_mixed_dequanitze(G2_int8, W2_int8.t(), state_G2, state_W2).view( + *G_3D.size()[:-1], -1 + ) + + grad_W2 = torch.matmul(G2.t(), X2.to(G2.dtype)) + grad_B2 = G2.sum(dim=0) + + G1_int8, state_G1, G1 = quantize_rowwise_nogroup_back_gelu(G1, X2_pre) + + if ctx.needs_input_grad[0]: + + W1_int8, state_W1 = quantize_global_transpose(W1) + grad_X1 = int8_matmul_mixed_dequanitze(G1_int8, W1_int8.t(), state_G1, state_W1).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + grad_W1 = torch.matmul(G1.t(), X1.to(G1.dtype)) + if ctx.needs_input_grad[2]: + grad_B1 = G1.sum(dim=0) + + return grad_X1, grad_W1, grad_B1, grad_W2, grad_B2 + + +class SwitchBackGlobalMLP(nn.Module): + + + def __init__(self, dim_in, dim_hidden): + super().__init__() + self.linear1 = nn.Linear(dim_in, dim_hidden) + self.linear2 = nn.Linear(dim_hidden, dim_in) + + + def forward(self, x): + return _switchback_mlp.apply(x, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias) + \ No newline at end of file diff --git a/bitsandbytes/nn/triton_utils/v0/__init__.py b/bitsandbytes/nn/triton_utils/v0/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bitsandbytes/nn/triton_utils/v0/fused_gelu_quantize.py b/bitsandbytes/nn/triton_utils/v0/fused_gelu_quantize.py new file mode 100644 index 0000000..50451cb --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/fused_gelu_quantize.py @@ -0,0 +1,190 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +tl.libdevice + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_rowwise_nogroup_gelu( + x_ptr, + output_ptr, + output_maxs, + output_fp16, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + cdf = 0.5 * (1.0 + tl.libdevice.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x))) + x_new = x * cdf + + tl.store(output_fp16 + offsets, x_new, mask=row_mask) + + abs_x = tl.abs(x_new) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x_new / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_rowwise_nogroup_gelu(x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise_nogroup_gelu[grid](x, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs, output_fp16 + + + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_rowwise_nogroup_back_gelu( + x_ptr, + in_ptr, + output_ptr, + output_maxs, + output_fp16, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x_out = tl.load(x_ptr + offsets, mask=row_mask) + x_in = tl.load(in_ptr + offsets, mask=row_mask) + + cdf = 0.5 * (1.0 + tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in))) + intermediate = tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in)) + dcdf = 0.5 * (0.7978845608 + 0.1070322243 * x_in * x_in) * (1 - intermediate * intermediate) + x = x_out * (cdf + x_in * dcdf) + + tl.store(output_fp16 + offsets, x, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_rowwise_nogroup_back_gelu(x: torch.Tensor, y : torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise_nogroup_back_gelu[grid](x, y, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs, output_fp16 + + + +# if __name__ == '__main__': +# torch.manual_seed(0) + +# x = torch.randn(1280, 768).cuda().to(torch.float16) +# out = quantize_rowwise_nogroup(x) + +# x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8) +# max2 = x.abs().max(1)[0] + +# print(torch.allclose(out[1], max2)) +# print( (x_real == out[0]).float().mean() ) + +# # for i in range(x.shape[0]): +# # print( (x_real[i, :] == out[0][i, :]).float().mean() ) + +# # print(out[0]) +# # print(x_real) +# # import pdb; pdb.set_trace() +# # print(out[2]) +# # print(out[2][:10]) +# sums = x.sum(dim=0) +# #print(sums[:10]) +# #print( (sums == out[2]).float().mean() ) + +# import pdb; pdb.set_trace() +# # import pdb; pdb.set_trace() +# # exit() + +# # repeat = 16 + +# # for _ in range(8): +# # out = quantize_rowwise_nogroup(x) + +# # triton_graph = torch.cuda.CUDAGraph() +# # with torch.cuda.graph(triton_graph): +# # out = quantize_rowwise_nogroup(x) + +# # triton_graph.replay() + +# # torch.cuda.synchronize() +# # start = time.time() +# # for _ in range(repeat): +# # triton_graph.replay() +# # torch.cuda.synchronize() +# # end = time.time() + +# # print(out[0]) +# # print(out[1]) +# # print(x / x.abs().max(dim=1, keepdim=True)[0]) +# # max1 = out[1] +# # max2 = x.abs().max(1)[0] +# # print(max1, max2) +# # print(torch.allclose(max1, max2)) + +# #print(f"time: {(end - start) / repeat * 1000:.3f} ms") diff --git a/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py b/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py new file mode 100644 index 0000000..2ecfcb8 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/int8_matmul_mixed_dequanitze.py @@ -0,0 +1,276 @@ +import torch + +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr) + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_mixed_dequanitze(a, b, state_x, state_w): + device = a.device + divfactor = 1. / (127. * 127.) + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c + + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr) + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_mixed_dequanitze_bias(a, b, state_x, state_w, bias): + device = a.device + divfactor = 1. / (127. * 127.) + has_bias = 0 if bias is None else 1 + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel_bias[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py new file mode 100644 index 0000000..fa0b516 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize.py @@ -0,0 +1,149 @@ +import torch + +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr + rbn)[None, :] + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_rowwise_dequantize(a, b, state_x, state_w): + divfactor = 1. / (127. * 127.) + + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize_bias.py b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize_bias.py new file mode 100644 index 0000000..5f524c1 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/int8_matmul_rowwise_dequantize_bias.py @@ -0,0 +1,160 @@ +import torch + +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr + rbn)[None, :] + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def int8_matmul_rowwise_dequantize_bias(a, b, state_x, state_w, bias): + + #print(bias) + divfactor = 1. / (127. * 127.) + + has_bias = 0 if bias is None else 1 + + if bias is not None: + bias = bias.contiguous() + + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_nogroup_transpose.py b/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_nogroup_transpose.py new file mode 100644 index 0000000..fa3a9a9 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/quantize_columnwise_nogroup_transpose.py @@ -0,0 +1,122 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_columnwise_nogroup_transpose( + x_ptr, + output_ptr, + output_maxs, + n_elements, + M : tl.constexpr, N : tl.constexpr, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid + p2_arange = tl.arange(0, P2) + p2_arange_mask = p2_arange < M + arange = p2_arange * N + offsets = block_start + arange + x = tl.load(x_ptr + offsets, mask=p2_arange_mask) + abs_x = tl.abs(x) + max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + + new_start = pid * M + new_offsets = new_start + p2_arange + tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_columnwise_nogroup_transpose(x: torch.Tensor): + M, N = x.shape + output = torch.empty(N, M, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(M)))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_columnwise_nogroup_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) + return output, output_maxs + + + +if __name__ == '__main__': + torch.manual_seed(0) + + x = torch.randn(1280, 768).cuda().to(torch.float16) + out = quantize_columnwise_nogroup_transpose(x) + + + x_real = x.t().float() + x_real_int8 = (127. * x_real / x_real.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8) + maxs = x_real.abs().max(dim=1, keepdim=True)[0].half() + + #print(out[0][2,:]) + + print((out[0] == x_real_int8).float().mean()) + print((out[1] == maxs[:, 0]).float().mean()) + + # print(out[0]) + # print(out[1]) + + # print(out[0][2,:]) + # print(x_real[2, :]) + + # print((out[0] != x_real).nonzero()) + + #import pdb; pdb.set_trace() + # repeat = 16 + + # for _ in range(8): + # out = quantize_columnwise_nogroup_transpose(x) + + # triton_graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(triton_graph): + # out = quantize_columnwise_nogroup_transpose(x) + + # triton_graph.replay() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # triton_graph.replay() + # torch.cuda.synchronize() + # end = time.time() + + # print(out[0]) + # print(out[1]) + # print(x / x.abs().max(dim=0, keepdim=True)[0]) + # x_real = (127 * (x / x.abs().max(dim=0, keepdim=True)[0])).round().to(torch.int8) + # max1 = out[1] + # max2 = x.abs().max(0)[0] + # print(max1, max2) + # import pdb; pdb.set_trace() + # print(torch.allclose(max1, max2)) + + # print(f"time: {(end - start) / repeat * 1000:.3f} ms") diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_global.py b/bitsandbytes/nn/triton_utils/v0/quantize_global.py new file mode 100644 index 0000000..6d23aac --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/quantize_global.py @@ -0,0 +1,130 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), + triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), + + ], + key=['n_elements'] +) +@triton.jit +def _quantize_global( + x_ptr, + absmax_inv_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + output = tl.libdevice.llrint(127. * (x * absmax_inv)) + tl.store(output_ptr + offsets, output, mask=mask) + +def quantize_global(x: torch.Tensor): + absmax = x.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_global[grid](x, absmax_inv, output, n_elements) + return output, absmax + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + + # ... + ], + key=['M', 'N'] +) +@triton.jit +def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, + BLOCK_M : tl.constexpr, + BLOCK_N : tl.constexpr, + GROUP_M : tl.constexpr): + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // group_size + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) + mask = (rm < M)[:, None] & (rn < N)[None, :] + a = tl.load(A, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + + # rematerialize to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + + output = tl.libdevice.llrint(127. * (a * absmax_inv)) + + tl.store(B, output, mask=mask) + +def quantize_global_transpose(input): + absmax = input.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + M, N = input.shape + out = torch.empty(N, M, device='cuda', dtype=torch.int8) + + assert out.size(0) == N and out.size(1) == M + assert input.stride(0) == 1 or input.stride(1) == 1 + assert out.stride(0) == 1 or out.stride(1) == 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) + return out, absmax + +if __name__ == '__main__': + + + w = torch.randn(768, 1280).cuda().to(torch.float16) + W_int8, state_w = quantize_global(w) + r_state_w = w.abs().max() + r_W_int8 = ((127 * w.float()) / state_w).round().to(torch.int8) + print((r_W_int8 == W_int8).float().mean()) + + # print(r_W_int8) + # print(W_int8) + exit() + repeat = 16 + + for _ in range(8): + out = quantize_global(w) + + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph): + out = quantize_global(w) + + triton_graph.replay() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + triton_graph.replay() + torch.cuda.synchronize() + end = time.time() + + print(f"time: {(end - start) / repeat * 1000:.3f} ms") diff --git a/bitsandbytes/nn/triton_utils/v0/quantize_rowwise_nogroup.py b/bitsandbytes/nn/triton_utils/v0/quantize_rowwise_nogroup.py new file mode 100644 index 0000000..7e63f74 --- /dev/null +++ b/bitsandbytes/nn/triton_utils/v0/quantize_rowwise_nogroup.py @@ -0,0 +1,174 @@ +import math +import torch +import time +import triton +import triton.language as tl +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +# TODO: autotune this better. +@triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _quantize_rowwise_nogroup( + x_ptr, + output_ptr, + output_maxs, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + +def quantize_rowwise_nogroup(x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise_nogroup[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] +) +@triton.jit +def _experimental_quantize_rowwise_nogroup( + x_ptr, + output_ptr, + bias_grad_ptr, + output_maxs, + n_elements, + M: tl.constexpr, N: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, + P2M: tl.constexpr, +): + pid = tl.program_id(axis=0) + if pid < M: + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + else: + real_pid = pid - M + arange_new = tl.arange(0, P2M) + mask_new = arange_new < M + offsets_new = real_pid + arange_new * N + new_x = tl.load(x_ptr + offsets_new, mask=mask_new) + s = tl.sum(tl.where(mask_new, new_x, 0).to(tl.float32), axis=0) + tl.store(bias_grad_ptr + real_pid, s) + +def experimental_quantize_rowwise_nogroup(x: torch.Tensor): + M, N = x.shape + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + bias_grad = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + P2M = int(2 ** (math.ceil(math.log2(x.shape[0])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0] + x.shape[1],) + _experimental_quantize_rowwise_nogroup[grid](x, output, bias_grad, output_maxs, n_elements, M, N, BLOCK_SIZE=x.shape[1], P2=P2, P2M=P2M) + return output, output_maxs, bias_grad + + +if __name__ == '__main__': + torch.manual_seed(0) + + x = torch.randn(1280, 768).cuda().to(torch.float16) + out = quantize_rowwise_nogroup(x) + + x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8) + max2 = x.abs().max(1)[0] + + print(torch.allclose(out[1], max2)) + print( (x_real == out[0]).float().mean() ) + + # for i in range(x.shape[0]): + # print( (x_real[i, :] == out[0][i, :]).float().mean() ) + + # print(out[0]) + # print(x_real) + # import pdb; pdb.set_trace() + # print(out[2]) + # print(out[2][:10]) + sums = x.sum(dim=0) + #print(sums[:10]) + #print( (sums == out[2]).float().mean() ) + + import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() + # exit() + + # repeat = 16 + + # for _ in range(8): + # out = quantize_rowwise_nogroup(x) + + # triton_graph = torch.cuda.CUDAGraph() + # with torch.cuda.graph(triton_graph): + # out = quantize_rowwise_nogroup(x) + + # triton_graph.replay() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # triton_graph.replay() + # torch.cuda.synchronize() + # end = time.time() + + # print(out[0]) + # print(out[1]) + # print(x / x.abs().max(dim=1, keepdim=True)[0]) + # max1 = out[1] + # max2 = x.abs().max(1)[0] + # print(max1, max2) + # print(torch.allclose(max1, max2)) + + #print(f"time: {(end - start) / repeat * 1000:.3f} ms") diff --git a/tests/triton_tests/attn_decomp.py b/tests/triton_tests/attn_decomp.py new file mode 100644 index 0000000..9e8ed28 --- /dev/null +++ b/tests/triton_tests/attn_decomp.py @@ -0,0 +1,363 @@ + +import torch +import json +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +import time + +# class AttentionOld(torch.nn.Module): +# def __init__( +# self, +# dim, +# num_heads=8, +# qkv_bias=True, +# scaled_cosine=False, +# scale_heads=False, +# attn_drop=0., +# proj_drop=0., +# linear_module=torch.nn.Linear, +# ): +# super().__init__() +# self.scaled_cosine = scaled_cosine +# self.scale_heads = scale_heads +# assert dim % num_heads == 0, 'dim should be divisible by num_heads' +# self.num_heads = num_heads +# self.head_dim = dim // num_heads +# self.scale = self.head_dim ** -0.5 + +# self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias) + +# self.attn_drop = torch.nn.Dropout(attn_drop) +# if self.scale_heads: +# self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1))) +# else: +# self.head_scale = None +# self.out_proj = linear_module(dim, dim) +# self.out_drop = torch.nn.Dropout(proj_drop) + +# def forward(self, x, attn_mask = None): +# L, N, C = x.shape + +# q, k, v = self.in_proj_linear(x).chunk(3, dim=-1) + +# q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) +# k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) +# v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) + +# q = q * self.scale +# attn = torch.bmm(q, k.transpose(-1, -2)) + +# if attn_mask is not None: +# if attn_mask.dtype == torch.bool: +# new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) +# new_attn_mask.masked_fill_(attn_mask, float("-inf")) +# attn_mask = new_attn_mask +# attn += attn_mask + +# attn = attn.softmax(dim=-1) +# attn = self.attn_drop(attn) + +# x = torch.bmm(attn, v) +# x = x.transpose(0, 1).reshape(L, N, C) + +# x = self.out_proj(x) +# x = self.out_drop(x) +# return x + +class Attention(torch.nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + scaled_cosine=False, + scale_heads=False, + attn_drop=0., + proj_drop=0., + linear_module=torch.nn.Linear, + ): + super().__init__() + self.scaled_cosine = scaled_cosine + self.scale_heads = scale_heads + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.ln = torch.nn.LayerNorm(dim) + + self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias) + + self.attn_drop = torch.nn.Dropout(attn_drop) + if self.scale_heads: + self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1))) + else: + self.head_scale = None + self.out_proj = linear_module(dim, dim) + self.out_drop = torch.nn.Dropout(proj_drop) + + def forward(self, x, attn_mask = None): + q, k, v = self.in_proj_linear(self.ln(x)).chunk(3, dim=-1) + x = torch.compile(torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)) + x = self.out_proj(x) + return x + +if __name__ == '__main__': + + + for dim in [1024, 1280, 1408, 1664, 2048]: + for batch in [2**14, 2**15, 2**16, 2**17]: + + # if dim != 4096 or batch != 2**17: + # continue + + x1 = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) + qu = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) + ke = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) + va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True) + + standard = Attention(dim).cuda() + my_standard = Attention(dim, linear_module=MyLinear).cuda() + sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda() + standard_compiled = torch.compile(standard) + ln_model = torch.nn.Sequential( + torch.nn.LayerNorm(dim), + torch.nn.LayerNorm(dim), + ).cuda() + ln_model_compiled = torch.compile( + ln_model + ) + gelu_model = torch.nn.Sequential( + torch.nn.GELU(), + ).cuda() + gelu_model_compiled = torch.compile( + gelu_model + ) + + + print('Model part 2') + + repeat = 32 + + info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim} + + + k = 'attn' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va) + ((2 ** 16) * out_attn).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va) + ((2 ** 16) * out_attn).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'ln' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out = ln_model(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out = ln_model(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'ln_compiled' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out = ln_model_compiled(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out = ln_model_compiled(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'gelu' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out = gelu_model(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out = gelu_model(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'gelu_compiled' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out = gelu_model_compiled(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out = gelu_model_compiled(x1) + ((2 ** 16) * out).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + x1.grad.zero_() + + k = 'standard' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_standard = standard(x1) + ((2 ** 16) * out_standard).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_standard = standard(x1) + ((2 ** 16) * out_standard).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'my_standard' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_my_standard = my_standard(x1) + ((2 ** 16) * out_my_standard).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_my_standard = my_standard(x1) + ((2 ** 16) * out_my_standard).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + # + # + + x1.grad.zero_() + + + k = 'standard_compiled' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_standard_compiled = standard_compiled(x1) + ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_standard_compiled = standard_compiled(x1) + ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + + k = 'sb' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + info_json = json.dumps(info) + + + with open("tests/triton_tests/attn_info_ln.jsonl", "a") as file: + file.write(info_json + "\n") + + + #exit() + + # err_fused = (out_standard - out_fused).abs().mean() + # err_sb = (out_standard - out_sb).abs().mean() + # print('OUT', err_fused, err_sb) + + # err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean() + # err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean() + + # print('GW2', err_fused, err_sb) + + # err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean() + # err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean() + + # print('GW1', err_fused, err_sb) + + # err_fused = (x1.grad - x2.grad).abs().mean() + # err_sb = (x1.grad - x3.grad).abs().mean() + + # print('GX1', err_fused, err_sb) + + # import pdb; pdb.set_trace() + + + # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file diff --git a/tests/triton_tests/attn_info_ln.jsonl b/tests/triton_tests/attn_info_ln.jsonl new file mode 100644 index 0000000..c2f239b --- /dev/null +++ b/tests/triton_tests/attn_info_ln.jsonl @@ -0,0 +1,20 @@ +{"repeat": 32, "batch_size": 16384, "dim": 1024, "attn": 2.1414458751678467, "ln": 1.6365647315979004, "ln_compiled": 1.799367368221283, "gelu": 1.0930374264717102, "gelu_compiled": 1.094818115234375, "standard": 4.159651696681976, "my_standard": 4.696495831012726, "standard_compiled": 3.675594925880432, "sb": 4.1465312242507935} +{"repeat": 32, "batch_size": 32768, "dim": 1024, "attn": 4.100345075130463, "ln": 3.1594187021255493, "ln_compiled": 3.437422215938568, "gelu": 2.109348773956299, "gelu_compiled": 2.11450457572937, "standard": 7.706902921199799, "my_standard": 8.799396455287933, "standard_compiled": 6.735652685165405, "sb": 7.66376405954361} +{"repeat": 32, "batch_size": 65536, "dim": 1024, "attn": 7.953710854053497, "ln": 6.236426532268524, "ln_compiled": 6.746955215930939, "gelu": 4.164382815361023, "gelu_compiled": 4.171714186668396, "standard": 14.894917607307434, "my_standard": 17.042435705661774, "standard_compiled": 12.985721230506897, "sb": 14.6140456199646} +{"repeat": 32, "batch_size": 131072, "dim": 1024, "attn": 15.638880431652069, "ln": 12.333884835243225, "ln_compiled": 13.272866606712341, "gelu": 8.228793740272522, "gelu_compiled": 8.243747055530548, "standard": 29.425136744976044, "my_standard": 35.08377820253372, "standard_compiled": 25.69487690925598, "sb": 28.760001063346863} +{"repeat": 32, "batch_size": 16384, "dim": 1280, "attn": 2.627238631248474, "ln": 2.0098239183425903, "ln_compiled": 2.4197474122047424, "gelu": 1.3455823063850403, "gelu_compiled": 1.35069340467453, "standard": 5.554787814617157, "my_standard": 6.2290579080581665, "standard_compiled": 5.132324993610382, "sb": 5.4178386926651} +{"repeat": 32, "batch_size": 32768, "dim": 1280, "attn": 5.0596073269844055, "ln": 3.903590142726898, "ln_compiled": 4.719957709312439, "gelu": 2.6203468441963196, "gelu_compiled": 2.627365291118622, "standard": 10.546617209911346, "my_standard": 11.850126087665558, "standard_compiled": 9.685918688774109, "sb": 10.088451206684113} +{"repeat": 32, "batch_size": 65536, "dim": 1280, "attn": 9.845800697803497, "ln": 7.711298763751984, "ln_compiled": 9.292080998420715, "gelu": 5.172915756702423, "gelu_compiled": 5.180932581424713, "standard": 21.371990442276, "my_standard": 23.921720683574677, "standard_compiled": 19.669152796268463, "sb": 20.267993211746216} +{"repeat": 32, "batch_size": 131072, "dim": 1280, "attn": 19.375711679458618, "ln": 15.333592891693115, "ln_compiled": 18.245264887809753, "gelu": 10.264746844768524, "gelu_compiled": 10.283775627613068, "standard": 41.79700464010239, "my_standard": 45.84744572639465, "standard_compiled": 38.35208714008331, "sb": 38.35364431142807} +{"repeat": 32, "batch_size": 16384, "dim": 1408, "attn": 2.9110386967658997, "ln": 2.1998360753059387, "ln_compiled": 2.581551671028137, "gelu": 1.4731436967849731, "gelu_compiled": 1.478634774684906, "standard": 6.764143705368042, "my_standard": 7.331632077693939, "standard_compiled": 6.24605268239975, "sb": 6.325609982013702} +{"repeat": 32, "batch_size": 32768, "dim": 1408, "attn": 5.542516708374023, "ln": 4.289716482162476, "ln_compiled": 5.065307021141052, "gelu": 2.8742849826812744, "gelu_compiled": 2.882353961467743, "standard": 12.749537825584412, "my_standard": 13.79828155040741, "standard_compiled": 11.728867888450623, "sb": 11.642806231975555} +{"repeat": 32, "batch_size": 65536, "dim": 1408, "attn": 10.80312579870224, "ln": 8.471302688121796, "ln_compiled": 9.96796041727066, "gelu": 5.681410431861877, "gelu_compiled": 5.6905597448349, "standard": 25.19702911376953, "my_standard": 27.226239442825317, "standard_compiled": 23.22910726070404, "sb": 22.682294249534607} +{"repeat": 32, "batch_size": 131072, "dim": 1408, "attn": 21.284908056259155, "ln": 16.85701310634613, "ln_compiled": 19.643358886241913, "gelu": 11.292420327663422, "gelu_compiled": 11.314474046230316, "standard": 50.06787180900574, "my_standard": 54.29378151893616, "standard_compiled": 44.58653926849365, "sb": 45.359253883361816} +{"repeat": 32, "batch_size": 16384, "dim": 1664, "attn": 3.382459282875061, "ln": 2.6206374168395996, "ln_compiled": 2.9666870832443237, "gelu": 1.7263293266296387, "gelu_compiled": 1.7317384481430054, "standard": 8.414775133132935, "my_standard": 9.117811918258667, "standard_compiled": 7.7542513608932495, "sb": 7.70898163318634} +{"repeat": 32, "batch_size": 32768, "dim": 1664, "attn": 6.468378007411957, "ln": 5.125559866428375, "ln_compiled": 5.791269242763519, "gelu": 3.3864825963974, "gelu_compiled": 3.3920034766197205, "standard": 16.016244888305664, "my_standard": 17.25083589553833, "standard_compiled": 14.60808515548706, "sb": 14.347739517688751} +{"repeat": 32, "batch_size": 65536, "dim": 1664, "attn": 12.645229697227478, "ln": 10.13532280921936, "ln_compiled": 11.427387595176697, "gelu": 6.6957250237464905, "gelu_compiled": 6.711684167385101, "standard": 31.792201101779938, "my_standard": 34.31189805269241, "standard_compiled": 29.10037338733673, "sb": 28.3128023147583} +{"repeat": 32, "batch_size": 131072, "dim": 1664, "attn": 24.970605969429016, "ln": 20.182937383651733, "ln_compiled": 22.7489173412323, "gelu": 13.326868414878845, "gelu_compiled": 13.345755636692047, "standard": 63.46555054187775, "my_standard": 70.19880414009094, "standard_compiled": 56.40875548124313, "sb": 56.22846633195877} +{"repeat": 32, "batch_size": 16384, "dim": 2048, "attn": 4.080049693584442, "ln": 3.2655522227287292, "ln_compiled": 3.3329352736473083, "gelu": 2.108432352542877, "gelu_compiled": 2.114713191986084, "standard": 11.370822787284851, "my_standard": 12.234866619110107, "standard_compiled": 10.377615690231323, "sb": 10.209612548351288} +{"repeat": 32, "batch_size": 32768, "dim": 2048, "attn": 7.74645060300827, "ln": 6.418220698833466, "ln_compiled": 6.55733048915863, "gelu": 4.163652658462524, "gelu_compiled": 4.171028733253479, "standard": 21.39316499233246, "my_standard": 23.04024249315262, "standard_compiled": 19.431106746196747, "sb": 18.732361495494843} +{"repeat": 32, "batch_size": 65536, "dim": 2048, "attn": 15.235155820846558, "ln": 12.684382498264313, "ln_compiled": 12.895286083221436, "gelu": 8.228868246078491, "gelu_compiled": 8.242718875408173, "standard": 42.55136102437973, "my_standard": 45.82635313272476, "standard_compiled": 38.663335144519806, "sb": 36.76284849643707} +{"repeat": 32, "batch_size": 131072, "dim": 2048, "attn": 30.24454414844513, "ln": 25.25731921195984, "ln_compiled": 25.67601203918457, "gelu": 16.384944319725037, "gelu_compiled": 16.409948468208313, "standard": 84.26841348409653, "my_standard": 91.10662341117859, "standard_compiled": 76.89539343118668, "sb": 71.73164188861847} diff --git a/tests/triton_tests/full_matrix_decomp.py b/tests/triton_tests/full_matrix_decomp.py new file mode 100644 index 0000000..de37b95 --- /dev/null +++ b/tests/triton_tests/full_matrix_decomp.py @@ -0,0 +1,353 @@ +import json + +import time +import torch +import torch.nn as nn +import bitsandbytes.nn as bnn +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear + +from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup +from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose +from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias +from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose +from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias + +# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. +# not that big of an issue. + +def get_time_standard_fwd(k, v): + + x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() + g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() + + ##### time matmul 1 + for _ in range(repeat // 2): + g.t().matmul(x) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + g.t().matmul(x) + + torch.cuda.synchronize() + end = time.time() + print(f"time {k}: {(end - start) / repeat * 1000:.3f} ms") + return (end - start) / repeat * 1000 + +if __name__ == '__main__': + torch.manual_seed(0) + #for (dim, wm) in [(1024, 4), (1280, 4), (1408, 4.3637), (1664, 4.9231), (2048, 4), (4096, 4), (8096, 4)] + for (dim, wm) in [(1408, 4), (1664, 4),]: + + for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: + #for batch_size in [256*256, 256*512]: + + for switch in [False, True]: + + + # hparams + repeat = 64 + batch_size = batch_size + dim_out = dim * wm + dim_in = dim + if switch: + dim_out = dim + dim_in = wm * dim + + dim_in = round(dim_in) + dim_out = round(dim_out) + + + # simulate forward pass + x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() + g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() + w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() + + x_int8 = x.clone().to(torch.int8) + g_int8 = g.clone().to(torch.int8) + w_int8 = w.clone().to(torch.int8) + wt_int8 = w.t().contiguous().clone().to(torch.int8) + state_x_rowwise = x.max(dim=1)[0] + state_g_rowwise = g.max(dim=1)[0] + state_w_columnwise = w.max(dim=0)[0] + state_w_rowwise = w.max(dim=1)[0] + state_w_global = w.max() + + info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} + + k = 'standard_fwd' + for _ in range(repeat // 2): + x.matmul(w.t()) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + x.matmul(w.t()) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'standard_gw' + for _ in range(repeat // 2): + g.t().matmul(x) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + g.t().matmul(x) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'standard_gx' + for _ in range(repeat // 2): + g.matmul(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + g.matmul(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + + k = 'rowwise_fwd' + for _ in range(repeat // 2): + int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'rowwise_bwd' + for _ in range(repeat // 2): + int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'global_fwd' + for _ in range(repeat // 2): + int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'global_bwd' + for _ in range(repeat // 2): + int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'x_quantize_rowwise' + for _ in range(repeat // 2): + quantize_rowwise_nogroup(x) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_rowwise_nogroup(x) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'g_quantize_rowwise' + for _ in range(repeat // 2): + quantize_rowwise_nogroup(g) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_rowwise_nogroup(g) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'w_quantize_rowwise' + for _ in range(repeat // 2): + quantize_rowwise_nogroup(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_rowwise_nogroup(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'w_quantize_colwise_transpose' + for _ in range(repeat // 2): + quantize_columnwise_nogroup_transpose(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_columnwise_nogroup_transpose(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'w_quantize_global' + for _ in range(repeat // 2): + quantize_global(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_global(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + k = 'w_quantize_global_transpose' + for _ in range(repeat // 2): + quantize_global_transpose(w) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_global_transpose(w) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + k = 'cast_x' + for _ in range(repeat // 2): + newx = x.to(torch.int8) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + newx = x.to(torch.int8) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + + k = 'cast_g' + for _ in range(repeat // 2): + newx = g.to(torch.int8) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + newx = g.to(torch.int8) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + + k = 'cast_w' + for _ in range(repeat // 2): + newx = w.to(torch.int8) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + newx = w.to(torch.int8) + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] + time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] + time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] + + print('TOTAL STANDARD', time_standard) + print('TOTAL ROWWISE', time_rowwise) + print('TOTAL GLOBAL', time_global) + + print('speedup', -100*(time_global - time_standard)/time_standard) + + info['time_standard'] = time_standard + info['time_rowwise'] = time_rowwise + info['time_global'] = time_global + + + + info_json = json.dumps(info) + + + with open("tests/triton_tests/info.jsonl", "a") as file: + file.write(info_json + "\n") \ No newline at end of file diff --git a/tests/triton_tests/info.jsonl b/tests/triton_tests/info.jsonl new file mode 100644 index 0000000..879a65f --- /dev/null +++ b/tests/triton_tests/info.jsonl @@ -0,0 +1,142 @@ +{"repeat": 64, "batch_size": 1024, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.047907233238220215, "standard_gw": 0.04326179623603821, "standard_gx": 0.042986124753952026, "rowwise_fwd": 0.03902614116668701, "rowwise_bwd": 0.038955360651016235, "global_fwd": 0.03974884748458862, "global_bwd": 0.0391639769077301, "x_quantize_rowwise": 0.02619624137878418, "g_quantize_rowwise": 0.02695620059967041, "w_quantize_rowwise": 0.02631545066833496, "w_quantize_colwise_transpose": 0.08677691221237183, "w_quantize_global": 0.07359683513641357, "w_quantize_global_transpose": 0.08226558566093445, "cast_x": 0.007815659046173096, "cast_g": 0.016041100025177002, "cast_w": 0.01600012183189392, "time_standard": 0.13415515422821045, "time_rowwise": 0.28748810291290283, "time_global": 0.33118948340415955} +{"repeat": 64, "batch_size": 1024, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.04236400127410889, "standard_gw": 0.04898756742477417, "standard_gx": 0.04731118679046631, "rowwise_fwd": 0.03933534026145935, "rowwise_bwd": 0.03947317600250244, "global_fwd": 0.03688037395477295, "global_bwd": 0.039167702198028564, "x_quantize_rowwise": 0.02533942461013794, "g_quantize_rowwise": 0.02516806125640869, "w_quantize_rowwise": 0.02528354525566101, "w_quantize_colwise_transpose": 0.0903792679309845, "w_quantize_global": 0.0997595489025116, "w_quantize_global_transpose": 0.10209530591964722, "cast_x": 0.01626834273338318, "cast_g": 0.011973083019256592, "cast_w": 0.016044825315475464, "time_standard": 0.13866275548934937, "time_rowwise": 0.2939663827419281, "time_global": 0.37739798426628113} +{"repeat": 64, "batch_size": 2048, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.07753819227218628, "standard_gw": 0.08026883006095886, "standard_gx": 0.0906921923160553, "rowwise_fwd": 0.0630207359790802, "rowwise_bwd": 0.058263540267944336, "global_fwd": 0.06167963147163391, "global_bwd": 0.05801767110824585, "x_quantize_rowwise": 0.034205615520477295, "g_quantize_rowwise": 0.03341957926750183, "w_quantize_rowwise": 0.03244727849960327, "w_quantize_colwise_transpose": 0.08665025234222412, "w_quantize_global": 0.09483471512794495, "w_quantize_global_transpose": 0.10108202695846558, "cast_x": 0.012032687664031982, "cast_g": 0.03752484917640686, "cast_w": 0.01605972647666931, "time_standard": 0.24849921464920044, "time_rowwise": 0.3882758319377899, "time_global": 0.46350806951522827} +{"repeat": 64, "batch_size": 2048, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.09099021553993225, "standard_gw": 0.0799819827079773, "standard_gx": 0.07644668221473694, "rowwise_fwd": 0.05840510129928589, "rowwise_bwd": 0.06359070539474487, "global_fwd": 0.057831406593322754, "global_bwd": 0.06148591637611389, "x_quantize_rowwise": 0.03434717655181885, "g_quantize_rowwise": 0.03361701965332031, "w_quantize_rowwise": 0.03209337592124939, "w_quantize_colwise_transpose": 0.09028613567352295, "w_quantize_global": 0.0944770872592926, "w_quantize_global_transpose": 0.0994168221950531, "cast_x": 0.03769621253013611, "cast_g": 0.012010335922241211, "cast_w": 0.01600012183189392, "time_standard": 0.24741888046264648, "time_rowwise": 0.39232149720191956, "time_global": 0.4611574113368988} +{"repeat": 64, "batch_size": 4096, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.14450401067733765, "standard_gw": 0.14326348900794983, "standard_gx": 0.14762207865715027, "rowwise_fwd": 0.10525062680244446, "rowwise_bwd": 0.09800493717193604, "global_fwd": 0.10229647159576416, "global_bwd": 0.09718164801597595, "x_quantize_rowwise": 0.03429874777793884, "g_quantize_rowwise": 0.04567950963973999, "w_quantize_rowwise": 0.03365054726600647, "w_quantize_colwise_transpose": 0.08654966950416565, "w_quantize_global": 0.09663775563240051, "w_quantize_global_transpose": 0.10383129119873047, "cast_x": 0.01605972647666931, "cast_g": 0.08305534720420837, "cast_w": 0.01624971628189087, "time_standard": 0.43538957834243774, "time_rowwise": 0.5466975271701813, "time_global": 0.6231889128684998} +{"repeat": 64, "batch_size": 4096, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.14496594667434692, "standard_gw": 0.1412704586982727, "standard_gx": 0.14446303248405457, "rowwise_fwd": 0.10041892528533936, "rowwise_bwd": 0.10674074292182922, "global_fwd": 0.09856373071670532, "global_bwd": 0.10319426655769348, "x_quantize_rowwise": 0.045571476221084595, "g_quantize_rowwise": 0.03273040056228638, "w_quantize_rowwise": 0.033464282751083374, "w_quantize_colwise_transpose": 0.09154900908470154, "w_quantize_global": 0.0964440405368805, "w_quantize_global_transpose": 0.1031048595905304, "cast_x": 0.0835023820400238, "cast_g": 0.016242265701293945, "cast_w": 0.016283243894577026, "time_standard": 0.4306994378566742, "time_rowwise": 0.5517452955245972, "time_global": 0.6208792328834534} +{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28106942772865295, "standard_gw": 0.2841465175151825, "standard_gx": 0.301852822303772, "rowwise_fwd": 0.19879266619682312, "rowwise_bwd": 0.16228482127189636, "global_fwd": 0.19488856196403503, "global_bwd": 0.1607760787010193, "x_quantize_rowwise": 0.033974647521972656, "g_quantize_rowwise": 0.08221715688705444, "w_quantize_rowwise": 0.03248825669288635, "w_quantize_colwise_transpose": 0.08646398782730103, "w_quantize_global": 0.0939294695854187, "w_quantize_global_transpose": 0.09895861148834229, "cast_x": 0.03753975033760071, "cast_g": 0.15900656580924988, "cast_w": 0.01603737473487854, "time_standard": 0.8670687675476074, "time_rowwise": 0.8803680539131165, "time_global": 0.9488910436630249} +{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.26415660977363586, "standard_gw": 0.2679601311683655, "standard_gx": 0.30617788434028625, "rowwise_fwd": 0.180121511220932, "rowwise_bwd": 0.21555647253990173, "global_fwd": 0.17506256699562073, "global_bwd": 0.2116672694683075, "x_quantize_rowwise": 0.08289515972137451, "g_quantize_rowwise": 0.033795833587646484, "w_quantize_rowwise": 0.03366544842720032, "w_quantize_colwise_transpose": 0.09965524077415466, "w_quantize_global": 0.09595602750778198, "w_quantize_global_transpose": 0.1024976372718811, "cast_x": 0.1602955162525177, "cast_g": 0.03787502646446228, "cast_w": 0.016216188669204712, "time_standard": 0.8382946252822876, "time_rowwise": 0.9136497974395752, "time_global": 0.9698346257209778} +{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5719438195228577, "standard_gw": 0.524863600730896, "standard_gx": 0.6005167961120605, "rowwise_fwd": 0.3750324249267578, "rowwise_bwd": 0.28166547417640686, "global_fwd": 0.3674700856208801, "global_bwd": 0.2798214554786682, "x_quantize_rowwise": 0.04655122756958008, "g_quantize_rowwise": 0.1555122435092926, "w_quantize_rowwise": 0.03437697887420654, "w_quantize_colwise_transpose": 0.08634477853775024, "w_quantize_global": 0.09759142994880676, "w_quantize_global_transpose": 0.10081753134727478, "cast_x": 0.0828765332698822, "cast_g": 0.31184032559394836, "cast_w": 0.016063451766967773, "time_standard": 1.6973242163658142, "time_rowwise": 1.5043467283248901, "time_global": 1.5726275742053986} +{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5423910915851593, "standard_gw": 0.5674734711647034, "standard_gx": 0.5907565355300903, "rowwise_fwd": 0.3149174153804779, "rowwise_bwd": 0.3899820148944855, "global_fwd": 0.2909451723098755, "global_bwd": 0.3783814609050751, "x_quantize_rowwise": 0.15584751963615417, "g_quantize_rowwise": 0.04688650369644165, "w_quantize_rowwise": 0.031463801860809326, "w_quantize_colwise_transpose": 0.09072571992874146, "w_quantize_global": 0.09774044156074524, "w_quantize_global_transpose": 0.10405108332633972, "cast_x": 0.3111511468887329, "cast_g": 0.08282437920570374, "cast_w": 0.015992671251296997, "time_standard": 1.700621098279953, "time_rowwise": 1.5972964465618134, "time_global": 1.6413256525993347} +{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2115389108657837, "standard_gw": 1.1259466409683228, "standard_gx": 1.1027492582798004, "rowwise_fwd": 0.7407031953334808, "rowwise_bwd": 0.5539208650588989, "global_fwd": 0.7214657962322235, "global_bwd": 0.5515590310096741, "x_quantize_rowwise": 0.08765608072280884, "g_quantize_rowwise": 0.3022328019142151, "w_quantize_rowwise": 0.03347545862197876, "w_quantize_colwise_transpose": 0.08694455027580261, "w_quantize_global": 0.09706243872642517, "w_quantize_global_transpose": 0.10102614760398865, "cast_x": 0.1592189073562622, "cast_g": 0.6166175007820129, "cast_w": 0.01607835292816162, "time_standard": 3.440234810113907, "time_rowwise": 2.930879592895508, "time_global": 2.986948937177658} +{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1010989546775818, "standard_gw": 1.1352524161338806, "standard_gx": 1.1676251888275146, "rowwise_fwd": 0.5864761769771576, "rowwise_bwd": 0.7485374808311462, "global_fwd": 0.5547590553760529, "global_bwd": 0.7249303162097931, "x_quantize_rowwise": 0.3021731972694397, "g_quantize_rowwise": 0.08751824498176575, "w_quantize_rowwise": 0.033952295780181885, "w_quantize_colwise_transpose": 0.09011104702949524, "w_quantize_global": 0.09443238377571106, "w_quantize_global_transpose": 0.10376051068305969, "cast_x": 0.6167255342006683, "cast_g": 0.15922263264656067, "cast_w": 0.016070902347564697, "time_standard": 3.403976559638977, "time_rowwise": 2.984020859003067, "time_global": 3.0028261244297028} +{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.472013235092163, "standard_gw": 2.218998968601227, "standard_gx": 2.2116564214229584, "rowwise_fwd": 1.466125249862671, "rowwise_bwd": 1.0577328503131866, "global_fwd": 1.431729644536972, "global_bwd": 1.0476894676685333, "x_quantize_rowwise": 0.16929209232330322, "g_quantize_rowwise": 0.5952082574367523, "w_quantize_rowwise": 0.032100826501846313, "w_quantize_colwise_transpose": 0.08670613169670105, "w_quantize_global": 0.09590759873390198, "w_quantize_global_transpose": 0.10358169674873352, "cast_x": 0.31175464391708374, "cast_g": 1.2264922261238098, "cast_w": 0.016067177057266235, "time_standard": 6.902668625116348, "time_rowwise": 5.626164376735687, "time_global": 5.662407726049423} +{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.181064337491989, "standard_gw": 2.2256113588809967, "standard_gx": 2.3229196667671204, "rowwise_fwd": 1.0886266827583313, "rowwise_bwd": 1.4654062688350677, "global_fwd": 1.0472461581230164, "global_bwd": 1.433148980140686, "x_quantize_rowwise": 0.5954094231128693, "g_quantize_rowwise": 0.16921386122703552, "w_quantize_rowwise": 0.03442913293838501, "w_quantize_colwise_transpose": 0.09007751941680908, "w_quantize_global": 0.09575113654136658, "w_quantize_global_transpose": 0.10503828525543213, "cast_x": 1.2264810502529144, "cast_g": 0.3119036555290222, "cast_w": 0.01605600118637085, "time_standard": 6.729595363140106, "time_rowwise": 5.668774247169495, "time_global": 5.671419203281403} +{"repeat": 64, "batch_size": 1024, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.08157268166542053, "standard_gw": 0.07601454854011536, "standard_gx": 0.09059160947799683, "rowwise_fwd": 0.053066760301589966, "rowwise_bwd": 0.04787370562553406, "global_fwd": 0.05243346095085144, "global_bwd": 0.04809349775314331, "x_quantize_rowwise": 0.02571195363998413, "g_quantize_rowwise": 0.025898218154907227, "w_quantize_rowwise": 0.02714991569519043, "w_quantize_colwise_transpose": 0.19773468375205994, "w_quantize_global": 0.07273256778717041, "w_quantize_global_transpose": 0.08068978786468506, "cast_x": 0.008046627044677734, "cast_g": 0.0252649188041687, "cast_w": 0.0393986701965332, "time_standard": 0.24817883968353271, "time_rowwise": 0.4534497857093811, "time_global": 0.38157403469085693} +{"repeat": 64, "batch_size": 1024, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.09134411811828613, "standard_gw": 0.07602199912071228, "standard_gx": 0.09555742144584656, "rowwise_fwd": 0.047691166400909424, "rowwise_bwd": 0.05320459604263306, "global_fwd": 0.04759058356285095, "global_bwd": 0.0521540641784668, "x_quantize_rowwise": 0.025313347578048706, "g_quantize_rowwise": 0.025119632482528687, "w_quantize_rowwise": 0.0269375741481781, "w_quantize_colwise_transpose": 0.1857280731201172, "w_quantize_global": 0.07451698184013367, "w_quantize_global_transpose": 0.08009746670722961, "cast_x": 0.02547726035118103, "cast_g": 0.007897615432739258, "cast_w": 0.039536505937576294, "time_standard": 0.26292353868484497, "time_rowwise": 0.44001638889312744, "time_global": 0.3808140754699707} +{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.940010607242584, "standard_gw": 4.434864968061447, "standard_gx": 4.4097937643527985, "rowwise_fwd": 2.9467344284057617, "rowwise_bwd": 2.09181010723114, "global_fwd": 2.8806477785110474, "global_bwd": 2.0816922187805176, "x_quantize_rowwise": 0.33279508352279663, "g_quantize_rowwise": 1.1817067861557007, "w_quantize_rowwise": 0.03306567668914795, "w_quantize_colwise_transpose": 0.08666515350341797, "w_quantize_global": 0.0957287847995758, "w_quantize_global_transpose": 0.10242313146591187, "cast_x": 0.6165988743305206, "cast_g": 2.446405589580536, "cast_w": 0.016100704669952393, "time_standard": 13.78466933965683, "time_rowwise": 11.107642203569412, "time_global": 11.109858751296997} +{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.293464124202728, "standard_gw": 4.461295902729034, "standard_gx": 4.638340324163437, "rowwise_fwd": 2.116892486810684, "rowwise_bwd": 2.9479674994945526, "global_fwd": 2.0760856568813324, "global_bwd": 2.8755851089954376, "x_quantize_rowwise": 1.1818408966064453, "g_quantize_rowwise": 0.33276528120040894, "w_quantize_rowwise": 0.03287568688392639, "w_quantize_colwise_transpose": 0.09038299322128296, "w_quantize_global": 0.09598955512046814, "w_quantize_global_transpose": 0.100649893283844, "cast_x": 2.4467408657073975, "cast_g": 0.6165951490402222, "cast_w": 0.016082078218460083, "time_standard": 13.3931003510952, "time_rowwise": 11.164020746946335, "time_global": 11.12421229481697} +{"repeat": 64, "batch_size": 2048, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.1699887216091156, "standard_gw": 0.14045089483261108, "standard_gx": 0.17407909035682678, "rowwise_fwd": 0.10082125663757324, "rowwise_bwd": 0.08344277739524841, "global_fwd": 0.09941309690475464, "global_bwd": 0.08352473378181458, "x_quantize_rowwise": 0.025317072868347168, "g_quantize_rowwise": 0.03849714994430542, "w_quantize_rowwise": 0.02596527338027954, "w_quantize_colwise_transpose": 0.19767135381698608, "w_quantize_global": 0.07257238030433655, "w_quantize_global_transpose": 0.08127838373184204, "cast_x": 0.012032687664031982, "cast_g": 0.06345659494400024, "cast_w": 0.03953278064727783, "time_standard": 0.48451870679855347, "time_rowwise": 0.612165778875351, "time_global": 0.5410537123680115} +{"repeat": 64, "batch_size": 2048, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.14855340123176575, "standard_gw": 0.15553459525108337, "standard_gx": 0.16282498836517334, "rowwise_fwd": 0.09259581565856934, "rowwise_bwd": 0.11080875992774963, "global_fwd": 0.09166449308395386, "global_bwd": 0.10796263813972473, "x_quantize_rowwise": 0.03939121961593628, "g_quantize_rowwise": 0.025227665901184082, "w_quantize_rowwise": 0.027202069759368896, "w_quantize_colwise_transpose": 0.1940988004207611, "w_quantize_global": 0.07397681474685669, "w_quantize_global_transpose": 0.08178502321243286, "cast_x": 0.065632164478302, "cast_g": 0.01268833875656128, "cast_w": 0.04057586193084717, "time_standard": 0.46691298484802246, "time_rowwise": 0.6448589265346527, "time_global": 0.5755424499511719} +{"repeat": 64, "batch_size": 4096, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.32291561365127563, "standard_gw": 0.2875030040740967, "standard_gx": 0.3379322588443756, "rowwise_fwd": 0.19295886158943176, "rowwise_bwd": 0.16265735030174255, "global_fwd": 0.19031018018722534, "global_bwd": 0.16187503933906555, "x_quantize_rowwise": 0.02730637788772583, "g_quantize_rowwise": 0.06797909736633301, "w_quantize_rowwise": 0.02642720937728882, "w_quantize_colwise_transpose": 0.19745901226997375, "w_quantize_global": 0.07253512740135193, "w_quantize_global_transpose": 0.08047744631767273, "cast_x": 0.022336840629577637, "cast_g": 0.1209154725074768, "cast_w": 0.039268285036087036, "time_standard": 0.9483508765697479, "time_rowwise": 0.9622909128665924, "time_global": 0.8879862725734711} +{"repeat": 64, "batch_size": 4096, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.3019683063030243, "standard_gw": 0.288400799036026, "standard_gx": 0.3154948353767395, "rowwise_fwd": 0.18264353275299072, "rowwise_bwd": 0.2075284719467163, "global_fwd": 0.17072632908821106, "global_bwd": 0.1960061490535736, "x_quantize_rowwise": 0.06893649697303772, "g_quantize_rowwise": 0.02561509609222412, "w_quantize_rowwise": 0.026594847440719604, "w_quantize_colwise_transpose": 0.18575787544250488, "w_quantize_global": 0.07266923785209656, "w_quantize_global_transpose": 0.08060410618782043, "cast_x": 0.12182071805000305, "cast_g": 0.022590160369873047, "cast_w": 0.04000961780548096, "time_standard": 0.9058639407157898, "time_rowwise": 0.9854771196842194, "time_global": 0.9029582142829895} +{"repeat": 64, "batch_size": 8192, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.6489232182502747, "standard_gw": 0.5987770855426788, "standard_gx": 0.6644465029239655, "rowwise_fwd": 0.35867467522621155, "rowwise_bwd": 0.31855329871177673, "global_fwd": 0.353105366230011, "global_bwd": 0.31349435448646545, "x_quantize_rowwise": 0.03382191061973572, "g_quantize_rowwise": 0.12668967247009277, "w_quantize_rowwise": 0.02681836485862732, "w_quantize_colwise_transpose": 0.19756704568862915, "w_quantize_global": 0.07336586713790894, "w_quantize_global_transpose": 0.08036196231842041, "cast_x": 0.0583939254283905, "cast_g": 0.23520365357398987, "cast_w": 0.03935396671295166, "time_standard": 1.912146806716919, "time_rowwise": 1.660902053117752, "time_global": 1.579616218805313} +{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.5789436399936676, "standard_gw": 0.6130896508693695, "standard_gx": 0.6558857858181, "rowwise_fwd": 0.3464221954345703, "rowwise_bwd": 0.3650560975074768, "global_fwd": 0.3174394369125366, "global_bwd": 0.35758689045906067, "x_quantize_rowwise": 0.12686848640441895, "g_quantize_rowwise": 0.034302473068237305, "w_quantize_rowwise": 0.02745911478996277, "w_quantize_colwise_transpose": 0.1847483217716217, "w_quantize_global": 0.07192790508270264, "w_quantize_global_transpose": 0.08050352334976196, "cast_x": 0.23534893989562988, "cast_g": 0.05846098065376282, "cast_w": 0.03949552774429321, "time_standard": 1.847919076681137, "time_rowwise": 1.6979463398456573, "time_global": 1.6017183661460876} +{"repeat": 64, "batch_size": 1024, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.0573769211769104, "standard_gw": 0.061042606830596924, "standard_gx": 0.0783093273639679, "rowwise_fwd": 0.046797096729278564, "rowwise_bwd": 0.04620850086212158, "global_fwd": 0.04521384835243225, "global_bwd": 0.04425644874572754, "x_quantize_rowwise": 0.03257766366004944, "g_quantize_rowwise": 0.03449246287345886, "w_quantize_rowwise": 0.033657997846603394, "w_quantize_colwise_transpose": 0.1426301896572113, "w_quantize_global": 0.09257346391677856, "w_quantize_global_transpose": 0.10266527533531189, "cast_x": 0.011991709470748901, "cast_g": 0.020314007997512817, "cast_w": 0.027321279048919678, "time_standard": 0.19672885537147522, "time_rowwise": 0.39740651845932007, "time_global": 0.41282176971435547} +{"repeat": 64, "batch_size": 1024, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.07858872413635254, "standard_gw": 0.06122514605522156, "standard_gx": 0.05758553743362427, "rowwise_fwd": 0.04598498344421387, "rowwise_bwd": 0.04618242383003235, "global_fwd": 0.04597380757331848, "global_bwd": 0.046450644731521606, "x_quantize_rowwise": 0.03332272171974182, "g_quantize_rowwise": 0.033274292945861816, "w_quantize_rowwise": 0.0337548553943634, "w_quantize_colwise_transpose": 0.14807656407356262, "w_quantize_global": 0.09948387742042542, "w_quantize_global_transpose": 0.10120868682861328, "cast_x": 0.020120292901992798, "cast_g": 0.011488795280456543, "cast_w": 0.027466565370559692, "time_standard": 0.19739940762519836, "time_rowwise": 0.40182098746299744, "time_global": 0.420939177274704} +{"repeat": 64, "batch_size": 16384, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 1.3515166938304901, "standard_gw": 1.1536777019500732, "standard_gx": 1.224767416715622, "rowwise_fwd": 0.6912238895893097, "rowwise_bwd": 0.5562454462051392, "global_fwd": 0.67867711186409, "global_bwd": 0.5518943071365356, "x_quantize_rowwise": 0.06204098463058472, "g_quantize_rowwise": 0.24417787790298462, "w_quantize_rowwise": 0.025238841772079468, "w_quantize_colwise_transpose": 0.19756704568862915, "w_quantize_global": 0.07240846753120422, "w_quantize_global_transpose": 0.08046254515647888, "cast_x": 0.11138245463371277, "cast_g": 0.4637613892555237, "cast_w": 0.03935769200325012, "time_standard": 3.7299618124961853, "time_rowwise": 2.9301717877388, "time_global": 2.8433389961719513} +{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 1.2090615928173065, "standard_gw": 1.1396333575248718, "standard_gx": 1.2223869562149048, "rowwise_fwd": 0.5849376320838928, "rowwise_bwd": 0.6985403597354889, "global_fwd": 0.5565173923969269, "global_bwd": 0.6789751350879669, "x_quantize_rowwise": 0.2445802092552185, "g_quantize_rowwise": 0.06200745701789856, "w_quantize_rowwise": 0.027727335691452026, "w_quantize_colwise_transpose": 0.18501654267311096, "w_quantize_global": 0.07182732224464417, "w_quantize_global_transpose": 0.08069723844528198, "cast_x": 0.4638172686100006, "cast_g": 0.11136755347251892, "cast_w": 0.039517879486083984, "time_standard": 3.571081906557083, "time_rowwise": 2.9424428939819336, "time_global": 2.834238111972809} +{"repeat": 64, "batch_size": 32768, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 2.683013677597046, "standard_gw": 2.2987723350524902, "standard_gx": 2.4510622024536133, "rowwise_fwd": 1.359008252620697, "rowwise_bwd": 1.1018887162208557, "global_fwd": 1.3311207294464111, "global_bwd": 1.0954029858112335, "x_quantize_rowwise": 0.11804327368736267, "g_quantize_rowwise": 0.479232519865036, "w_quantize_rowwise": 0.026308000087738037, "w_quantize_colwise_transpose": 0.1975223422050476, "w_quantize_global": 0.07223710417747498, "w_quantize_global_transpose": 0.08019432425498962, "cast_x": 0.2161264419555664, "cast_g": 0.9207837283611298, "cast_w": 0.03929063677787781, "time_standard": 7.432848215103149, "time_rowwise": 5.580775439739227, "time_global": 5.475003272294998} +{"repeat": 64, "batch_size": 2048, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.11088326573371887, "standard_gw": 0.10994821786880493, "standard_gx": 0.12367218732833862, "rowwise_fwd": 0.07392093539237976, "rowwise_bwd": 0.07127970457077026, "global_fwd": 0.0730752944946289, "global_bwd": 0.07089227437973022, "x_quantize_rowwise": 0.03361701965332031, "g_quantize_rowwise": 0.03525242209434509, "w_quantize_rowwise": 0.03341585397720337, "w_quantize_colwise_transpose": 0.14318525791168213, "w_quantize_global": 0.09704753756523132, "w_quantize_global_transpose": 0.10221078991889954, "cast_x": 0.012002885341644287, "cast_g": 0.05240738391876221, "cast_w": 0.027313828468322754, "time_standard": 0.3445036709308624, "time_rowwise": 0.5006194114685059, "time_global": 0.5220435559749603} +{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 2.4625882506370544, "standard_gw": 2.421922981739044, "standard_gx": 2.380847930908203, "rowwise_fwd": 1.1231191456317902, "rowwise_bwd": 1.360483467578888, "global_fwd": 1.0947436094284058, "global_bwd": 1.3314113020896912, "x_quantize_rowwise": 0.4795975983142853, "g_quantize_rowwise": 0.11777132749557495, "w_quantize_rowwise": 0.02699345350265503, "w_quantize_colwise_transpose": 0.18484890460968018, "w_quantize_global": 0.07201358675956726, "w_quantize_global_transpose": 0.0803135335445404, "cast_x": 0.920858234167099, "cast_g": 0.21616369485855103, "cast_w": 0.03937259316444397, "time_standard": 7.265359163284302, "time_rowwise": 5.714736878871918, "time_global": 5.597773939371109} +{"repeat": 64, "batch_size": 2048, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.12437254190444946, "standard_gw": 0.11018291115760803, "standard_gx": 0.10970607399940491, "rowwise_fwd": 0.07167831063270569, "rowwise_bwd": 0.07583573460578918, "global_fwd": 0.07314234972000122, "global_bwd": 0.07501617074012756, "x_quantize_rowwise": 0.035624951124191284, "g_quantize_rowwise": 0.0333636999130249, "w_quantize_rowwise": 0.03264099359512329, "w_quantize_colwise_transpose": 0.14795735478401184, "w_quantize_global": 0.09621679782867432, "w_quantize_global_transpose": 0.10380148887634277, "cast_x": 0.05278363823890686, "cast_g": 0.01249462366104126, "cast_w": 0.02767890691757202, "time_standard": 0.3442615270614624, "time_rowwise": 0.5072839558124542, "time_global": 0.5273483693599701} +{"repeat": 64, "batch_size": 4096, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.21922588348388672, "standard_gw": 0.20731613039970398, "standard_gx": 0.23101642727851868, "rowwise_fwd": 0.1423358917236328, "rowwise_bwd": 0.1195073127746582, "global_fwd": 0.1401938498020172, "global_bwd": 0.11940300464630127, "x_quantize_rowwise": 0.03353878855705261, "g_quantize_rowwise": 0.06387382745742798, "w_quantize_rowwise": 0.03428757190704346, "w_quantize_colwise_transpose": 0.14376267790794373, "w_quantize_global": 0.09389594197273254, "w_quantize_global_transpose": 0.10196119546890259, "cast_x": 0.020060688257217407, "cast_g": 0.10236725211143494, "cast_w": 0.02732500433921814, "time_standard": 0.6575584411621094, "time_rowwise": 0.7446222007274628, "time_global": 0.7601827383041382} +{"repeat": 64, "batch_size": 4096, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.20026043057441711, "standard_gw": 0.21172687411308289, "standard_gx": 0.2276189625263214, "rowwise_fwd": 0.12956932187080383, "rowwise_bwd": 0.15310943126678467, "global_fwd": 0.12427568435668945, "global_bwd": 0.14432892203330994, "x_quantize_rowwise": 0.06471946835517883, "g_quantize_rowwise": 0.03309175372123718, "w_quantize_rowwise": 0.03242120146751404, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.09280815720558167, "w_quantize_global_transpose": 0.10265037417411804, "cast_x": 0.10267645120620728, "cast_g": 0.020150095224380493, "cast_w": 0.027399510145187378, "time_standard": 0.6396062672138214, "time_rowwise": 0.7719770073890686, "time_global": 0.773601233959198} +{"repeat": 64, "batch_size": 65536, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 5.324859172105789, "standard_gw": 4.977177828550339, "standard_gx": 4.468705505132675, "rowwise_fwd": 2.7004145085811615, "rowwise_bwd": 2.121664583683014, "global_fwd": 2.648312598466873, "global_bwd": 2.111390233039856, "x_quantize_rowwise": 0.22934377193450928, "g_quantize_rowwise": 0.9496547281742096, "w_quantize_rowwise": 0.02555176615715027, "w_quantize_colwise_transpose": 0.1977868378162384, "w_quantize_global": 0.0727437436580658, "w_quantize_global_transpose": 0.08098781108856201, "cast_x": 0.4259459674358368, "cast_g": 1.8352754414081573, "cast_w": 0.039637088775634766, "time_standard": 14.770742505788803, "time_rowwise": 11.201594024896622, "time_global": 11.069610714912415} +{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.49151480197906494, "standard_gw": 0.4681535065174103, "standard_gx": 0.42366236448287964, "rowwise_fwd": 0.2766512334346771, "rowwise_bwd": 0.2083033323287964, "global_fwd": 0.2709813416004181, "global_bwd": 0.20718947052955627, "x_quantize_rowwise": 0.034555792808532715, "g_quantize_rowwise": 0.11969730257987976, "w_quantize_rowwise": 0.03300607204437256, "w_quantize_colwise_transpose": 0.14345720410346985, "w_quantize_global": 0.09280070662498474, "w_quantize_global_transpose": 0.10214745998382568, "cast_x": 0.052288174629211426, "cast_g": 0.19747763872146606, "cast_w": 0.027339905500411987, "time_standard": 1.3833306729793549, "time_rowwise": 1.2838244438171387, "time_global": 1.2955255806446075} +{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.39635971188545227, "standard_gw": 0.44353678822517395, "standard_gx": 0.4724152386188507, "rowwise_fwd": 0.22813305258750916, "rowwise_bwd": 0.2868436276912689, "global_fwd": 0.2119205892086029, "global_bwd": 0.2749413251876831, "x_quantize_rowwise": 0.12082979083061218, "g_quantize_rowwise": 0.03444403409957886, "w_quantize_rowwise": 0.03444403409957886, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09495392441749573, "w_quantize_global_transpose": 0.1009330153465271, "cast_x": 0.19745156168937683, "cast_g": 0.05227327346801758, "cast_w": 0.027336180210113525, "time_standard": 1.312311738729477, "time_rowwise": 1.294981688261032, "time_global": 1.2815594673156738} +{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0207034647464752, "standard_gw": 0.897720456123352, "standard_gx": 0.8374936878681183, "rowwise_fwd": 0.5457103252410889, "rowwise_bwd": 0.4088357090950012, "global_fwd": 0.5308091640472412, "global_bwd": 0.40555745363235474, "x_quantize_rowwise": 0.05984678864479065, "g_quantize_rowwise": 0.2306811511516571, "w_quantize_rowwise": 0.0334717333316803, "w_quantize_colwise_transpose": 0.14356523752212524, "w_quantize_global": 0.09340420365333557, "w_quantize_global_transpose": 0.09996071457862854, "cast_x": 0.10207295417785645, "cast_g": 0.3880411386489868, "cast_w": 0.027671456336975098, "time_standard": 2.7559176087379456, "time_rowwise": 2.3198314011096954, "time_global": 2.31797993183136} +{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 4.502948373556137, "standard_gw": 4.418112337589264, "standard_gx": 4.748217761516571, "rowwise_fwd": 2.1329298615455627, "rowwise_bwd": 2.6968345046043396, "global_fwd": 2.102244645357132, "global_bwd": 2.6461556553840637, "x_quantize_rowwise": 0.9493157267570496, "g_quantize_rowwise": 0.2290569245815277, "w_quantize_rowwise": 0.02551451325416565, "w_quantize_colwise_transpose": 0.18491223454475403, "w_quantize_global": 0.07426366209983826, "w_quantize_global_transpose": 0.08058920502662659, "cast_x": 1.8352717161178589, "cast_g": 0.425681471824646, "cast_w": 0.039402395486831665, "time_standard": 13.669278472661972, "time_rowwise": 10.636676102876663, "time_global": 10.499738156795502} +{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8179470896720886, "standard_gw": 0.8687414228916168, "standard_gx": 0.9276494383811951, "rowwise_fwd": 0.4481859505176544, "rowwise_bwd": 0.5557462573051453, "global_fwd": 0.4100687801837921, "global_bwd": 0.5317367613315582, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.05963817238807678, "w_quantize_rowwise": 0.033523887395858765, "w_quantize_colwise_transpose": 0.14462321996688843, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10088086128234863, "cast_x": 0.3879927098751068, "cast_g": 0.10205060243606567, "cast_w": 0.02714991569519043, "time_standard": 2.6143379509449005, "time_rowwise": 2.3406408727169037, "time_global": 2.295881509780884} +{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0698904991149902, "standard_gw": 1.7200261354446411, "standard_gx": 1.663345843553543, "rowwise_fwd": 1.0664835572242737, "rowwise_bwd": 0.8059032261371613, "global_fwd": 1.0454729199409485, "global_bwd": 0.801432877779007, "x_quantize_rowwise": 0.1127384603023529, "g_quantize_rowwise": 0.4529319703578949, "w_quantize_rowwise": 0.03398582339286804, "w_quantize_colwise_transpose": 0.14343857765197754, "w_quantize_global": 0.09441003203392029, "w_quantize_global_transpose": 0.09993091225624084, "cast_x": 0.19744038581848145, "cast_g": 0.769149512052536, "cast_w": 0.02734735608100891, "time_standard": 5.453262478113174, "time_rowwise": 4.335507750511169, "time_global": 4.3269433081150055} +{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.758193761110306, "standard_gw": 1.6880109906196594, "standard_gx": 1.8163062632083893, "rowwise_fwd": 0.8343160152435303, "rowwise_bwd": 1.073598861694336, "global_fwd": 0.8045099675655365, "global_bwd": 1.0492689907550812, "x_quantize_rowwise": 0.453021377325058, "g_quantize_rowwise": 0.11304020881652832, "w_quantize_rowwise": 0.0337064266204834, "w_quantize_colwise_transpose": 0.1452416181564331, "w_quantize_global": 0.09451434016227722, "w_quantize_global_transpose": 0.0998079776763916, "cast_x": 0.769101083278656, "cast_g": 0.19731372594833374, "cast_w": 0.027332454919815063, "time_standard": 6.2625110149383545, "time_rowwise": 4.340935498476028, "time_global": 4.302173852920532} +{"repeat": 64, "batch_size": 131072, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 10.728541761636734, "standard_gw": 9.228862822055817, "standard_gx": 8.837487548589706, "rowwise_fwd": 5.4414160549640656, "rowwise_bwd": 4.186157137155533, "global_fwd": 5.329187959432602, "global_bwd": 4.150416702032089, "x_quantize_rowwise": 0.4517659544944763, "g_quantize_rowwise": 1.890372484922409, "w_quantize_rowwise": 0.027563422918319702, "w_quantize_colwise_transpose": 0.1980513334274292, "w_quantize_global": 0.0733695924282074, "w_quantize_global_transpose": 0.08009746670722961, "cast_x": 0.8449330925941467, "cast_g": 3.6641769111156464, "cast_w": 0.03945454955101013, "time_standard": 28.794892132282257, "time_rowwise": 21.42418920993805, "time_global": 21.20407298207283} +{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.127204418182373, "standard_gw": 3.359321504831314, "standard_gx": 5.557261407375336, "rowwise_fwd": 2.1365806460380554, "rowwise_bwd": 1.6042962670326233, "global_fwd": 2.0923763513565063, "global_bwd": 1.5939176082611084, "x_quantize_rowwise": 0.21954253315925598, "g_quantize_rowwise": 0.8971206843852997, "w_quantize_rowwise": 0.03357976675033569, "w_quantize_colwise_transpose": 0.1431293785572052, "w_quantize_global": 0.10574981570243835, "w_quantize_global_transpose": 0.10281801223754883, "cast_x": 0.38795173168182373, "cast_g": 1.5318207442760468, "cast_w": 0.027142465114593506, "time_standard": 13.043787330389023, "time_rowwise": 8.39357078075409, "time_global": 8.370846509933472} +{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.576469004154205, "standard_gw": 3.361724317073822, "standard_gx": 3.6300085484981537, "rowwise_fwd": 1.6183294355869293, "rowwise_bwd": 2.1462254226207733, "global_fwd": 1.5953555703163147, "global_bwd": 2.0915642380714417, "x_quantize_rowwise": 0.8973218500614166, "g_quantize_rowwise": 0.2197064459323883, "w_quantize_rowwise": 0.03402307629585266, "w_quantize_colwise_transpose": 0.14822185039520264, "w_quantize_global": 0.09706616401672363, "w_quantize_global_transpose": 0.10339170694351196, "cast_x": 1.5312805771827698, "cast_g": 0.3879964351654053, "cast_w": 0.0269375741481781, "time_standard": 12.568201869726181, "time_rowwise": 8.425552397966385, "time_global": 8.366130292415619} +{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 8.900497108697891, "standard_gw": 9.188394993543625, "standard_gx": 9.503517299890518, "rowwise_fwd": 4.189815372228622, "rowwise_bwd": 5.426768213510513, "global_fwd": 4.155576229095459, "global_bwd": 5.329132080078125, "x_quantize_rowwise": 1.8885880708694458, "g_quantize_rowwise": 0.45193731784820557, "w_quantize_rowwise": 0.025987625122070312, "w_quantize_colwise_transpose": 0.1842118799686432, "w_quantize_global": 0.07349997758865356, "w_quantize_global_transpose": 0.08074194192886353, "cast_x": 3.6639943718910217, "cast_g": 0.8447282016277313, "cast_w": 0.03973767161369324, "time_standard": 27.592409402132034, "time_rowwise": 21.355703473091125, "time_global": 21.167870610952377} +{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.2329623401165, "standard_gw": 6.799045950174332, "standard_gx": 6.893906742334366, "rowwise_fwd": 4.252739250659943, "rowwise_bwd": 3.2025352120399475, "global_fwd": 4.176046699285507, "global_bwd": 3.173377364873886, "x_quantize_rowwise": 0.43221935629844666, "g_quantize_rowwise": 1.7872042953968048, "w_quantize_rowwise": 0.03328174352645874, "w_quantize_colwise_transpose": 0.1431480050086975, "w_quantize_global": 0.09707733988761902, "w_quantize_global_transpose": 0.10161846876144409, "cast_x": 0.7692091166973114, "cast_g": 3.057178109884262, "cast_w": 0.027302652597427368, "time_standard": 21.9259150326252, "time_rowwise": 16.65017381310463, "time_global": 16.56658947467804} +{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.278409510850906, "standard_gw": 6.815284490585327, "standard_gx": 7.280956953763962, "rowwise_fwd": 3.206692636013031, "rowwise_bwd": 4.246953874826431, "global_fwd": 3.1801797449588776, "global_bwd": 4.169579595327377, "x_quantize_rowwise": 1.7862766981124878, "g_quantize_rowwise": 0.4329495131969452, "w_quantize_rowwise": 0.03413483500480652, "w_quantize_colwise_transpose": 0.14493241906166077, "w_quantize_global": 0.09881332516670227, "w_quantize_global_transpose": 0.10376423597335815, "cast_x": 3.057088702917099, "cast_g": 0.7693544030189514, "cast_w": 0.027261674404144287, "time_standard": 25.374650955200195, "time_rowwise": 16.66722446680069, "time_global": 16.586847603321075} +{"repeat": 64, "batch_size": 1024, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.11636316776275635, "standard_gw": 0.11816620826721191, "standard_gx": 0.11482089757919312, "rowwise_fwd": 0.08482113480567932, "rowwise_bwd": 0.06284937262535095, "global_fwd": 0.08296221494674683, "global_bwd": 0.061664730310440063, "x_quantize_rowwise": 0.026706606149673462, "g_quantize_rowwise": 0.025641173124313354, "w_quantize_rowwise": 0.03740563988685608, "w_quantize_colwise_transpose": 0.2965778112411499, "w_quantize_global": 0.11304393410682678, "w_quantize_global_transpose": 0.12390688061714172, "cast_x": 0.008635222911834717, "cast_g": 0.037532299757003784, "cast_w": 0.06856024265289307, "time_standard": 0.3493502736091614, "time_rowwise": 0.652167946100235, "time_global": 0.5520917475223541} +{"repeat": 64, "batch_size": 1024, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.11609122157096863, "standard_gw": 0.11704489588737488, "standard_gx": 0.11566653847694397, "rowwise_fwd": 0.06706640124320984, "rowwise_bwd": 0.09074807167053223, "global_fwd": 0.06621330976486206, "global_bwd": 0.0859871506690979, "x_quantize_rowwise": 0.027574598789215088, "g_quantize_rowwise": 0.02520531415939331, "w_quantize_rowwise": 0.04095584154129028, "w_quantize_colwise_transpose": 0.37036463618278503, "w_quantize_global": 0.11350959539413452, "w_quantize_global_transpose": 0.12202560901641846, "cast_x": 0.03780052065849304, "cast_g": 0.00860169529914856, "cast_w": 0.06864592432975769, "time_standard": 0.3488026559352875, "time_rowwise": 0.7389597594738007, "time_global": 0.5575604736804962} +{"repeat": 64, "batch_size": 2048, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.22610649466514587, "standard_gw": 0.2229548990726471, "standard_gx": 0.22150203585624695, "rowwise_fwd": 0.1421608030796051, "rowwise_bwd": 0.10771304368972778, "global_fwd": 0.13930723071098328, "global_bwd": 0.10715052485466003, "x_quantize_rowwise": 0.02812594175338745, "g_quantize_rowwise": 0.04733726382255554, "w_quantize_rowwise": 0.03758445382118225, "w_quantize_colwise_transpose": 0.29515475034713745, "w_quantize_global": 0.11344626545906067, "w_quantize_global_transpose": 0.12392178177833557, "cast_x": 0.013589859008789062, "cast_g": 0.08285418152809143, "cast_w": 0.06850436329841614, "time_standard": 0.6705634295940399, "time_rowwise": 0.8810311555862427, "time_global": 0.7822439074516296} +{"repeat": 64, "batch_size": 2048, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.20173192024230957, "standard_gw": 0.2351999282836914, "standard_gx": 0.24710968136787415, "rowwise_fwd": 0.12035667896270752, "rowwise_bwd": 0.153418630361557, "global_fwd": 0.11473894119262695, "global_bwd": 0.14553219079971313, "x_quantize_rowwise": 0.04762038588523865, "g_quantize_rowwise": 0.02557411789894104, "w_quantize_rowwise": 0.04055723547935486, "w_quantize_colwise_transpose": 0.32641738653182983, "w_quantize_global": 0.1138448715209961, "w_quantize_global_transpose": 0.12255832552909851, "cast_x": 0.08405372500419617, "cast_g": 0.013835728168487549, "cast_w": 0.06961449980735779, "time_standard": 0.6840415298938751, "time_rowwise": 0.9491443634033203, "time_global": 0.8050687611103058} +{"repeat": 64, "batch_size": 4096, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.48126280307769775, "standard_gw": 0.46824291348457336, "standard_gx": 0.45252591371536255, "rowwise_fwd": 0.2749897539615631, "rowwise_bwd": 0.2111680805683136, "global_fwd": 0.2689175307750702, "global_bwd": 0.2104043960571289, "x_quantize_rowwise": 0.02676248550415039, "g_quantize_rowwise": 0.0842660665512085, "w_quantize_rowwise": 0.037495046854019165, "w_quantize_colwise_transpose": 0.2952851355075836, "w_quantize_global": 0.11366978287696838, "w_quantize_global_transpose": 0.12461841106414795, "cast_x": 0.0283755362033844, "cast_g": 0.1590624451637268, "cast_w": 0.06854161620140076, "time_standard": 1.4020316302776337, "time_rowwise": 1.3982094824314117, "time_global": 1.2968815863132477} +{"repeat": 64, "batch_size": 4096, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.4076175391674042, "standard_gw": 0.45526400208473206, "standard_gx": 0.4996545612812042, "rowwise_fwd": 0.238761305809021, "rowwise_bwd": 0.2913624048233032, "global_fwd": 0.2149641513824463, "global_bwd": 0.2717897295951843, "x_quantize_rowwise": 0.0845976173877716, "g_quantize_rowwise": 0.0266246497631073, "w_quantize_rowwise": 0.04038959741592407, "w_quantize_colwise_transpose": 0.33299997448921204, "w_quantize_global": 0.11374801397323608, "w_quantize_global_transpose": 0.12202560901641846, "cast_x": 0.15895813703536987, "cast_g": 0.028312206268310547, "cast_w": 0.06841868162155151, "time_standard": 1.3625361025333405, "time_rowwise": 1.4699995517730713, "time_global": 1.2890137732028961} +{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 1.02214515209198, "standard_gw": 0.9412020444869995, "standard_gx": 0.883936882019043, "rowwise_fwd": 0.5209781229496002, "rowwise_bwd": 0.41617080569267273, "global_fwd": 0.5089044570922852, "global_bwd": 0.4142932593822479, "x_quantize_rowwise": 0.03763660788536072, "g_quantize_rowwise": 0.15798211097717285, "w_quantize_rowwise": 0.0375211238861084, "w_quantize_colwise_transpose": 0.2973228693008423, "w_quantize_global": 0.11317431926727295, "w_quantize_global_transpose": 0.12396648526191711, "cast_x": 0.0685863196849823, "cast_g": 0.311531126499176, "cast_w": 0.0685080885887146, "time_standard": 2.8472840785980225, "time_rowwise": 2.4088136851787567, "time_global": 2.2971592843532562} +{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.8539073169231415, "standard_gw": 0.9352751076221466, "standard_gx": 0.9567439556121826, "rowwise_fwd": 0.4599541425704956, "rowwise_bwd": 0.531073659658432, "global_fwd": 0.42063742876052856, "global_bwd": 0.5125999450683594, "x_quantize_rowwise": 0.1581348478794098, "g_quantize_rowwise": 0.03755837678909302, "w_quantize_rowwise": 0.04056468605995178, "w_quantize_colwise_transpose": 0.3295913338661194, "w_quantize_global": 0.11314079165458679, "w_quantize_global_transpose": 0.12153387069702148, "cast_x": 0.3114752471446991, "cast_g": 0.06850063800811768, "cast_w": 0.06839632987976074, "time_standard": 2.7459263801574707, "time_rowwise": 2.492152154445648, "time_global": 2.2988803684711456} +{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 2.0550191402435303, "standard_gw": 1.7850138247013092, "standard_gx": 1.7571337521076202, "rowwise_fwd": 1.026798039674759, "rowwise_bwd": 0.8242167532444, "global_fwd": 1.0042376816272736, "global_bwd": 0.8189938962459564, "x_quantize_rowwise": 0.0688992440700531, "g_quantize_rowwise": 0.3054179251194, "w_quantize_rowwise": 0.03757700324058533, "w_quantize_colwise_transpose": 0.2973712980747223, "w_quantize_global": 0.11324509978294373, "w_quantize_global_transpose": 0.12398511171340942, "cast_x": 0.13050436973571777, "cast_g": 0.6165280938148499, "cast_w": 0.06848573684692383, "time_standard": 5.59716671705246, "time_rowwise": 4.345294088125229, "time_global": 4.2197927832603455} +{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 1.79310142993927, "standard_gw": 1.7801076173782349, "standard_gx": 1.9140169024467468, "rowwise_fwd": 0.8629709482192993, "rowwise_bwd": 1.0353922843933105, "global_fwd": 0.8200556039810181, "global_bwd": 1.002725213766098, "x_quantize_rowwise": 0.30517578125, "g_quantize_rowwise": 0.06880238652229309, "w_quantize_rowwise": 0.040318816900253296, "w_quantize_colwise_transpose": 0.3413744270801544, "w_quantize_global": 0.11326000094413757, "w_quantize_global_transpose": 0.12197345495223999, "cast_x": 0.6162337958812714, "cast_g": 0.13053417205810547, "cast_w": 0.06848946213722229, "time_standard": 5.487225949764252, "time_rowwise": 4.4341422617435455, "time_global": 4.212100058794022} +{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 4.0736086666584015, "standard_gw": 3.595758229494095, "standard_gx": 3.7020929157733917, "rowwise_fwd": 2.0306408405303955, "rowwise_bwd": 1.635722815990448, "global_fwd": 1.9890740513801575, "global_bwd": 1.627359539270401, "x_quantize_rowwise": 0.13131648302078247, "g_quantize_rowwise": 0.6001107394695282, "w_quantize_rowwise": 0.03781542181968689, "w_quantize_colwise_transpose": 0.2975836396217346, "w_quantize_global": 0.11357292532920837, "w_quantize_global_transpose": 0.12416765093803406, "cast_x": 0.2544410526752472, "cast_g": 1.2265890836715698, "cast_w": 0.06866827607154846, "time_standard": 11.371459811925888, "time_rowwise": 8.32894816994667, "time_global": 8.181359618902206} +{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 3.525231033563614, "standard_gw": 3.489706665277481, "standard_gx": 3.9937011897563934, "rowwise_fwd": 1.6627348959445953, "rowwise_bwd": 2.0311400294303894, "global_fwd": 1.6270726919174194, "global_bwd": 1.988884061574936, "x_quantize_rowwise": 0.5999915301799774, "g_quantize_rowwise": 0.1310594379901886, "w_quantize_rowwise": 0.04043802618980408, "w_quantize_colwise_transpose": 0.32950565218925476, "w_quantize_global": 0.11298432946205139, "w_quantize_global_transpose": 0.12201443314552307, "cast_x": 1.2257546186447144, "cast_g": 0.25444477796554565, "cast_w": 0.06848573684692383, "time_standard": 11.008638888597488, "time_rowwise": 8.28457623720169, "time_global": 8.071713149547577} +{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 8.123598992824554, "standard_gw": 8.085217326879501, "standard_gx": 7.293816655874252, "rowwise_fwd": 4.07782569527626, "rowwise_bwd": 3.196723759174347, "global_fwd": 4.001103341579437, "global_bwd": 3.1843744218349457, "x_quantize_rowwise": 0.2560615539550781, "g_quantize_rowwise": 1.1893659830093384, "w_quantize_rowwise": 0.037297606468200684, "w_quantize_colwise_transpose": 0.29668211936950684, "w_quantize_global": 0.11358782649040222, "w_quantize_global_transpose": 0.12476742267608643, "cast_x": 0.5020052194595337, "cast_g": 2.4454034864902496, "cast_w": 0.0684782862663269, "time_standard": 23.502632975578308, "time_rowwise": 17.139174044132233, "time_global": 16.95447787642479} +{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 6.932958960533142, "standard_gw": 7.0609524846076965, "standard_gx": 7.460080087184906, "rowwise_fwd": 3.1809918582439423, "rowwise_bwd": 4.078391939401627, "global_fwd": 3.185112029314041, "global_bwd": 3.99089977145195, "x_quantize_rowwise": 1.1891834437847137, "g_quantize_rowwise": 0.25588274002075195, "w_quantize_rowwise": 0.0406019389629364, "w_quantize_colwise_transpose": 0.3389529883861542, "w_quantize_global": 0.11313334107398987, "w_quantize_global_transpose": 0.12241676449775696, "cast_x": 2.4446770548820496, "cast_g": 0.5022138357162476, "cast_w": 0.06857141852378845, "time_standard": 21.453991532325745, "time_rowwise": 16.14495739340782, "time_global": 15.9175805747509} +{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 16.38999581336975, "standard_gw": 15.075922012329102, "standard_gx": 14.479495584964752, "rowwise_fwd": 8.128684014081955, "rowwise_bwd": 6.41091912984848, "global_fwd": 7.977847009897232, "global_bwd": 6.362702697515488, "x_quantize_rowwise": 0.5057230591773987, "g_quantize_rowwise": 2.3681968450546265, "w_quantize_rowwise": 0.037435442209243774, "w_quantize_colwise_transpose": 0.29555708169937134, "w_quantize_global": 0.11360272765159607, "w_quantize_global_transpose": 0.12426823377609253, "cast_x": 0.997692346572876, "cast_g": 4.8848651349544525, "cast_w": 0.0685565173625946, "time_standard": 45.945413410663605, "time_rowwise": 32.82243758440018, "time_global": 32.528262585401535} +{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 14.838922768831253, "standard_gw": 15.112213790416718, "standard_gx": 14.869242906570435, "rowwise_fwd": 6.402213126420975, "rowwise_bwd": 8.132629096508026, "global_fwd": 6.36359304189682, "global_bwd": 7.9823993146419525, "x_quantize_rowwise": 2.367999404668808, "g_quantize_rowwise": 0.5056969821453094, "w_quantize_rowwise": 0.04053488373756409, "w_quantize_colwise_transpose": 0.3559887409210205, "w_quantize_global": 0.1136288046836853, "w_quantize_global_transpose": 0.125102698802948, "cast_x": 4.880473017692566, "cast_g": 0.9965412318706512, "cast_w": 0.06855279207229614, "time_standard": 44.820379465818405, "time_rowwise": 32.91727602481842, "time_global": 32.57063403725624} +{"repeat": 64, "batch_size": 1024, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.15426427125930786, "standard_gw": 0.14531239867210388, "standard_gx": 0.1703128218650818, "rowwise_fwd": 0.09618699550628662, "rowwise_bwd": 0.10633841156959534, "global_fwd": 0.09483471512794495, "global_bwd": 0.10636076331138611, "x_quantize_rowwise": 0.02434849739074707, "g_quantize_rowwise": 0.026009976863861084, "w_quantize_rowwise": 0.04366040229797363, "w_quantize_colwise_transpose": 0.34148991107940674, "w_quantize_global": 0.13587623834609985, "w_quantize_global_transpose": 0.14698877930641174, "cast_x": 0.009745359420776367, "cast_g": 0.03773719072341919, "cast_w": 0.08277222514152527, "time_standard": 0.46988949179649353, "time_rowwise": 0.7833465933799744, "time_global": 0.6797313690185547} +{"repeat": 64, "batch_size": 1024, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.16738846898078918, "standard_gw": 0.14199689030647278, "standard_gx": 0.15476346015930176, "rowwise_fwd": 0.11660531163215637, "rowwise_bwd": 0.1050308346748352, "global_fwd": 0.11050701141357422, "global_bwd": 0.09868666529655457, "x_quantize_rowwise": 0.02781301736831665, "g_quantize_rowwise": 0.024966895580291748, "w_quantize_rowwise": 0.047437846660614014, "w_quantize_colwise_transpose": 0.5995631217956543, "w_quantize_global": 0.1362822949886322, "w_quantize_global_transpose": 0.14807283878326416, "cast_x": 0.0377558171749115, "cast_g": 0.00973045825958252, "cast_w": 0.0828281044960022, "time_standard": 0.4641488194465637, "time_rowwise": 1.063413918018341, "time_global": 0.6883256137371063} +{"repeat": 64, "batch_size": 2048, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.2727396786212921, "standard_gw": 0.2711080014705658, "standard_gx": 0.3120154142379761, "rowwise_fwd": 0.16424059867858887, "rowwise_bwd": 0.17686933279037476, "global_fwd": 0.161685049533844, "global_bwd": 0.17517060041427612, "x_quantize_rowwise": 0.025484710931777954, "g_quantize_rowwise": 0.047635287046432495, "w_quantize_rowwise": 0.04380941390991211, "w_quantize_colwise_transpose": 0.3401711583137512, "w_quantize_global": 0.13605505228042603, "w_quantize_global_transpose": 0.14705583453178406, "cast_x": 0.01584365963935852, "cast_g": 0.08274242281913757, "cast_w": 0.08281320333480835, "time_standard": 0.855863094329834, "time_rowwise": 1.0693185031414032, "time_global": 0.9641945362091064} +{"repeat": 64, "batch_size": 2048, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.28916075825691223, "standard_gw": 0.29472261667251587, "standard_gx": 0.30096620321273804, "rowwise_fwd": 0.19618868827819824, "rowwise_bwd": 0.17556175589561462, "global_fwd": 0.18328800797462463, "global_bwd": 0.16647577285766602, "x_quantize_rowwise": 0.047441571950912476, "g_quantize_rowwise": 0.026609748601913452, "w_quantize_rowwise": 0.04766508936882019, "w_quantize_colwise_transpose": 0.6060972809791565, "w_quantize_global": 0.1363418996334076, "w_quantize_global_transpose": 0.14806538820266724, "cast_x": 0.08295103907585144, "cast_g": 0.015836209058761597, "cast_w": 0.08285045623779297, "time_standard": 0.8848495781421661, "time_rowwise": 1.3942867517471313, "time_global": 1.0029450058937073} +{"repeat": 64, "batch_size": 4096, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.6430819630622864, "standard_gw": 0.5622953176498413, "standard_gx": 0.5780421197414398, "rowwise_fwd": 0.318676233291626, "rowwise_bwd": 0.29438361525535583, "global_fwd": 0.31290948390960693, "global_bwd": 0.290747731924057, "x_quantize_rowwise": 0.027455389499664307, "g_quantize_rowwise": 0.08405372500419617, "w_quantize_rowwise": 0.04369765520095825, "w_quantize_colwise_transpose": 0.34110620617866516, "w_quantize_global": 0.1360774040222168, "w_quantize_global_transpose": 0.14697015285491943, "cast_x": 0.037614256143569946, "cast_g": 0.15922263264656067, "cast_w": 0.08288025856018066, "time_standard": 1.7834194004535675, "time_rowwise": 1.671668142080307, "time_global": 1.560509204864502} +{"repeat": 64, "batch_size": 4096, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.551275908946991, "standard_gw": 0.591665506362915, "standard_gx": 0.6067268550395966, "rowwise_fwd": 0.33493712544441223, "rowwise_bwd": 0.32918527722358704, "global_fwd": 0.29528141021728516, "global_bwd": 0.31659379601478577, "x_quantize_rowwise": 0.08441135287284851, "g_quantize_rowwise": 0.025656074285507202, "w_quantize_rowwise": 0.04745647311210632, "w_quantize_colwise_transpose": 0.5993843078613281, "w_quantize_global": 0.1359879970550537, "w_quantize_global_transpose": 0.14815106987953186, "cast_x": 0.15932321548461914, "cast_g": 0.037439167499542236, "cast_w": 0.08288398385047913, "time_standard": 1.7496682703495026, "time_rowwise": 2.0126961171627045, "time_global": 1.5977472066879272} +{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2295916676521301, "standard_gw": 1.116037368774414, "standard_gx": 1.1164769530296326, "rowwise_fwd": 0.603698194026947, "rowwise_bwd": 0.5168020725250244, "global_fwd": 0.5922466516494751, "global_bwd": 0.5151033401489258, "x_quantize_rowwise": 0.0437907874584198, "g_quantize_rowwise": 0.157918781042099, "w_quantize_rowwise": 0.044032931327819824, "w_quantize_colwise_transpose": 0.34073740243911743, "w_quantize_global": 0.13559311628341675, "w_quantize_global_transpose": 0.14679506421089172, "cast_x": 0.08263811469078064, "cast_g": 0.3115162253379822, "cast_w": 0.08287280797958374, "time_standard": 3.4621059894561768, "time_rowwise": 2.8230175375938416, "time_global": 2.707485109567642} +{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.090865582227707, "standard_gw": 1.1468492448329926, "standard_gx": 1.1166594922542572, "rowwise_fwd": 0.5559474229812622, "rowwise_bwd": 0.6105974316596985, "global_fwd": 0.5200020968914032, "global_bwd": 0.592011958360672, "x_quantize_rowwise": 0.15802308917045593, "g_quantize_rowwise": 0.04357844591140747, "w_quantize_rowwise": 0.04709511995315552, "w_quantize_colwise_transpose": 0.5969703197479248, "w_quantize_global": 0.13620033860206604, "w_quantize_global_transpose": 0.148136168718338, "cast_x": 0.31115859746932983, "cast_g": 0.08263811469078064, "cast_w": 0.08268281817436218, "time_standard": 3.3543743193149567, "time_rowwise": 3.159061074256897, "time_global": 2.744801342487335} +{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4665743112564087, "standard_gw": 2.1993443369865417, "standard_gx": 2.1993033587932587, "rowwise_fwd": 1.192428171634674, "rowwise_bwd": 1.023314893245697, "global_fwd": 1.1711902916431427, "global_bwd": 1.0202191770076752, "x_quantize_rowwise": 0.08077174425125122, "g_quantize_rowwise": 0.30520185828208923, "w_quantize_rowwise": 0.043783336877822876, "w_quantize_colwise_transpose": 0.339999794960022, "w_quantize_global": 0.13628602027893066, "w_quantize_global_transpose": 0.14696642756462097, "cast_x": 0.15902891755104065, "cast_g": 0.6164535880088806, "cast_w": 0.08285418152809143, "time_standard": 6.865222007036209, "time_rowwise": 5.184844136238098, "time_global": 5.059979856014252} +{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1861791610717773, "standard_gw": 2.157818526029587, "standard_gx": 2.321537584066391, "rowwise_fwd": 1.0536126792430878, "rowwise_bwd": 1.1971630156040192, "global_fwd": 1.02127343416214, "global_bwd": 1.1707991361618042, "x_quantize_rowwise": 0.30522048473358154, "g_quantize_rowwise": 0.08065253496170044, "w_quantize_rowwise": 0.04741176962852478, "w_quantize_colwise_transpose": 0.5979575216770172, "w_quantize_global": 0.1362040638923645, "w_quantize_global_transpose": 0.14854222536087036, "cast_x": 0.6162486970424652, "cast_g": 0.1591891050338745, "cast_w": 0.08288398385047913, "time_standard": 6.665535271167755, "time_rowwise": 5.439836531877518, "time_global": 5.020510405302048} +{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.891645163297653, "standard_gw": 4.233300685882568, "standard_gx": 4.2071714997291565, "rowwise_fwd": 2.3616664111614227, "rowwise_bwd": 1.9419342279434204, "global_fwd": 2.3244209587574005, "global_bwd": 1.9598640501499176, "x_quantize_rowwise": 0.15483051538467407, "g_quantize_rowwise": 0.6008371710777283, "w_quantize_rowwise": 0.043839216232299805, "w_quantize_colwise_transpose": 0.3400743007659912, "w_quantize_global": 0.1362822949886322, "w_quantize_global_transpose": 0.14691054821014404, "cast_x": 0.31141936779022217, "cast_g": 1.2254081666469574, "cast_w": 0.08280202746391296, "time_standard": 13.332117348909378, "time_rowwise": 9.676482528448105, "time_global": 9.556446224451065} +{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.267625510692596, "standard_gw": 4.237007349729538, "standard_gx": 4.666488617658615, "rowwise_fwd": 1.9670464098453522, "rowwise_bwd": 2.362079918384552, "global_fwd": 1.9469596445560455, "global_bwd": 2.32585147023201, "x_quantize_rowwise": 0.6000921130180359, "g_quantize_rowwise": 0.15481188893318176, "w_quantize_rowwise": 0.04725530743598938, "w_quantize_colwise_transpose": 0.5976222455501556, "w_quantize_global": 0.13619661331176758, "w_quantize_global_transpose": 0.14815852046012878, "cast_x": 1.2261345982551575, "cast_g": 0.3117173910140991, "cast_w": 0.08279457688331604, "time_standard": 13.17112147808075, "time_rowwise": 9.965915232896805, "time_global": 9.549077600240707} +{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.787477552890778, "standard_gw": 8.533861488103867, "standard_gx": 8.979786187410355, "rowwise_fwd": 4.741787910461426, "rowwise_bwd": 3.871854394674301, "global_fwd": 4.674319177865982, "global_bwd": 3.9110779762268066, "x_quantize_rowwise": 0.3025829792022705, "g_quantize_rowwise": 1.1898204684257507, "w_quantize_rowwise": 0.043705105781555176, "w_quantize_colwise_transpose": 0.33997371792793274, "w_quantize_global": 0.13592839241027832, "w_quantize_global_transpose": 0.14724954962730408, "cast_x": 0.6160177290439606, "cast_g": 2.4440810084342957, "cast_w": 0.08280575275421143, "time_standard": 27.301125228405, "time_rowwise": 19.023586064577103, "time_global": 18.89484003186226} +{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.461769670248032, "standard_gw": 8.428700268268585, "standard_gx": 9.447630494832993, "rowwise_fwd": 3.881257027387619, "rowwise_bwd": 4.7471001744270325, "global_fwd": 3.9101652801036835, "global_bwd": 4.662122577428818, "x_quantize_rowwise": 1.1892355978488922, "g_quantize_rowwise": 0.3024376928806305, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.5982778966426849, "w_quantize_global": 0.13624131679534912, "w_quantize_global_transpose": 0.1484602689743042, "cast_x": 2.4463236331939697, "cast_g": 0.6163865327835083, "cast_w": 0.08278340101242065, "time_standard": 26.33810043334961, "time_rowwise": 19.194088876247406, "time_global": 18.777363002300262} +{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.699689000844955, "standard_gw": 16.89574122428894, "standard_gx": 17.907552421092987, "rowwise_fwd": 9.453803300857544, "rowwise_bwd": 7.8153833746910095, "global_fwd": 9.313825517892838, "global_bwd": 7.8215524554252625, "x_quantize_rowwise": 0.5986690521240234, "g_quantize_rowwise": 2.368006855249405, "w_quantize_rowwise": 0.043682754039764404, "w_quantize_colwise_transpose": 0.3406330943107605, "w_quantize_global": 0.13626739382743835, "w_quantize_global_transpose": 0.14715641736984253, "cast_x": 1.2262165546417236, "cast_g": 4.8834048211574554, "cast_w": 0.08272379636764526, "time_standard": 54.50298264622688, "time_rowwise": 37.51591965556145, "time_global": 37.28121891617775} +{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.66700127720833, "standard_gw": 18.56840029358864, "standard_gx": 18.049821257591248, "rowwise_fwd": 7.742393761873245, "rowwise_bwd": 9.479016065597534, "global_fwd": 7.806576788425446, "global_bwd": 9.328477084636688, "x_quantize_rowwise": 2.368297427892685, "g_quantize_rowwise": 0.5978643894195557, "w_quantize_rowwise": 0.047303736209869385, "w_quantize_colwise_transpose": 0.5982741713523865, "w_quantize_global": 0.13678893446922302, "w_quantize_global_transpose": 0.1488029956817627, "cast_x": 4.880513995885849, "cast_g": 1.2248307466506958, "cast_w": 0.08270144462585449, "time_standard": 55.285222828388214, "time_rowwise": 39.401549845933914, "time_global": 38.955207914114} +{"repeat": 64, "batch_size": 1024, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 0.529509037733078, "standard_gw": 0.5781911313533783, "standard_gx": 0.6095841526985168, "rowwise_fwd": 0.2811029553413391, "rowwise_bwd": 0.3345906734466553, "global_fwd": 0.27928128838539124, "global_bwd": 0.33126771450042725, "x_quantize_rowwise": 0.025760382413864136, "g_quantize_rowwise": 0.06494298577308655, "w_quantize_rowwise": 0.15570968389511108, "w_quantize_colwise_transpose": 1.6086548566818237, "w_quantize_global": 0.481434166431427, "w_quantize_global_transpose": 0.505443662405014, "cast_x": 0.01582130789756775, "cast_g": 0.08295103907585144, "cast_w": 0.311531126499176, "time_standard": 1.7172843217849731, "time_rowwise": 3.048952668905258, "time_global": 2.2663213312625885} +{"repeat": 64, "batch_size": 1024, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 0.5729459226131439, "standard_gw": 0.5789846181869507, "standard_gx": 0.5775243043899536, "rowwise_fwd": 0.36711618304252625, "rowwise_bwd": 0.2913735806941986, "global_fwd": 0.33703818917274475, "global_bwd": 0.2821236848831177, "x_quantize_rowwise": 0.064849853515625, "g_quantize_rowwise": 0.025060027837753296, "w_quantize_rowwise": 0.22537633776664734, "w_quantize_colwise_transpose": 3.6401040852069855, "w_quantize_global": 0.4818551242351532, "w_quantize_global_transpose": 0.5101114511489868, "cast_x": 0.08286535739898682, "cast_g": 0.015828758478164673, "cast_w": 0.3114677965641022, "time_standard": 1.7294548451900482, "time_rowwise": 5.192864686250687, "time_global": 2.2800229489803314} +{"repeat": 64, "batch_size": 2048, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 1.1735819280147552, "standard_gw": 1.121576875448227, "standard_gx": 1.1242404580116272, "rowwise_fwd": 0.5535706877708435, "rowwise_bwd": 0.5567893385887146, "global_fwd": 0.5486570298671722, "global_bwd": 0.551365315914154, "x_quantize_rowwise": 0.02710893750190735, "g_quantize_rowwise": 0.11784210801124573, "w_quantize_rowwise": 0.15565752983093262, "w_quantize_colwise_transpose": 1.607745885848999, "w_quantize_global": 0.4824437201023102, "w_quantize_global_transpose": 0.5060508847236633, "cast_x": 0.03808736801147461, "cast_g": 0.15912577509880066, "cast_w": 0.31150132417678833, "time_standard": 3.4193992614746094, "time_rowwise": 4.14029136300087, "time_global": 3.35504487156868} +{"repeat": 64, "batch_size": 2048, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 1.1169910430908203, "standard_gw": 1.1065900325775146, "standard_gx": 1.1815577745437622, "rowwise_fwd": 0.5917288362979889, "rowwise_bwd": 0.5614385008811951, "global_fwd": 0.5646944046020508, "global_bwd": 0.5500949919223785, "x_quantize_rowwise": 0.118207186460495, "g_quantize_rowwise": 0.025041401386260986, "w_quantize_rowwise": 0.22566691040992737, "w_quantize_colwise_transpose": 3.635551780462265, "w_quantize_global": 0.4815608263015747, "w_quantize_global_transpose": 0.509701669216156, "cast_x": 0.15912950038909912, "cast_g": 0.03797560930252075, "cast_w": 0.3114044666290283, "time_standard": 3.405138850212097, "time_rowwise": 6.264224648475647, "time_global": 3.3558905124664307} +{"repeat": 64, "batch_size": 4096, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 2.3259930312633514, "standard_gw": 2.1472275257110596, "standard_gx": 2.213582396507263, "rowwise_fwd": 1.0509602725505829, "rowwise_bwd": 0.9888559579849243, "global_fwd": 1.0398179292678833, "global_bwd": 0.9887740015983582, "x_quantize_rowwise": 0.04647299647331238, "g_quantize_rowwise": 0.22570788860321045, "w_quantize_rowwise": 0.1554824411869049, "w_quantize_colwise_transpose": 1.610085368156433, "w_quantize_global": 0.48134103417396545, "w_quantize_global_transpose": 0.5054809153079987, "cast_x": 0.08297711610794067, "cast_g": 0.3115646541118622, "cast_w": 0.31159818172454834, "time_standard": 6.686802953481674, "time_rowwise": 6.224792450666428, "time_global": 5.434822291135788} +{"repeat": 64, "batch_size": 4096, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 2.19760462641716, "standard_gw": 2.2860951721668243, "standard_gx": 2.290956676006317, "rowwise_fwd": 1.0311491787433624, "rowwise_bwd": 1.0555200278759003, "global_fwd": 0.9858310222625732, "global_bwd": 1.0394863784313202, "x_quantize_rowwise": 0.22591277956962585, "g_quantize_rowwise": 0.046234577894210815, "w_quantize_rowwise": 0.22603943943977356, "w_quantize_colwise_transpose": 3.628809005022049, "w_quantize_global": 0.4819147288799286, "w_quantize_global_transpose": 0.5104243755340576, "cast_x": 0.3114528954029083, "cast_g": 0.08296966552734375, "cast_w": 0.3116317093372345, "time_standard": 6.7746564745903015, "time_rowwise": 8.499760180711746, "time_global": 5.575899034738541} +{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.633370786905289, "standard_gw": 4.397690296173096, "standard_gx": 4.286538809537888, "rowwise_fwd": 2.089906483888626, "rowwise_bwd": 1.9657425582408905, "global_fwd": 2.0679645240306854, "global_bwd": 1.9629858434200287, "x_quantize_rowwise": 0.08271634578704834, "g_quantize_rowwise": 0.43905526399612427, "w_quantize_rowwise": 0.1551508903503418, "w_quantize_colwise_transpose": 1.6106180846691132, "w_quantize_global": 0.48185884952545166, "w_quantize_global_transpose": 0.506274402141571, "cast_x": 0.15918537974357605, "cast_g": 0.6163418292999268, "cast_w": 0.311531126499176, "time_standard": 13.317599892616272, "time_rowwise": 10.74087992310524, "time_global": 9.938545525074005} +{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.424266517162323, "standard_gw": 4.391487687826157, "standard_gx": 4.61186096072197, "rowwise_fwd": 1.9874684512615204, "rowwise_bwd": 2.093140035867691, "global_fwd": 1.9647255539894104, "global_bwd": 2.06940621137619, "x_quantize_rowwise": 0.43999403715133667, "g_quantize_rowwise": 0.08271634578704834, "w_quantize_rowwise": 0.22581592202186584, "w_quantize_colwise_transpose": 3.631964325904846, "w_quantize_global": 0.4821456968784332, "w_quantize_global_transpose": 0.5102343857288361, "cast_x": 0.6164386868476868, "cast_g": 0.1591108739376068, "cast_w": 0.31154975295066833, "time_standard": 13.42761516571045, "time_rowwise": 12.852586805820465, "time_global": 9.940709918737411} +{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.229827672243118, "standard_gw": 8.319318294525146, "standard_gx": 8.652344346046448, "rowwise_fwd": 4.163607954978943, "rowwise_bwd": 3.778301179409027, "global_fwd": 4.121184349060059, "global_bwd": 3.7708766758441925, "x_quantize_rowwise": 0.1553669571876526, "g_quantize_rowwise": 0.8715838193893433, "w_quantize_rowwise": 0.15540048480033875, "w_quantize_colwise_transpose": 1.6092769801616669, "w_quantize_global": 0.4813969135284424, "w_quantize_global_transpose": 0.5070343613624573, "cast_x": 0.31150132417678833, "cast_g": 1.2259706854820251, "cast_w": 0.311482697725296, "time_standard": 26.201490312814713, "time_rowwise": 19.052855670452118, "time_global": 18.226761370897293} +{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.577890694141388, "standard_gw": 9.073298424482346, "standard_gx": 9.210295975208282, "rowwise_fwd": 3.7784352898597717, "rowwise_bwd": 4.165928810834885, "global_fwd": 3.7702471017837524, "global_bwd": 4.121150821447372, "x_quantize_rowwise": 0.868629664182663, "g_quantize_rowwise": 0.1554340124130249, "w_quantize_rowwise": 0.22614002227783203, "w_quantize_colwise_transpose": 3.6367811262607574, "w_quantize_global": 0.4828609526157379, "w_quantize_global_transpose": 0.510137528181076, "cast_x": 1.2258104979991913, "cast_g": 0.31299516558647156, "cast_w": 0.3114677965641022, "time_standard": 26.861485093832016, "time_rowwise": 21.90464735031128, "time_global": 18.981758505105972} +{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.52763444185257, "standard_gw": 17.835520207881927, "standard_gx": 17.375655472278595, "rowwise_fwd": 8.35346058011055, "rowwise_bwd": 7.584303617477417, "global_fwd": 8.300606161355972, "global_bwd": 7.550913840532303, "x_quantize_rowwise": 0.3016740083694458, "g_quantize_rowwise": 1.7321519553661346, "w_quantize_rowwise": 0.15538185834884644, "w_quantize_colwise_transpose": 1.6110800206661224, "w_quantize_global": 0.4815198481082916, "w_quantize_global_transpose": 0.5066357553005219, "cast_x": 0.6163753569126129, "cast_g": 2.4452805519104004, "cast_w": 0.31156837940216064, "time_standard": 53.73881012201309, "time_rowwise": 37.573572248220444, "time_global": 36.7090217769146} +{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 18.073823302984238, "standard_gw": 16.71283319592476, "standard_gx": 18.46104860305786, "rowwise_fwd": 7.542364299297333, "rowwise_bwd": 8.374195545911789, "global_fwd": 7.5644850730896, "global_bwd": 8.26016440987587, "x_quantize_rowwise": 1.7326027154922485, "g_quantize_rowwise": 0.30233338475227356, "w_quantize_rowwise": 0.2259574830532074, "w_quantize_colwise_transpose": 3.634512424468994, "w_quantize_global": 0.48204511404037476, "w_quantize_global_transpose": 0.5093887448310852, "cast_x": 2.445656806230545, "cast_g": 0.6163381040096283, "cast_w": 0.31144917011260986, "time_standard": 53.24770510196686, "time_rowwise": 38.524799048900604, "time_global": 35.56385263800621} +{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.123402416706085, "standard_gw": 32.68447890877724, "standard_gx": 34.13737937808037, "rowwise_fwd": 16.65867120027542, "rowwise_bwd": 15.004873275756836, "global_fwd": 16.536589711904526, "global_bwd": 14.949381351470947, "x_quantize_rowwise": 0.5952902138233185, "g_quantize_rowwise": 3.4581348299980164, "w_quantize_rowwise": 0.15559792518615723, "w_quantize_colwise_transpose": 1.6055963933467865, "w_quantize_global": 0.48203766345977783, "w_quantize_global_transpose": 0.5048215389251709, "cast_x": 1.2256354093551636, "cast_g": 4.875503480434418, "cast_w": 0.3110244870185852, "time_standard": 102.94526070356369, "time_rowwise": 70.16264274716377, "time_global": 69.210734218359} +{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.0223146378994, "standard_gw": 32.84081444144249, "standard_gx": 35.984884947538376, "rowwise_fwd": 15.018381178379059, "rowwise_bwd": 16.69919490814209, "global_fwd": 14.942582696676254, "global_bwd": 16.529250890016556, "x_quantize_rowwise": 3.442291170358658, "g_quantize_rowwise": 0.5951747298240662, "w_quantize_rowwise": 0.22576376795768738, "w_quantize_colwise_transpose": 3.621157258749008, "w_quantize_global": 0.48135966062545776, "w_quantize_global_transpose": 0.5095489323139191, "cast_x": 4.875205457210541, "cast_g": 1.2237727642059326, "cast_w": 0.3110431134700775, "time_standard": 103.84801402688026, "time_rowwise": 72.44277745485306, "time_global": 69.3410225212574} +{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 72.33698666095734, "standard_gw": 71.31465151906013, "standard_gx": 69.32922825217247, "rowwise_fwd": 33.37707370519638, "rowwise_bwd": 30.1642008125782, "global_fwd": 33.002063632011414, "global_bwd": 30.003495514392853, "x_quantize_rowwise": 1.1819563806056976, "g_quantize_rowwise": 6.896954029798508, "w_quantize_rowwise": 0.15557929873466492, "w_quantize_colwise_transpose": 1.6083605587482452, "w_quantize_global": 0.48125162720680237, "w_quantize_global_transpose": 0.5055665969848633, "cast_x": 2.442535012960434, "cast_g": 9.750165045261383, "cast_w": 0.31094998121261597, "time_standard": 212.98086643218994, "time_rowwise": 144.69877630472183, "time_global": 143.38593930006027} +{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 70.24158909916878, "standard_gw": 72.03734293580055, "standard_gx": 72.01339676976204, "rowwise_fwd": 30.072908848524094, "rowwise_bwd": 33.376410603523254, "global_fwd": 29.965493828058243, "global_bwd": 33.01112726330757, "x_quantize_rowwise": 6.894122809171677, "g_quantize_rowwise": 1.1817142367362976, "w_quantize_rowwise": 0.22567808628082275, "w_quantize_colwise_transpose": 3.616899251937866, "w_quantize_global": 0.4819147288799286, "w_quantize_global_transpose": 0.5107112228870392, "cast_x": 9.750377386808395, "cast_g": 2.4411343038082123, "cast_w": 0.31099095940589905, "time_standard": 214.29232880473137, "time_rowwise": 147.40507677197456, "time_global": 144.0824270248413} +{"repeat": 64, "batch_size": 65536, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 138.23134452104568, "standard_gw": 131.48364424705505, "standard_gx": 141.09868183732033, "rowwise_fwd": 65.38830325007439, "rowwise_bwd": 58.39048698544502, "global_fwd": 65.2194656431675, "global_bwd": 58.58004465699196, "x_quantize_rowwise": 1.1899955570697784, "g_quantize_rowwise": 6.623774766921997, "w_quantize_rowwise": 0.5935952067375183, "w_quantize_colwise_transpose": 24.08137544989586, "w_quantize_global": 1.740824431180954, "w_quantize_global_transpose": 1.8664970993995667, "cast_x": 2.413548529148102, "cast_g": 9.63655486702919, "cast_w": 1.1956281960010529, "time_standard": 410.81367060542107, "time_rowwise": 287.7511754631996, "time_global": 266.7042464017868} +{"repeat": 64, "batch_size": 65536, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 141.08363911509514, "standard_gw": 133.26667994260788, "standard_gx": 136.0350362956524, "rowwise_fwd": 58.49892646074295, "rowwise_bwd": 65.34496694803238, "global_fwd": 58.73573571443558, "global_bwd": 65.30505418777466, "x_quantize_rowwise": 6.648071110248566, "g_quantize_rowwise": 1.1903978884220123, "w_quantize_rowwise": 0.8329600095748901, "w_quantize_colwise_transpose": 15.297897160053253, "w_quantize_global": 1.7403066158294678, "w_quantize_global_transpose": 1.8791332840919495, "cast_x": 9.636614471673965, "cast_g": 2.4122819304466248, "cast_w": 1.1954344809055328, "time_standard": 410.3853553533554, "time_rowwise": 281.07989951968193, "time_global": 268.7653787434101} +{"repeat": 64, "batch_size": 1024, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 2.535879611968994, "standard_gw": 2.249978482723236, "standard_gx": 2.2262558341026306, "rowwise_fwd": 1.085665076971054, "rowwise_bwd": 1.069542020559311, "global_fwd": 1.0830685496330261, "global_bwd": 1.0597631335258484, "x_quantize_rowwise": 0.02650916576385498, "g_quantize_rowwise": 0.1200847327709198, "w_quantize_rowwise": 0.5937665700912476, "w_quantize_colwise_transpose": 23.926906287670135, "w_quantize_global": 1.7397291958332062, "w_quantize_global_transpose": 1.8652454018592834, "cast_x": 0.03688782453536987, "cast_g": 0.15725940465927124, "cast_w": 1.1969134211540222, "time_standard": 7.012113928794861, "time_rowwise": 29.07245233654976, "time_global": 8.144378662109375} +{"repeat": 64, "batch_size": 1024, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 2.245493233203888, "standard_gw": 2.2966675460338593, "standard_gx": 2.216015011072159, "rowwise_fwd": 1.1000856757164001, "rowwise_bwd": 1.0902360081672668, "global_fwd": 1.0597333312034607, "global_bwd": 1.0812543332576752, "x_quantize_rowwise": 0.11992454528808594, "g_quantize_rowwise": 0.026784837245941162, "w_quantize_rowwise": 0.8310377597808838, "w_quantize_colwise_transpose": 15.30550792813301, "w_quantize_global": 1.7401352524757385, "w_quantize_global_transpose": 1.8841177225112915, "cast_x": 0.1573599874973297, "cast_g": 0.03676116466522217, "cast_w": 1.195952296257019, "time_standard": 6.758175790309906, "time_rowwise": 20.770244300365448, "time_global": 8.208617568016052} +{"repeat": 64, "batch_size": 2048, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 4.197858273983002, "standard_gw": 4.288379102945328, "standard_gx": 4.155721515417099, "rowwise_fwd": 2.0567886531352997, "rowwise_bwd": 1.9073635339736938, "global_fwd": 2.0506344735622406, "global_bwd": 1.9086338579654694, "x_quantize_rowwise": 0.04758685827255249, "g_quantize_rowwise": 0.22284314036369324, "w_quantize_rowwise": 0.5935467779636383, "w_quantize_colwise_transpose": 23.935042321681976, "w_quantize_global": 1.7397813498973846, "w_quantize_global_transpose": 1.8662959337234497, "cast_x": 0.08194148540496826, "cast_g": 0.3077872097492218, "cast_w": 1.1968687176704407, "time_standard": 12.641958892345428, "time_rowwise": 33.05155038833618, "time_global": 12.124154716730118} +{"repeat": 64, "batch_size": 2048, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 4.126541316509247, "standard_gw": 4.309836775064468, "standard_gx": 4.117351025342941, "rowwise_fwd": 1.9266381859779358, "rowwise_bwd": 2.0577237010002136, "global_fwd": 1.908630132675171, "global_bwd": 2.0505934953689575, "x_quantize_rowwise": 0.22304058074951172, "g_quantize_rowwise": 0.04766136407852173, "w_quantize_rowwise": 0.8306317031383514, "w_quantize_colwise_transpose": 15.309855341911316, "w_quantize_global": 1.7415396869182587, "w_quantize_global_transpose": 1.8827766180038452, "cast_x": 0.30782073736190796, "cast_g": 0.08186325430870056, "cast_w": 1.1955127120018005, "time_standard": 12.553729116916656, "time_rowwise": 24.70538765192032, "time_global": 12.164078652858734} +{"repeat": 64, "batch_size": 4096, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 8.298952132463455, "standard_gw": 8.345257490873337, "standard_gx": 8.647706359624863, "rowwise_fwd": 4.106882959604263, "rowwise_bwd": 3.8046911358833313, "global_fwd": 4.09451499581337, "global_bwd": 3.8078874349594116, "x_quantize_rowwise": 0.08447840809822083, "g_quantize_rowwise": 0.4291348159313202, "w_quantize_rowwise": 0.5934201180934906, "w_quantize_colwise_transpose": 23.843105882406235, "w_quantize_global": 1.7399191856384277, "w_quantize_global_transpose": 1.8653236329555511, "cast_x": 0.1577921211719513, "cast_g": 0.6089024245738983, "cast_w": 1.1952444911003113, "time_standard": 25.291915982961655, "time_rowwise": 41.2069708108902, "time_global": 20.366515964269638} +{"repeat": 64, "batch_size": 4096, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 8.323360234498978, "standard_gw": 8.433796465396881, "standard_gx": 8.236430585384369, "rowwise_fwd": 3.8114115595817566, "rowwise_bwd": 4.106346517801285, "global_fwd": 3.8080140948295593, "global_bwd": 4.094675183296204, "x_quantize_rowwise": 0.4288516938686371, "g_quantize_rowwise": 0.08437782526016235, "w_quantize_rowwise": 0.8310228586196899, "w_quantize_colwise_transpose": 15.306610614061356, "w_quantize_global": 1.741155982017517, "w_quantize_global_transpose": 1.8809586763381958, "cast_x": 0.6091706454753876, "cast_g": 0.157233327627182, "cast_w": 1.1953115463256836, "time_standard": 24.993587285280228, "time_rowwise": 33.00241753458977, "time_global": 20.471829921007156} +{"repeat": 64, "batch_size": 8192, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 16.656354069709778, "standard_gw": 17.066240310668945, "standard_gx": 17.252348363399506, "rowwise_fwd": 8.220307528972626, "rowwise_bwd": 7.2372183203697205, "global_fwd": 8.2036592066288, "global_bwd": 7.236208766698837, "x_quantize_rowwise": 0.15832111239433289, "g_quantize_rowwise": 0.8406005799770355, "w_quantize_rowwise": 0.5935393273830414, "w_quantize_colwise_transpose": 23.86143058538437, "w_quantize_global": 1.7401576042175293, "w_quantize_global_transpose": 1.8653534352779388, "cast_x": 0.3079026937484741, "cast_g": 1.209162175655365, "cast_w": 1.1951625347137451, "time_standard": 50.97494274377823, "time_rowwise": 57.97765776515007, "time_global": 37.11054101586342} +{"repeat": 64, "batch_size": 8192, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 17.398890107870102, "standard_gw": 18.470749258995056, "standard_gx": 16.520217061042786, "rowwise_fwd": 7.235266268253326, "rowwise_bwd": 8.207589387893677, "global_fwd": 7.235914468765259, "global_bwd": 8.204508572816849, "x_quantize_rowwise": 0.8409880101680756, "g_quantize_rowwise": 0.15821680426597595, "w_quantize_rowwise": 0.8324198424816132, "w_quantize_colwise_transpose": 15.305522829294205, "w_quantize_global": 1.7396919429302216, "w_quantize_global_transpose": 1.8805749714374542, "cast_x": 1.2103468179702759, "cast_g": 0.30729547142982483, "cast_w": 1.1953599750995636, "time_standard": 52.389856427907944, "time_rowwise": 51.05075240135193, "time_global": 38.53064402937889} +{"repeat": 64, "batch_size": 16384, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 33.533211797475815, "standard_gw": 33.00020843744278, "standard_gx": 34.614477306604385, "rowwise_fwd": 16.364943236112595, "rowwise_bwd": 14.551006257534027, "global_fwd": 16.33496955037117, "global_bwd": 14.513172209262848, "x_quantize_rowwise": 0.3053396940231323, "g_quantize_rowwise": 1.6693994402885437, "w_quantize_rowwise": 0.5936138331890106, "w_quantize_colwise_transpose": 23.89485388994217, "w_quantize_global": 1.741711050271988, "w_quantize_global_transpose": 1.8656104803085327, "cast_x": 0.6089657545089722, "cast_g": 2.4122074246406555, "cast_w": 1.1951886117458344, "time_standard": 101.14789754152298, "time_rowwise": 90.37936478853226, "time_global": 69.430410861969} +{"repeat": 64, "batch_size": 16384, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 33.65536406636238, "standard_gw": 33.02193805575371, "standard_gx": 33.10496360063553, "rowwise_fwd": 14.54489678144455, "rowwise_bwd": 16.36252924799919, "global_fwd": 14.50401172041893, "global_bwd": 16.33254438638687, "x_quantize_rowwise": 1.6695670783519745, "g_quantize_rowwise": 0.3054291009902954, "w_quantize_rowwise": 0.83121657371521, "w_quantize_colwise_transpose": 15.305932611227036, "w_quantize_global": 1.7382949590682983, "w_quantize_global_transpose": 1.880194991827011, "cast_x": 2.412091940641403, "cast_g": 0.6079599261283875, "cast_w": 1.1950358748435974, "time_standard": 99.78226572275162, "time_rowwise": 82.04150944948196, "time_global": 69.45198029279709} +{"repeat": 64, "batch_size": 32768, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 67.96638667583466, "standard_gw": 67.99514591693878, "standard_gx": 69.66376304626465, "rowwise_fwd": 33.51752087473869, "rowwise_bwd": 29.131878167390823, "global_fwd": 32.65715390443802, "global_bwd": 29.13403883576393, "x_quantize_rowwise": 0.6002038717269897, "g_quantize_rowwise": 3.3336542546749115, "w_quantize_rowwise": 0.5934685468673706, "w_quantize_colwise_transpose": 23.92345294356346, "w_quantize_global": 1.7405375838279724, "w_quantize_global_transpose": 1.8656738102436066, "cast_x": 1.2112446129322052, "cast_g": 4.81804832816124, "cast_w": 1.1952146887779236, "time_standard": 205.6252956390381, "time_rowwise": 159.09532457590103, "time_global": 137.3264081776142} +{"repeat": 64, "batch_size": 32768, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 68.2341456413269, "standard_gw": 65.5074268579483, "standard_gx": 67.13805347681046, "rowwise_fwd": 29.153641313314438, "rowwise_bwd": 32.71844983100891, "global_fwd": 29.124341905117035, "global_bwd": 32.65979886054993, "x_quantize_rowwise": 3.3318176865577698, "g_quantize_rowwise": 0.6004795432090759, "w_quantize_rowwise": 0.8309967815876007, "w_quantize_colwise_transpose": 15.305690467357635, "w_quantize_global": 1.7405711114406586, "w_quantize_global_transpose": 1.8802620470523834, "cast_x": 4.8183538019657135, "cast_g": 1.2096390128135681, "cast_w": 1.1951103806495667, "time_standard": 200.87962597608566, "time_rowwise": 147.44850248098373, "time_global": 134.84469801187515} +{"repeat": 64, "batch_size": 1024, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.07764250040054321, "standard_gw": 0.07398426532745361, "standard_gx": 0.08482858538627625, "rowwise_fwd": 0.05266070365905762, "rowwise_bwd": 0.04478543996810913, "global_fwd": 0.052012503147125244, "global_bwd": 0.044364482164382935, "x_quantize_rowwise": 0.02640858292579651, "g_quantize_rowwise": 0.02539902925491333, "w_quantize_rowwise": 0.026457011699676514, "w_quantize_colwise_transpose": 0.17770379781723022, "w_quantize_global": 0.07440149784088135, "w_quantize_global_transpose": 0.08142739534378052, "cast_x": 0.008150935173034668, "cast_g": 0.022415071725845337, "cast_w": 0.03479421138763428, "time_standard": 0.23645535111427307, "time_rowwise": 0.42739883065223694, "time_global": 0.3779977560043335} +{"repeat": 64, "batch_size": 1024, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.08524581789970398, "standard_gw": 0.07383152842521667, "standard_gx": 0.07564574480056763, "rowwise_fwd": 0.04478171467781067, "rowwise_bwd": 0.052671879529953, "global_fwd": 0.04452839493751526, "global_bwd": 0.05219504237174988, "x_quantize_rowwise": 0.025328248739242554, "g_quantize_rowwise": 0.027123838663101196, "w_quantize_rowwise": 0.025607645511627197, "w_quantize_colwise_transpose": 0.17121434211730957, "w_quantize_global": 0.07916614413261414, "w_quantize_global_transpose": 0.08177384734153748, "cast_x": 0.022619962692260742, "cast_g": 0.008556991815567017, "cast_w": 0.034421682357788086, "time_standard": 0.23472309112548828, "time_rowwise": 0.42055919766426086, "time_global": 0.3839470446109772} +{"repeat": 64, "batch_size": 2048, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.13731792569160461, "standard_gw": 0.13414397835731506, "standard_gx": 0.14049187302589417, "rowwise_fwd": 0.10158121585845947, "rowwise_bwd": 0.07804110646247864, "global_fwd": 0.09908527135848999, "global_bwd": 0.07766112685203552, "x_quantize_rowwise": 0.026516616344451904, "g_quantize_rowwise": 0.03666803240776062, "w_quantize_rowwise": 0.024981796741485596, "w_quantize_colwise_transpose": 0.17706677317619324, "w_quantize_global": 0.07443130016326904, "w_quantize_global_transpose": 0.07870793342590332, "cast_x": 0.01224130392074585, "cast_g": 0.05828961730003357, "cast_w": 0.03501400351524353, "time_standard": 0.41195377707481384, "time_rowwise": 0.5789995193481445, "time_global": 0.5272142589092255} +{"repeat": 64, "batch_size": 2048, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.14651194214820862, "standard_gw": 0.14011189341545105, "standard_gx": 0.140264630317688, "rowwise_fwd": 0.081576406955719, "rowwise_bwd": 0.10671466588973999, "global_fwd": 0.08158013224601746, "global_bwd": 0.10219961404800415, "x_quantize_rowwise": 0.03775954246520996, "g_quantize_rowwise": 0.026103109121322632, "w_quantize_rowwise": 0.02656877040863037, "w_quantize_colwise_transpose": 0.17822161316871643, "w_quantize_global": 0.07506832480430603, "w_quantize_global_transpose": 0.07928535342216492, "cast_x": 0.05893409252166748, "cast_g": 0.012326985597610474, "cast_w": 0.03498047590255737, "time_standard": 0.42688846588134766, "time_rowwise": 0.5970560014247894, "time_global": 0.5421079695224762} +{"repeat": 64, "batch_size": 4096, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.2734065055847168, "standard_gw": 0.25558844208717346, "standard_gx": 0.29174983501434326, "rowwise_fwd": 0.173322856426239, "rowwise_bwd": 0.1515895128250122, "global_fwd": 0.17048418521881104, "global_bwd": 0.1506991684436798, "x_quantize_rowwise": 0.025950372219085693, "g_quantize_rowwise": 0.0653192400932312, "w_quantize_rowwise": 0.027138739824295044, "w_quantize_colwise_transpose": 0.17699971795082092, "w_quantize_global": 0.07373467087745667, "w_quantize_global_transpose": 0.07901713252067566, "cast_x": 0.02214685082435608, "cast_g": 0.11127442121505737, "cast_w": 0.03481656312942505, "time_standard": 0.8207447826862335, "time_rowwise": 0.8759088814258575, "time_global": 0.8207932114601135} +{"repeat": 64, "batch_size": 4096, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.27839839458465576, "standard_gw": 0.2537444233894348, "standard_gx": 0.28207898139953613, "rowwise_fwd": 0.16542896628379822, "rowwise_bwd": 0.18540024757385254, "global_fwd": 0.15722215175628662, "global_bwd": 0.17368420958518982, "x_quantize_rowwise": 0.06661936640739441, "g_quantize_rowwise": 0.027049332857131958, "w_quantize_rowwise": 0.025507062673568726, "w_quantize_colwise_transpose": 0.1741349697113037, "w_quantize_global": 0.07463246583938599, "w_quantize_global_transpose": 0.07879361510276794, "cast_x": 0.11301413178443909, "cast_g": 0.023346394300460815, "cast_w": 0.03505498170852661, "time_standard": 0.8142217993736267, "time_rowwise": 0.8978843688964844, "time_global": 0.8317455649375916} +{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5755424499511719, "standard_gw": 0.5219094455242157, "standard_gx": 0.5992203950881958, "rowwise_fwd": 0.33193081617355347, "rowwise_bwd": 0.295441597700119, "global_fwd": 0.32791122794151306, "global_bwd": 0.2906434237957001, "x_quantize_rowwise": 0.0337548553943634, "g_quantize_rowwise": 0.1225881278514862, "w_quantize_rowwise": 0.024937093257904053, "w_quantize_colwise_transpose": 0.17729029059410095, "w_quantize_global": 0.0730752944946289, "w_quantize_global_transpose": 0.07835403084754944, "cast_x": 0.058166682720184326, "cast_g": 0.21592900156974792, "cast_w": 0.03454089164733887, "time_standard": 1.6966722905635834, "time_rowwise": 1.5078522264957428, "time_global": 1.4482364058494568} +{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5104020237922668, "standard_gw": 0.5302242934703827, "standard_gx": 0.5842559039592743, "rowwise_fwd": 0.32220035791397095, "rowwise_bwd": 0.3576017916202545, "global_fwd": 0.2939775586128235, "global_bwd": 0.3313682973384857, "x_quantize_rowwise": 0.12369826436042786, "g_quantize_rowwise": 0.03423169255256653, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.16975775361061096, "w_quantize_global": 0.0768713653087616, "w_quantize_global_transpose": 0.08094683289527893, "cast_x": 0.21589547395706177, "cast_g": 0.05825608968734741, "cast_w": 0.03466010093688965, "time_standard": 1.6248822212219238, "time_rowwise": 1.5642158687114716, "time_global": 1.4713183045387268} +{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.194491982460022, "standard_gw": 1.0553859174251556, "standard_gx": 1.0726377367973328, "rowwise_fwd": 0.636763870716095, "rowwise_bwd": 0.5154944956302643, "global_fwd": 0.6281323730945587, "global_bwd": 0.5117170512676239, "x_quantize_rowwise": 0.062175095081329346, "g_quantize_rowwise": 0.23643672466278076, "w_quantize_rowwise": 0.025566667318344116, "w_quantize_colwise_transpose": 0.17768144607543945, "w_quantize_global": 0.07302314043045044, "w_quantize_global_transpose": 0.07866695523262024, "cast_x": 0.11140108108520508, "cast_g": 0.42498111724853516, "cast_w": 0.034831464290618896, "time_standard": 3.3225156366825104, "time_rowwise": 2.7095042169094086, "time_global": 2.645537257194519} +{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.0797791182994843, "standard_gw": 1.062549650669098, "standard_gx": 1.104947179555893, "rowwise_fwd": 0.5390122532844543, "rowwise_bwd": 0.6449781358242035, "global_fwd": 0.5145668983459473, "global_bwd": 0.6276033818721771, "x_quantize_rowwise": 0.23603439331054688, "g_quantize_rowwise": 0.062234699726104736, "w_quantize_rowwise": 0.02781301736831665, "w_quantize_colwise_transpose": 0.1703314483165741, "w_quantize_global": 0.07431954145431519, "w_quantize_global_transpose": 0.08028373122215271, "cast_x": 0.4249885678291321, "cast_g": 0.1113303005695343, "cast_w": 0.0348016619682312, "time_standard": 3.247275948524475, "time_rowwise": 2.742953598499298, "time_global": 2.657592296600342} +{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.392485737800598, "standard_gw": 2.046734094619751, "standard_gx": 2.177651971578598, "rowwise_fwd": 1.252591609954834, "rowwise_bwd": 1.0205842554569244, "global_fwd": 1.230098307132721, "global_bwd": 1.0132193565368652, "x_quantize_rowwise": 0.11823698878288269, "g_quantize_rowwise": 0.4639141261577606, "w_quantize_rowwise": 0.02602487802505493, "w_quantize_colwise_transpose": 0.17801672220230103, "w_quantize_global": 0.07301196455955505, "w_quantize_global_transpose": 0.07893890142440796, "cast_x": 0.21591037511825562, "cast_g": 0.843394547700882, "cast_w": 0.03460049629211426, "time_standard": 6.616871803998947, "time_rowwise": 5.106102675199509, "time_global": 5.0241537392139435} +{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.205628901720047, "standard_gw": 1.9917488098144531, "standard_gx": 2.1518059074878693, "rowwise_fwd": 1.040138304233551, "rowwise_bwd": 1.2538731098175049, "global_fwd": 1.0131187736988068, "global_bwd": 1.2291893362998962, "x_quantize_rowwise": 0.46381354331970215, "g_quantize_rowwise": 0.11790916323661804, "w_quantize_rowwise": 0.027123838663101196, "w_quantize_colwise_transpose": 0.17021596431732178, "w_quantize_global": 0.0752471387386322, "w_quantize_global_transpose": 0.08159875869750977, "cast_x": 0.8433908224105835, "cast_g": 0.215873122215271, "cast_w": 0.03452599048614502, "time_standard": 6.349183619022369, "time_rowwise": 5.064822733402252, "time_global": 4.972625523805618} +{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.755370318889618, "standard_gw": 4.736289381980896, "standard_gx": 4.0378570556640625, "rowwise_fwd": 2.4783052504062653, "rowwise_bwd": 1.9634142518043518, "global_fwd": 2.435591071844101, "global_bwd": 1.9498206675052643, "x_quantize_rowwise": 0.22948533296585083, "g_quantize_rowwise": 0.9186491370201111, "w_quantize_rowwise": 0.028233975172042847, "w_quantize_colwise_transpose": 0.17858296632766724, "w_quantize_global": 0.07418543100357056, "w_quantize_global_transpose": 0.07958710193634033, "cast_x": 0.4257224500179291, "cast_g": 1.680031418800354, "cast_w": 0.03458559513092041, "time_standard": 13.529516756534576, "time_rowwise": 10.532960295677185, "time_global": 10.423608124256134} +{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.050172865390778, "standard_gw": 3.916766494512558, "standard_gx": 4.281226545572281, "rowwise_fwd": 1.9789263606071472, "rowwise_bwd": 2.477586269378662, "global_fwd": 1.9495487213134766, "global_bwd": 2.434592694044113, "x_quantize_rowwise": 0.918261706829071, "g_quantize_rowwise": 0.22961944341659546, "w_quantize_rowwise": 0.025540590286254883, "w_quantize_colwise_transpose": 0.17032772302627563, "w_quantize_global": 0.07384642958641052, "w_quantize_global_transpose": 0.08105114102363586, "cast_x": 1.679886132478714, "cast_g": 0.42508915066719055, "cast_w": 0.03442913293838501, "time_standard": 12.248165905475616, "time_rowwise": 9.717028588056564, "time_global": 9.60368663072586} +{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.53347235918045, "standard_gw": 8.138865232467651, "standard_gx": 7.9666972160339355, "rowwise_fwd": 4.984956234693527, "rowwise_bwd": 3.850068897008896, "global_fwd": 4.9025751650333405, "global_bwd": 3.820303827524185, "x_quantize_rowwise": 0.45222043991088867, "g_quantize_rowwise": 1.8290691077709198, "w_quantize_rowwise": 0.026736408472061157, "w_quantize_colwise_transpose": 0.17832592129707336, "w_quantize_global": 0.07471069693565369, "w_quantize_global_transpose": 0.08177757263183594, "cast_x": 0.8435025811195374, "cast_g": 3.3529214560985565, "cast_w": 0.03475695848464966, "time_standard": 25.639034807682037, "time_rowwise": 19.460242241621017, "time_global": 19.299522042274475} +{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 7.996037602424622, "standard_gw": 8.2748644053936, "standard_gx": 8.523400872945786, "rowwise_fwd": 3.8556940853595734, "rowwise_bwd": 4.966288805007935, "global_fwd": 3.820043057203293, "global_bwd": 4.882067441940308, "x_quantize_rowwise": 1.8279887735843658, "g_quantize_rowwise": 0.4520900547504425, "w_quantize_rowwise": 0.02676248550415039, "w_quantize_colwise_transpose": 0.17083808779716492, "w_quantize_global": 0.07691606879234314, "w_quantize_global_transpose": 0.08223950862884521, "cast_x": 3.3530443906784058, "cast_g": 0.8434318006038666, "cast_w": 0.034671276807785034, "time_standard": 24.794302880764008, "time_rowwise": 19.574526697397232, "time_global": 19.416209310293198} +{"repeat": 64, "batch_size": 1024, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.09413063526153564, "standard_gw": 0.10038167238235474, "standard_gx": 0.09725615382194519, "rowwise_fwd": 0.05979463458061218, "rowwise_bwd": 0.0525452196598053, "global_fwd": 0.059057027101516724, "global_bwd": 0.05194917321205139, "x_quantize_rowwise": 0.02664700150489807, "g_quantize_rowwise": 0.02642720937728882, "w_quantize_rowwise": 0.030562281608581543, "w_quantize_colwise_transpose": 0.2400912344455719, "w_quantize_global": 0.09407848119735718, "w_quantize_global_transpose": 0.10256841778755188, "cast_x": 0.008724629878997803, "cast_g": 0.028502196073532104, "cast_w": 0.05552172660827637, "time_standard": 0.29176846146583557, "time_rowwise": 0.5364492535591125, "time_global": 0.4611089825630188} +{"repeat": 64, "batch_size": 1024, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.09753555059432983, "standard_gw": 0.10102242231369019, "standard_gx": 0.09121373295783997, "rowwise_fwd": 0.052150338888168335, "rowwise_bwd": 0.059779733419418335, "global_fwd": 0.05161017179489136, "global_bwd": 0.05943328142166138, "x_quantize_rowwise": 0.026702880859375, "g_quantize_rowwise": 0.02469494938850403, "w_quantize_rowwise": 0.03324449062347412, "w_quantize_colwise_transpose": 0.23468583822250366, "w_quantize_global": 0.09394437074661255, "w_quantize_global_transpose": 0.10142102837562561, "cast_x": 0.028360635042190552, "cast_g": 0.008717179298400879, "cast_w": 0.05577504634857178, "time_standard": 0.28977170586586, "time_rowwise": 0.5322806537151337, "time_global": 0.4588291049003601} +{"repeat": 64, "batch_size": 2048, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.18056854605674744, "standard_gw": 0.18374621868133545, "standard_gx": 0.19219890236854553, "rowwise_fwd": 0.1150965690612793, "rowwise_bwd": 0.0903494656085968, "global_fwd": 0.11263042688369751, "global_bwd": 0.08984282612800598, "x_quantize_rowwise": 0.027067959308624268, "g_quantize_rowwise": 0.040043145418167114, "w_quantize_rowwise": 0.03063306212425232, "w_quantize_colwise_transpose": 0.24128705263137817, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.1024976372718811, "cast_x": 0.01381710171699524, "cast_g": 0.06845593452453613, "cast_w": 0.05572289228439331, "time_standard": 0.5565136671066284, "time_rowwise": 0.7282234728336334, "time_global": 0.6494410336017609} +{"repeat": 64, "batch_size": 2048, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.16536936163902283, "standard_gw": 0.19479170441627502, "standard_gx": 0.18597766757011414, "rowwise_fwd": 0.09634345769882202, "rowwise_bwd": 0.11937320232391357, "global_fwd": 0.09264424443244934, "global_bwd": 0.11524930596351624, "x_quantize_rowwise": 0.04038214683532715, "g_quantize_rowwise": 0.025559216737747192, "w_quantize_rowwise": 0.03334507346153259, "w_quantize_colwise_transpose": 0.23956596851348877, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.1020580530166626, "cast_x": 0.06891414523124695, "cast_g": 0.013861805200576782, "cast_w": 0.05607306957244873, "time_standard": 0.546138733625412, "time_rowwise": 0.7493607699871063, "time_global": 0.6651394069194794} +{"repeat": 64, "batch_size": 4096, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.36064907908439636, "standard_gw": 0.3711991012096405, "standard_gx": 0.3863237798213959, "rowwise_fwd": 0.22270530462265015, "rowwise_bwd": 0.1760348677635193, "global_fwd": 0.21781772375106812, "global_bwd": 0.17484650015830994, "x_quantize_rowwise": 0.02625212073326111, "g_quantize_rowwise": 0.07131323218345642, "w_quantize_rowwise": 0.030372291803359985, "w_quantize_colwise_transpose": 0.23974105715751648, "w_quantize_global": 0.09407475590705872, "w_quantize_global_transpose": 0.1024492084980011, "cast_x": 0.028584152460098267, "cast_g": 0.1303069293498993, "cast_w": 0.05582347512245178, "time_standard": 1.1181719601154327, "time_rowwise": 1.137617975473404, "time_global": 1.057952642440796} +{"repeat": 64, "batch_size": 4096, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.32703205943107605, "standard_gw": 0.3764517605304718, "standard_gx": 0.3938935697078705, "rowwise_fwd": 0.18771737813949585, "rowwise_bwd": 0.2374798059463501, "global_fwd": 0.1843757927417755, "global_bwd": 0.23005902767181396, "x_quantize_rowwise": 0.07155537605285645, "g_quantize_rowwise": 0.02625212073326111, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.23755058646202087, "w_quantize_global": 0.09388476610183716, "w_quantize_global_transpose": 0.10246038436889648, "cast_x": 0.13131648302078247, "cast_g": 0.028781592845916748, "cast_w": 0.05638599395751953, "time_standard": 1.0973773896694183, "time_rowwise": 1.1699534952640533, "time_global": 1.0850392282009125} +{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7961541414260864, "standard_gw": 0.7424280047416687, "standard_gx": 0.8688867092132568, "rowwise_fwd": 0.432576984167099, "rowwise_bwd": 0.34543126821517944, "global_fwd": 0.4248805344104767, "global_bwd": 0.3432855010032654, "x_quantize_rowwise": 0.03750622272491455, "g_quantize_rowwise": 0.13292208313941956, "w_quantize_rowwise": 0.030599534511566162, "w_quantize_colwise_transpose": 0.24292618036270142, "w_quantize_global": 0.09351596236228943, "w_quantize_global_transpose": 0.1026056706905365, "cast_x": 0.06843730807304382, "cast_g": 0.2539418637752533, "cast_w": 0.05568563938140869, "time_standard": 2.407468855381012, "time_rowwise": 1.9643902778625488, "time_global": 1.8771439790725708} +{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7150471210479736, "standard_gw": 0.7525831460952759, "standard_gx": 0.8075274527072906, "rowwise_fwd": 0.36595389246940613, "rowwise_bwd": 0.4404708743095398, "global_fwd": 0.3485158085823059, "global_bwd": 0.4275962710380554, "x_quantize_rowwise": 0.1329965889453888, "g_quantize_rowwise": 0.03767386078834534, "w_quantize_rowwise": 0.03295019268989563, "w_quantize_colwise_transpose": 0.23509934544563293, "w_quantize_global": 0.09398534893989563, "w_quantize_global_transpose": 0.10186433792114258, "cast_x": 0.2537667751312256, "cast_g": 0.06839632987976074, "cast_w": 0.05571544170379639, "time_standard": 2.27515771985054, "time_rowwise": 1.9977279007434845, "time_global": 1.8952153623104095} +{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6392990946769714, "standard_gw": 1.4941170811653137, "standard_gx": 1.4451220631599426, "rowwise_fwd": 0.8369758725166321, "rowwise_bwd": 0.6830468773841858, "global_fwd": 0.8197203278541565, "global_bwd": 0.6782263517379761, "x_quantize_rowwise": 0.06883591413497925, "g_quantize_rowwise": 0.2565309405326843, "w_quantize_rowwise": 0.03046169877052307, "w_quantize_colwise_transpose": 0.2430342137813568, "w_quantize_global": 0.09346380829811096, "w_quantize_global_transpose": 0.10301917791366577, "cast_x": 0.13044849038124084, "cast_g": 0.5010999739170074, "cast_w": 0.05590170621871948, "time_standard": 4.578538239002228, "time_rowwise": 3.613002598285675, "time_global": 3.5139136016368866} +{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4654621481895447, "standard_gw": 1.5012174844741821, "standard_gx": 1.5183314681053162, "rowwise_fwd": 0.7059797644615173, "rowwise_bwd": 0.8470229804515839, "global_fwd": 0.6788894534111023, "global_bwd": 0.8200779557228088, "x_quantize_rowwise": 0.2564750611782074, "g_quantize_rowwise": 0.06899237632751465, "w_quantize_rowwise": 0.03293529152870178, "w_quantize_colwise_transpose": 0.23559853434562683, "w_quantize_global": 0.09375810623168945, "w_quantize_global_transpose": 0.10203942656517029, "cast_x": 0.5010105669498444, "cast_g": 0.13037025928497314, "cast_w": 0.05577504634857178, "time_standard": 4.485011100769043, "time_rowwise": 3.648221492767334, "time_global": 3.521449863910675} +{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.236088901758194, "standard_gw": 2.8601549565792084, "standard_gx": 2.8000958263874054, "rowwise_fwd": 1.6548968851566315, "rowwise_bwd": 1.3559646904468536, "global_fwd": 1.6249343752861023, "global_bwd": 1.3474412262439728, "x_quantize_rowwise": 0.13122707605361938, "g_quantize_rowwise": 0.5038455128669739, "w_quantize_rowwise": 0.03061816096305847, "w_quantize_colwise_transpose": 0.24301931262016296, "w_quantize_global": 0.09343400597572327, "w_quantize_global_transpose": 0.10178983211517334, "cast_x": 0.25383010506629944, "cast_g": 0.9955987334251404, "cast_w": 0.05569681525230408, "time_standard": 8.896339684724808, "time_rowwise": 6.779726594686508, "time_global": 6.662826985120773} +{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.8433389961719513, "standard_gw": 2.861086279153824, "standard_gx": 3.0227042734622955, "rowwise_fwd": 1.4057457447052002, "rowwise_bwd": 1.6565024852752686, "global_fwd": 1.3475008308887482, "global_bwd": 1.6247481107711792, "x_quantize_rowwise": 0.5038045346736908, "g_quantize_rowwise": 0.13130158185958862, "w_quantize_rowwise": 0.03298744559288025, "w_quantize_colwise_transpose": 0.23539364337921143, "w_quantize_global": 0.09393692016601562, "w_quantize_global_transpose": 0.10208785533905029, "cast_x": 0.9952597320079803, "cast_g": 0.25385990738868713, "cast_w": 0.05589798092842102, "time_standard": 8.72712954878807, "time_rowwise": 6.826821714639664, "time_global": 6.664466112852097} +{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.449159234762192, "standard_gw": 6.384443491697311, "standard_gx": 5.543403327465057, "rowwise_fwd": 3.3065229654312134, "rowwise_bwd": 2.6249960064888, "global_fwd": 3.2497718930244446, "global_bwd": 2.6061534881591797, "x_quantize_rowwise": 0.25821104645729065, "g_quantize_rowwise": 0.9981803596019745, "w_quantize_rowwise": 0.030606985092163086, "w_quantize_colwise_transpose": 0.24094432592391968, "w_quantize_global": 0.09358301758766174, "w_quantize_global_transpose": 0.10264664888381958, "cast_x": 0.5018562078475952, "cast_g": 1.9840113818645477, "cast_w": 0.05584210157394409, "time_standard": 18.37700605392456, "time_rowwise": 13.843905180692673, "time_global": 13.692989945411682} +{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.508493632078171, "standard_gw": 5.689781159162521, "standard_gx": 6.020743399858475, "rowwise_fwd": 2.640843391418457, "rowwise_bwd": 3.3075474202632904, "global_fwd": 2.605751156806946, "global_bwd": 3.2674334943294525, "x_quantize_rowwise": 0.9983181953430176, "g_quantize_rowwise": 0.25597214698791504, "w_quantize_rowwise": 0.03277510404586792, "w_quantize_colwise_transpose": 0.23587048053741455, "w_quantize_global": 0.09367987513542175, "w_quantize_global_transpose": 0.10236725211143494, "cast_x": 1.9848868250846863, "cast_g": 0.5010329186916351, "cast_w": 0.055771321058273315, "time_standard": 17.219018191099167, "time_rowwise": 13.161107897758484, "time_global": 13.013303279876709} +{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 12.975204735994339, "standard_gw": 11.424731463193893, "standard_gx": 11.05477660894394, "rowwise_fwd": 6.623122841119766, "rowwise_bwd": 5.253363400697708, "global_fwd": 6.506938487291336, "global_bwd": 5.211424082517624, "x_quantize_rowwise": 0.5057789385318756, "g_quantize_rowwise": 1.9870363175868988, "w_quantize_rowwise": 0.030517578125, "w_quantize_colwise_transpose": 0.24361908435821533, "w_quantize_global": 0.09384006261825562, "w_quantize_global_transpose": 0.10285153985023499, "cast_x": 0.9967051446437836, "cast_g": 3.9620958268642426, "cast_w": 0.05599111318588257, "time_standard": 35.45471280813217, "time_rowwise": 26.068169623613358, "time_global": 25.83260089159012} +{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.05555146932602, "standard_gw": 11.32136583328247, "standard_gx": 12.035444378852844, "rowwise_fwd": 5.243867635726929, "rowwise_bwd": 6.622854620218277, "global_fwd": 5.209986120462418, "global_bwd": 6.507329642772675, "x_quantize_rowwise": 1.9862838089466095, "g_quantize_rowwise": 0.506080687046051, "w_quantize_rowwise": 0.03318488597869873, "w_quantize_colwise_transpose": 0.23682788014411926, "w_quantize_global": 0.09349361062049866, "w_quantize_global_transpose": 0.1023709774017334, "cast_x": 3.962486982345581, "cast_g": 0.9956248104572296, "cast_w": 0.05572289228439331, "time_standard": 34.412361681461334, "time_rowwise": 25.950465351343155, "time_global": 25.726910680532455} diff --git a/tests/triton_tests/info_mlp.jsonl b/tests/triton_tests/info_mlp.jsonl new file mode 100644 index 0000000..a2076ee --- /dev/null +++ b/tests/triton_tests/info_mlp.jsonl @@ -0,0 +1,20 @@ +{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 3.807276487350464, "my_standard": 4.196919500827789, "standard_compiled": 3.771558403968811, "sb": 3.5132691264152527} +{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 7.215872406959534, "my_standard": 7.991522550582886, "standard_compiled": 7.241688668727875, "sb": 6.581142544746399} +{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 14.26444947719574, "my_standard": 15.685759484767914, "standard_compiled": 14.251746237277985, "sb": 12.735314667224884} +{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 28.49559485912323, "my_standard": 31.26966953277588, "standard_compiled": 28.414390981197357, "sb": 25.319166481494904} +{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 5.887262523174286, "my_standard": 6.132654845714569, "standard_compiled": 5.902409553527832, "sb": 4.947789013385773} +{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 11.14131510257721, "my_standard": 12.859955430030823, "standard_compiled": 11.133037507534027, "sb": 9.303092956542969} +{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 22.193141281604767, "my_standard": 25.66336840391159, "standard_compiled": 22.22583442926407, "sb": 18.285617232322693} +{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 44.23898458480835, "my_standard": 51.30268633365631, "standard_compiled": 44.08355802297592, "sb": 35.999126732349396} +{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 6.938718259334564, "my_standard": 7.269218564033508, "standard_compiled": 6.94604218006134, "sb": 5.764961242675781} +{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 13.04878294467926, "my_standard": 13.742901384830475, "standard_compiled": 13.011425733566284, "sb": 10.774023830890656} +{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 26.738539338111877, "my_standard": 27.739346027374268, "standard_compiled": 26.75659954547882, "sb": 21.882005035877228} +{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 51.905401051044464, "my_standard": 53.98637801408768, "standard_compiled": 51.8316924571991, "sb": 41.67725890874863} +{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 9.233824908733368, "my_standard": 9.619377553462982, "standard_compiled": 9.214423596858978, "sb": 7.557623088359833} +{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 17.324909567832947, "my_standard": 17.996780574321747, "standard_compiled": 17.29544997215271, "sb": 14.035224914550781} +{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 35.51657497882843, "my_standard": 36.674730479717255, "standard_compiled": 35.43049842119217, "sb": 28.38330715894699} +{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 69.0087378025055, "my_standard": 71.56594842672348, "standard_compiled": 68.82885098457336, "sb": 54.01633679866791} +{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 12.590140104293823, "my_standard": 13.106442987918854, "standard_compiled": 12.606985867023468, "sb": 10.286301374435425} +{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 24.830535054206848, "my_standard": 25.563716888427734, "standard_compiled": 24.895809590816498, "sb": 19.559212028980255} +{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 49.55078661441803, "my_standard": 51.16480588912964, "standard_compiled": 49.739621579647064, "sb": 38.29141706228256} +{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 98.36294502019882, "my_standard": 102.69322991371155, "standard_compiled": 98.76712411642075, "sb": 75.88706165552139} diff --git a/tests/triton_tests/info_mlp_autocast.jsonl b/tests/triton_tests/info_mlp_autocast.jsonl new file mode 100644 index 0000000..f2098cc --- /dev/null +++ b/tests/triton_tests/info_mlp_autocast.jsonl @@ -0,0 +1,20 @@ +{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 4.91420179605484, "my_standard": 5.577877163887024, "standard_compiled": 4.810944199562073, "sb": 4.512995481491089} +{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 8.876129984855652, "my_standard": 10.154612362384796, "standard_compiled": 8.820965886116028, "sb": 8.367843925952911} +{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 17.47015118598938, "my_standard": 19.857674837112427, "standard_compiled": 17.338842153549194, "sb": 15.992552042007446} +{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 34.824438393116, "my_standard": 39.499424397945404, "standard_compiled": 34.56207364797592, "sb": 31.573951244354248} +{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.342606782913208, "my_standard": 7.9323723912239075, "standard_compiled": 7.279552519321442, "sb": 6.395488977432251} +{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 13.69999349117279, "my_standard": 16.0503089427948, "standard_compiled": 13.603456318378448, "sb": 11.813104152679443} +{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 29.557034373283386, "my_standard": 34.2303067445755, "standard_compiled": 29.382556676864624, "sb": 22.882774472236633} +{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 53.629085421562195, "my_standard": 63.07622790336609, "standard_compiled": 53.33048850297928, "sb": 44.76426541805267} +{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 8.81417840719223, "my_standard": 9.477965533733368, "standard_compiled": 8.73943418264389, "sb": 7.479414343833923} +{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 16.242466866970062, "my_standard": 17.616644501686096, "standard_compiled": 16.14125818014145, "sb": 13.665586709976196} +{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 32.429613173007965, "my_standard": 34.80646014213562, "standard_compiled": 32.319076359272, "sb": 27.123987674713135} +{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 62.85770237445831, "my_standard": 67.55391508340836, "standard_compiled": 62.453076243400574, "sb": 51.53566598892212} +{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 11.585861444473267, "my_standard": 12.565858662128448, "standard_compiled": 11.504307389259338, "sb": 9.657211601734161} +{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 21.261662244796753, "my_standard": 22.771358489990234, "standard_compiled": 21.12410217523575, "sb": 17.64291524887085} +{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 42.85307973623276, "my_standard": 45.70870101451874, "standard_compiled": 42.57970303297043, "sb": 34.918561577796936} +{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 83.56057852506638, "my_standard": 89.11971747875214, "standard_compiled": 83.05662125349045, "sb": 66.32210314273834} +{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 15.7279372215271, "my_standard": 16.854502260684967, "standard_compiled": 15.655294060707092, "sb": 13.228952884674072} +{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 30.42648732662201, "my_standard": 32.26502239704132, "standard_compiled": 30.239209532737732, "sb": 24.354808032512665} +{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 60.779355466365814, "my_standard": 64.11923468112946, "standard_compiled": 60.89268624782562, "sb": 46.91776633262634} +{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 119.93677169084549, "my_standard": 128.19699943065643, "standard_compiled": 120.20225822925568, "sb": 92.3452153801918} diff --git a/tests/triton_tests/info_mlp_autocast_ln.jsonl b/tests/triton_tests/info_mlp_autocast_ln.jsonl new file mode 100644 index 0000000..706f949 --- /dev/null +++ b/tests/triton_tests/info_mlp_autocast_ln.jsonl @@ -0,0 +1,23 @@ +{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 5.171686410903931, "my_standard": 5.839601159095764, "standard_compiled": 5.032263696193695, "sb": 4.89344447851181} +{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 9.605035185813904, "my_standard": 10.910414159297943, "standard_compiled": 9.230785071849823, "sb": 9.128175675868988} +{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 18.802084028720856, "my_standard": 21.311581134796143, "standard_compiled": 18.105976283550262, "sb": 17.489850521087646} +{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 37.49683499336243, "my_standard": 42.40527004003525, "standard_compiled": 36.13145649433136, "sb": 34.58733111619949} +{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.709823548793793, "my_standard": 8.290477097034454, "standard_compiled": 7.564418017864227, "sb": 6.8823546171188354} +{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 14.64156061410904, "my_standard": 16.996942460536957, "standard_compiled": 14.4081711769104, "sb": 12.761622667312622} +{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 31.40200674533844, "my_standard": 36.074504256248474, "standard_compiled": 30.981406569480896, "sb": 24.76389706134796} +{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 56.93405121564865, "my_standard": 66.35250151157379, "standard_compiled": 56.07586354017258, "sb": 48.49743843078613} +{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 9.188003838062286, "my_standard": 9.84550267457962, "standard_compiled": 9.006097912788391, "sb": 7.9473331570625305} +{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 17.268165946006775, "my_standard": 18.64910125732422, "standard_compiled": 16.983114182949066, "sb": 14.70106840133667} +{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 34.39047932624817, "my_standard": 36.69705241918564, "standard_compiled": 33.8401272892952, "sb": 29.188089072704315} +{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 66.70494377613068, "my_standard": 71.27603143453598, "standard_compiled": 65.56134670972824, "sb": 55.6538850069046} +{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 12.10707426071167, "my_standard": 12.931793928146362, "standard_compiled": 11.76995038986206, "sb": 10.228671133518219} +{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 22.5130096077919, "my_standard": 23.962542414665222, "standard_compiled": 21.997176110744476, "sb": 18.89890432357788} +{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 45.210108160972595, "my_standard": 47.94136434793472, "standard_compiled": 44.2262664437294, "sb": 37.37735003232956} +{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 88.1955549120903, "my_standard": 93.6831533908844, "standard_compiled": 86.33609116077423, "sb": 71.23208791017532} +{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 16.538940370082855, "my_standard": 17.607316374778748, "standard_compiled": 16.108587384223938, "sb": 14.030493795871735} +{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 31.795650720596313, "my_standard": 33.57230871915817, "standard_compiled": 31.04180097579956, "sb": 25.971196591854095} +{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 63.021354377269745, "my_standard": 66.8477788567543, "standard_compiled": 61.682507395744324, "sb": 50.138771533966064} +{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 125.17062574625015, "my_standard": 133.60925763845444, "standard_compiled": 122.21191823482513, "sb": 98.40084612369537} +{"repeat": 32, "batch_size": 16384, "dim": 4096, "standard": 57.31645971536636, "my_standard": 60.84543466567993, "standard_compiled": 55.78199774026871, "sb": 45.43223977088928} +{"repeat": 32, "batch_size": 32768, "dim": 4096, "standard": 111.80306226015091, "my_standard": 119.0284714102745, "standard_compiled": 108.91905426979065, "sb": 85.4572057723999} +{"repeat": 32, "batch_size": 65536, "dim": 4096, "standard": 220.4471081495285, "my_standard": 233.0927476286888, "standard_compiled": 214.26431089639664, "sb": 163.30372542142868} diff --git a/tests/triton_tests/make_plot_with_info.py b/tests/triton_tests/make_plot_with_info.py new file mode 100644 index 0000000..116d1d1 --- /dev/null +++ b/tests/triton_tests/make_plot_with_info.py @@ -0,0 +1,137 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import os + +import matplotlib.gridspec as gridspec + +cmap=plt.get_cmap('cool') + +if __name__ == '__main__': + + fig = plt.figure(tight_layout=True, figsize=(12,3.5)) + gs = gridspec.GridSpec(1, 2) + + + ax = fig.add_subplot(gs[0, 0]) + + rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True) + df = rdf[rdf.batch_size == 32768] + + for k, marker, ls, color, name in [ + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), + + ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), + ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), + ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), + + ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), + ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), + + #### time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] + + ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), + ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), + ('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'), + ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'), + #('standard_gw', '.', '--', 'C1', 'standard_gw'), + ]: + xs = [] + ys = [] + for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + + + ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) + + + + + ax.set_xlabel('dim', fontsize=13) + ax.set_ylabel('time (ms)', fontsize=13) + # make a legend which is below the plot + + + + ax.grid() + + ax.set_xscale('log') + #ax.set_yscale('log') + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048, 4096]) + ax.set_xticklabels([1024, 2048, 4096]) + ax.set_xticks([], minor=True) + + leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) + leg.get_texts()[0].set_fontweight('bold') + leg.get_texts()[1].set_fontweight('bold') + plt.subplots_adjust(left=0.1) + ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) + + + ax = fig.add_subplot(gs[0, 1]) + + # now plot the % speedup for different batch sizes + for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]): + all_xs, all_ys = [], [] + for k, marker, ls, color, name in [ + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ]: + + xs, ys = [], [] + df = rdf[rdf.batch_size == batch_size] + for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + all_xs.append(xs) + all_ys.append(ys) + + color = cmap(j * 0.25) + real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] + markers = ['^', 'v', 'P', 'o'] + ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) + + ax.legend() + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% speedup', fontsize=13) + + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048, 4096]) + ax.set_xticklabels([1024, 2048, 4096]) + ax.set_xticks([], minor=True) + + ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) + + + + plt.savefig('tests/triton_tests/plot1.pdf', bbox_inches='tight') + diff --git a/tests/triton_tests/mlp.py b/tests/triton_tests/mlp.py new file mode 100644 index 0000000..1ec85b8 --- /dev/null +++ b/tests/triton_tests/mlp.py @@ -0,0 +1,64 @@ + +import time +import torch +import torch.nn as nn +import bitsandbytes.nn as bnn +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, MyLinear + +def construct_model(dim, layers, module): + modules = [] + for _ in range(layers): + modules.append(module(dim, 4*dim)) + modules.append(module(4*dim, dim)) + return nn.Sequential(*modules).cuda().train() + +def get_time(model, x, name): + for _ in range(repeat // 2): + #with torch.cuda.amp.autocast(): + out = model(x) + #(2**16 * out.pow(2).mean()).backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + # with torch.cuda.amp.autocast(): + out = model(x) + #(2**16 * out.pow(2).mean()).backward() + + torch.cuda.synchronize() + end = time.time() + print(f"time {name}: {(end - start) / repeat * 1000:.3f} ms") + +if __name__ == '__main__': + torch.manual_seed(0) + + # hparams + repeat = 16 + dim=2048 + layers =4 + batch_size = 2 + sequence_length = 2**15 + + # construct models + standard = construct_model(dim, layers, nn.Linear).half() + my_standard = construct_model(dim, layers, MyLinear).half() + switchback = construct_model(dim, layers, SwitchBackLinear).half() + switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half() + #bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt) + + # simulate forward pass + x = torch.randn(batch_size * sequence_length, dim, dtype=torch.float16).cuda() + + # get time for forward and backward + get_time(standard, x, "standard") + get_time(my_standard, x, "my_standard") + get_time(switchback, x, "switchback") + get_time(switchback_global, x, "switchback_global") + #get_time(bnb_8bitmixed, x, "bnb_8bitmixed") + + + + + + + \ No newline at end of file diff --git a/tests/triton_tests/mlp_decomp_autocast.py b/tests/triton_tests/mlp_decomp_autocast.py new file mode 100644 index 0000000..3a1fc9e --- /dev/null +++ b/tests/triton_tests/mlp_decomp_autocast.py @@ -0,0 +1,166 @@ + +import torch +import json +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +import time + +if __name__ == '__main__': + + print('Startin') + + + for dim in [1024, 1280, 1408, 1664, 2048]: + for batch in [2**14, 2**15, 2**16, 2**17]: + + if dim != 4096 or batch != 2**17: + continue + + + x1 = torch.randn(batch, dim).cuda().requires_grad_(True) + d = 2 + + standard = torch.nn.Sequential( + torch.nn.Linear(dim, 4 * dim), + torch.nn.GELU(), + torch.nn.Linear(4 * dim, dim), + ).cuda() + + my_standard = torch.nn.Sequential( + MyLinear(dim, 4 * dim), + torch.nn.GELU(), + MyLinear(4 * dim, dim), + ).cuda() + + fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda() + + sb = torch.nn.Sequential( + SwitchBackGlobalLinear(dim, 4 * dim), + torch.nn.GELU(), + SwitchBackGlobalLinear(4 * dim, dim), + ).cuda() + + standard_compiled = torch.compile(standard) + + print('Model part 2') + + repeat = 32 + + + info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim} + + # k = 'standard' + # for _ in range(repeat // 2): + # with torch.cuda.amp.autocast(): + # out_standard = standard(x1) + # ((2 ** 16) * out_standard).abs().mean().backward() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # with torch.cuda.amp.autocast(): + # out_standard = standard(x1) + # ((2 ** 16) * out_standard).abs().mean().backward() + + # torch.cuda.synchronize() + # end = time.time() + # ms = (end - start) / repeat * 1000 + # print(f"time {k}: {ms:.3f} ms") + # info[k] = ms + + + # x1.grad.zero_() + + # k = 'my_standard' + # for _ in range(repeat // 2): + # with torch.cuda.amp.autocast(): + # out_my_standard = my_standard(x1) + # ((2 ** 16) * out_my_standard).abs().mean().backward() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # with torch.cuda.amp.autocast(): + # out_my_standard = my_standard(x1) + # ((2 ** 16) * out_my_standard).abs().mean().backward() + + # torch.cuda.synchronize() + # end = time.time() + # ms = (end - start) / repeat * 1000 + # print(f"time {k}: {ms:.3f} ms") + # info[k] = ms + + # x1.grad.zero_() + + # k = 'standard_compiled' + # for _ in range(repeat // 2): + # with torch.cuda.amp.autocast(): + # out_standard_compiled = standard_compiled(x1) + # ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + # torch.cuda.synchronize() + # start = time.time() + # for _ in range(repeat): + # with torch.cuda.amp.autocast(): + # out_standard_compiled = standard_compiled(x1) + # ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + # torch.cuda.synchronize() + # end = time.time() + # ms = (end - start) / repeat * 1000 + # print(f"time {k}: {ms:.3f} ms") + # info[k] = ms + + # x1.grad.zero_() + + k = 'sb' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + info_json = json.dumps(info) + + + with open("tests/triton_tests/info_mlp_autocast.jsonl", "a") as file: + file.write(info_json + "\n") + + + #exit() + + # err_fused = (out_standard - out_fused).abs().mean() + # err_sb = (out_standard - out_sb).abs().mean() + # print('OUT', err_fused, err_sb) + + # err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean() + # err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean() + + # print('GW2', err_fused, err_sb) + + # err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean() + # err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean() + + # print('GW1', err_fused, err_sb) + + # err_fused = (x1.grad - x2.grad).abs().mean() + # err_sb = (x1.grad - x3.grad).abs().mean() + + # print('GX1', err_fused, err_sb) + + # import pdb; pdb.set_trace() + + + # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file diff --git a/tests/triton_tests/mlp_decomp_autocast_ln.py b/tests/triton_tests/mlp_decomp_autocast_ln.py new file mode 100644 index 0000000..2596278 --- /dev/null +++ b/tests/triton_tests/mlp_decomp_autocast_ln.py @@ -0,0 +1,165 @@ + +import torch +import json +from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear +import time + +if __name__ == '__main__': + + print('Startin') + + + for dim in [1024, 1280, 1408, 1664, 2048]: + for batch in [2**14, 2**15, 2**16, 2**17]: + + x1 = torch.randn(batch, dim).cuda().requires_grad_(True) + d = 2 + + standard = torch.nn.Sequential( + torch.nn.LayerNorm(dim), + torch.nn.Linear(dim, 4 * dim), + torch.nn.GELU(), + torch.nn.Linear(4 * dim, dim), + ).cuda() + + my_standard = torch.nn.Sequential( + torch.nn.LayerNorm(dim), + MyLinear(dim, 4 * dim), + torch.nn.GELU(), + MyLinear(4 * dim, dim), + ).cuda() + + fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda() + + sb = torch.nn.Sequential( + torch.nn.LayerNorm(dim), + SwitchBackGlobalLinear(dim, 4 * dim), + torch.nn.GELU(), + SwitchBackGlobalLinear(4 * dim, dim), + ).cuda() + + standard_compiled = torch.compile(standard) + + print('Model part 2') + + repeat = 32 + + + info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim} + + k = 'standard' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_standard = standard(x1) + ((2 ** 16) * out_standard).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_standard = standard(x1) + ((2 ** 16) * out_standard).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + x1.grad.zero_() + + k = 'my_standard' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_my_standard = my_standard(x1) + ((2 ** 16) * out_my_standard).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_my_standard = my_standard(x1) + ((2 ** 16) * out_my_standard).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'standard_compiled' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_standard_compiled = standard_compiled(x1) + ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_standard_compiled = standard_compiled(x1) + ((2 ** 16) * out_standard_compiled).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + x1.grad.zero_() + + k = 'sb' + for _ in range(repeat // 2): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + with torch.cuda.amp.autocast(): + out_sb = sb(x1) + ((2 ** 16) * out_sb).abs().mean().backward() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info[k] = ms + + + info_json = json.dumps(info) + + + with open("tests/triton_tests/info_mlp_autocast_ln.jsonl", "a") as file: + file.write(info_json + "\n") + + + #exit() + + # err_fused = (out_standard - out_fused).abs().mean() + # err_sb = (out_standard - out_sb).abs().mean() + # print('OUT', err_fused, err_sb) + + # err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean() + # err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean() + + # print('GW2', err_fused, err_sb) + + # err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean() + # err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean() + + # print('GW1', err_fused, err_sb) + + # err_fused = (x1.grad - x2.grad).abs().mean() + # err_sb = (x1.grad - x3.grad).abs().mean() + + # print('GX1', err_fused, err_sb) + + # import pdb; pdb.set_trace() + + + # # NO GELU, ST GRADIENTS, EVERYTHING FINE. \ No newline at end of file diff --git a/tests/triton_tests/plot1.pdf b/tests/triton_tests/plot1.pdf new file mode 100644 index 0000000..1fe7168 Binary files /dev/null and b/tests/triton_tests/plot1.pdf differ diff --git a/tests/triton_tests/plot1.png b/tests/triton_tests/plot1.png new file mode 100644 index 0000000..794c869 Binary files /dev/null and b/tests/triton_tests/plot1.png differ diff --git a/tests/triton_tests/plot2.pdf b/tests/triton_tests/plot2.pdf new file mode 100644 index 0000000..56b835e Binary files /dev/null and b/tests/triton_tests/plot2.pdf differ diff --git a/tests/triton_tests/plot2.png b/tests/triton_tests/plot2.png new file mode 100644 index 0000000..94659c0 Binary files /dev/null and b/tests/triton_tests/plot2.png differ diff --git a/tests/triton_tests/plot2.py b/tests/triton_tests/plot2.py new file mode 100644 index 0000000..d433548 --- /dev/null +++ b/tests/triton_tests/plot2.py @@ -0,0 +1,69 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import os + +import matplotlib.gridspec as gridspec + +cmap=plt.get_cmap('cool') + +if __name__ == '__main__': + + fig = plt.figure(tight_layout=True, figsize=(6,3.5)) + gs = gridspec.GridSpec(1, 1) + + + rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True) + + ax = fig.add_subplot(gs[0, 0]) + + # now plot the % speedup for different batch sizes + for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]): + all_xs, all_ys = [], [] + for k, marker, ls, color, name in [ + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ]: + + xs, ys = [], [] + df = rdf[rdf.batch_size == batch_size] + for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + all_xs.append(xs) + all_ys.append(ys) + + color = cmap(j * 0.25) + real_ys = [100 * all_ys[1][i] / all_ys[0][i] for i in range(len(all_ys[0]))] + markers = ['^', 'v', 'P', 'o'] + ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) + + ax.legend() + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% time occupied by quantize ops', fontsize=12) + + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048, 4096]) + ax.set_xticklabels([1024, 2048, 4096]) + ax.set_xticks([], minor=True) + + #ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) + + + + plt.savefig('tests/triton_tests/plot2.pdf', bbox_inches='tight') + diff --git a/tests/triton_tests/plot3.pdf b/tests/triton_tests/plot3.pdf new file mode 100644 index 0000000..19e93a2 Binary files /dev/null and b/tests/triton_tests/plot3.pdf differ diff --git a/tests/triton_tests/plot3.png b/tests/triton_tests/plot3.png new file mode 100644 index 0000000..e83178d Binary files /dev/null and b/tests/triton_tests/plot3.png differ diff --git a/tests/triton_tests/plot3.py b/tests/triton_tests/plot3.py new file mode 100644 index 0000000..beaa811 --- /dev/null +++ b/tests/triton_tests/plot3.py @@ -0,0 +1,193 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import os +import matplotlib.lines as mlines +import matplotlib.gridspec as gridspec + +cmap=plt.get_cmap('cool') + +if __name__ == '__main__': + + fig = plt.figure(tight_layout=True, figsize=(12,3.5)) + gs = gridspec.GridSpec(1, 3) + + + rdf1 = pd.read_json('tests/triton_tests/info_mlp_autocast_ln.jsonl', lines=True) + + ax = fig.add_subplot(gs[0, 0]) + + # now plot the % speedup for different batch sizes + for j, batch_size in enumerate([2**15, 2**17]):#, 2**15, 2**17, 2**17]): + all_xs, all_ys = {}, {} + for k, marker, ls, color, name in [ + ('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)'), + #('standard', 'o', '-', 'C1', 'standard (total time)'), + ('my_standard', 'o', '-', 'C2', 'my standard (total time)'), + ('sb', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ]: + + xs, ys = [], [] + df = rdf1[rdf1.batch_size == batch_size] + for embed_dim in [1024, 1280, 1408, 1664, 2048]: + df_ = df[df.dim == embed_dim] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_) + + all_xs[k] = xs + all_ys[k] = ys + #ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5) + + + color= cmap(float(j)) + speedup_over_my_standard = [-100 * (all_ys['sb'][i] - all_ys['my_standard'][i]) / all_ys['my_standard'][i] for i in range(len(all_ys['my_standard']))] + speedup_over_compile = [-100 * (all_ys['sb'][i] - all_ys['standard_compiled'][i]) / all_ys['standard_compiled'][i] for i in range(len(all_ys['standard_compiled']))] + + ax.plot(xs, speedup_over_my_standard, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5) + ax.plot(xs, speedup_over_compile, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5, linestyle='--') + + + #ax.legend() + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% speedup', fontsize=12) + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048]) + ax.set_xticklabels([1024, 2048]) + ax.set_xticks([], minor=True) + ax.set_title('MLP Block', fontsize=10, loc='left', y=1.07, pad=-20) + + + ########################################## + + rdf2 = pd.read_json('tests/triton_tests/attn_info_ln.jsonl', lines=True) + + ax = fig.add_subplot(gs[0, 1]) + + for j, batch_size in enumerate([2**15, 2**17]):#, 2**15, 2**17, 2**17]): + all_xs, all_ys = {}, {} + for k, marker, ls, color, name in [ + ('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)'), + #('standard', 'o', '-', 'C1', 'standard (total time)'), + ('my_standard', 'o', '-', 'C2', 'my standard (total time)'), + ('sb', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ]: + + xs, ys = [], [] + df = rdf2[rdf2.batch_size == batch_size] + for embed_dim in [1024, 1280, 1408, 1664, 2048]: + df_ = df[df.dim == embed_dim] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_) + + all_xs[k] = xs + all_ys[k] = ys + #ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5) + + color= cmap(float(j)) + speedup_over_my_standard = [-100 * (all_ys['sb'][i] - all_ys['my_standard'][i]) / all_ys['my_standard'][i] for i in range(len(all_ys['my_standard']))] + speedup_over_compile = [-100 * (all_ys['sb'][i] - all_ys['standard_compiled'][i]) / all_ys['standard_compiled'][i] for i in range(len(all_ys['standard_compiled']))] + + ax.plot(xs, speedup_over_my_standard, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5) + ax.plot(xs, speedup_over_compile, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5, linestyle='--') + + + speedup_compiled = mlines.Line2D([], [], linestyle='--', color='gray', label='speedup over compiled') + speedup_baseline = mlines.Line2D([], [], linestyle='-', color='gray', label='speedup over baseline') + batch_size_4 = mlines.Line2D([], [], linestyle='-', color=cmap(0.), label=f'batch = {int(2**15 // 256)}, sequence = {256}') + batch_size_8 = mlines.Line2D([], [], linestyle='-', color=cmap(1.), label=f'batch = {int(2**17 / 256)} sequence = {256}') + + # Create the legend with the proxy artists + + # adjust plots so that they dont get squished by putting the legend under both + + + plt.subplots_adjust(left=0.2) + plt.subplots_adjust(right=0.8) + + fig.legend(handles=[speedup_compiled, speedup_baseline, batch_size_4, batch_size_8], ncol=2, loc='upper center', bbox_to_anchor=(0.35, 0.255)) + + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% speedup', fontsize=12) + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048]) + ax.set_xticklabels([1024, 2048]) + ax.set_xticks([], minor=True) + + ax.set_title('Attention Block', fontsize=10, loc='left', y=1.07, pad=-20) + + + + ########################################## + + + + ax = fig.add_subplot(gs[0, 2]) + + for j, batch_size in enumerate([2**15]):#, 2**15, 2**17, 2**17]): + all_xs, all_ys = {}, {} + for k, marker, ls, color, name, b in [ + ('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)', False), + ('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)', True), + + #('standard', 'o', '-', 'C1', 'standard (total time)'), + #('my_standard', 'o', '-', 'C2', 'my standard (total time)'), + ('attn', 'o', '-', 'C4', 'SwitchBack int8 (total time)', True), + ]: + rdf = rdf2 if b else rdf1 + + xs, ys = [], [] + df = rdf[rdf.batch_size == batch_size] + for embed_dim in [1024, 1280, 1408, 1664, 2048]: + df_ = df[df.dim == embed_dim] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_) + + all_xs[k + str(int(b))] = xs + all_ys[k + str(int(b))] = ys + #ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5) + + + print(all_ys.keys()) + all_ys['standard_compiled'] = [x + y for x, y in zip(all_ys['standard_compiled0'], all_ys['standard_compiled1'])] + + speedup_over_my_standard = [100 * all_ys['attn1'][i] / (all_ys['standard_compiled'][i] + all_ys['attn1'][i]) for i in range(len(all_ys['standard_compiled']))] + ax.plot(xs, speedup_over_my_standard, color='gold', label=r'% time occupied by attention', marker='H', markersize=8) + + speedup_over_my_standard = [100 * all_ys['standard_compiled1'][i] / (all_ys['standard_compiled0'][i] + all_ys['standard_compiled1'][i]) for i in range(len(all_ys['standard_compiled']))] + ax.plot(xs, speedup_over_my_standard, color='indianred', label=r'% time occupied by attention block', marker='P', markersize=8) + + + ax.legend(bbox_to_anchor=(1.02, -0.27)) + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% time', fontsize=12) + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks([1024, 2048]) + ax.set_xticklabels([1024, 2048]) + ax.set_xticks([], minor=True) + + plt.savefig('tests/triton_tests/plot3.pdf', bbox_inches='tight') + diff --git a/tests/triton_tests/rowwise.py b/tests/triton_tests/rowwise.py new file mode 100644 index 0000000..c5acb8e --- /dev/null +++ b/tests/triton_tests/rowwise.py @@ -0,0 +1,43 @@ + +import time +import torch +import torch +import torch.nn as nn +import bitsandbytes.nn as bnn +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear + +from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup + + +# 256 * 256 * 4096 _> 0.7 +# 256 * 128 * 8192 -> 10 +if __name__ == '__main__': + torch.manual_seed(0) + + # hparams + repeat = 16 + dim=8192 + layers = 4 + + batch_size = 256 * 128 + + # simulate forward pass + x = torch.randn(batch_size, dim, dtype=torch.float16).cuda() + + for _ in range(repeat // 2): + quantize_rowwise_nogroup(x) + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + quantize_rowwise_nogroup(x) + torch.cuda.synchronize() + end = time.time() + + print(f"time: {(end - start) / repeat * 1000:.3f} ms") + + + + + + \ No newline at end of file