cleaning and refactor

This commit is contained in:
Mitchell Wortsman 2023-04-01 18:46:04 +00:00
parent 30d21d585c
commit 7f87ba83ee
33 changed files with 420 additions and 2514 deletions

View File

@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed
from .triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear

View File

@ -1,26 +1,76 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import time import time
from functools import partial
from .triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup from .triton_utils.v0.quantize_rowwise import quantize_rowwise
from .triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose from .triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from .triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias
from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose from .triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias from .triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
from .triton_utils.v0.fused_gelu_quantize import quantize_rowwise_nogroup_gelu, quantize_rowwise_nogroup_back_gelu
class _switchback(torch.autograd.Function):
class _switchback_global(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, X_3D, W, bias): def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X, W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
X, W = ctx.save_for_backward
if ctx.needs_input_grad[0]:
# rowwise quantize for G, global quantize for W
# for W, we also fuse the transpose operation because only A @ B^T is supported
# so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class _switchback_vectorrize(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1)) X = X_3D.view(-1, X_3D.size(-1))
ctx.save_for_backward = X, W ctx.save_for_backward = X, W
X_int8, state_X = quantize_rowwise_nogroup(X) # rowwise quantize for X
W_int8, state_W = quantize_rowwise_nogroup(W) # columnwise quantize for W (first rowwise, transpose later)
return int8_matmul_rowwise_dequantize_bias( X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_rowwise(W)
# matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1) ).view(*X_3D.size()[:-1], -1)
@ -33,12 +83,15 @@ class _switchback(torch.autograd.Function):
grad_X = grad_W = grad_bias = None grad_X = grad_W = grad_bias = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise_nogroup(G) # rowwise quantize for G, columnwise quantize for W and fused transpose
W_int8, state_W = quantize_columnwise_nogroup_transpose(W) # we call .t() for weight later because only A @ B^T is supported
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W).view( G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1 *G_3D.size()[:-1], -1
) )
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
# backward pass uses standard weight grad
grad_W = torch.matmul(G.t(), X.to(G.dtype)) grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]: if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0) grad_bias = G.sum(dim=0)
@ -46,11 +99,37 @@ class _switchback(torch.autograd.Function):
return grad_X, grad_W, grad_bias return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear): class SwitchBackLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
vectorize: bool = False
):
super().__init__(in_features, out_features, bias, device, dtype)
# By default, we use the global quantization.
self.vectorize = vectorize
if self.vectorize:
self._fn = _switchback_vectorrize
else:
self._fn = _switchback_global
def prepare_for_eval(self): def prepare_for_eval(self):
state_W = self.weight.abs().max(dim=1, keepdim=True)[0] # If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8) # Note this is experimental and not tested thoroughly.
state_W = state_W.squeeze() # Note this needs to be explicitly called with something like
# def cond_prepare(m):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print('=> preparing for eval.')
if self.vectorize:
W_int8, state_W = quantize_rowwise(self.weight)
else:
W_int8, state_W = quantize_global(self.weight)
self.register_buffer("W_int8", W_int8) self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W) self.register_buffer("state_W", state_W)
@ -59,80 +138,29 @@ class SwitchBackLinear(nn.Linear):
def forward(self, x): def forward(self, x):
if self.training: if self.training:
return _switchback.apply(x, self.weight, self.bias) return self._fn.apply(x, self.weight, self.bias)
else: else:
if not hasattr(self, "state_W"): # If it hasn't been "prepared for eval", run the standard forward pass.
self.prepare_for_eval() if not hasattr(self, "W_int8"):
return self._fn.apply(x, self.weight, self.bias)
# Otherwise, use pre-computed weights.
X = x.view(-1, x.size(-1)) X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X) X_int8, state_X = quantize_rowwise(X)
return int8_matmul_rowwise_dequantize_bias(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
class _switchback_global(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
X = X_3D.view(-1, X_3D.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X)
W_int8, state_W = quantize_global(W)
ctx.save_for_backward = X, W
return int8_matmul_mixed_dequanitze_bias(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
X, W = ctx.save_for_backward
if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise_nogroup(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class SwitchBackGlobalLinear(nn.Linear):
def prepare_for_eval(self):
state_W = self.weight.abs().max()
W_int8 = (127 * self.weight.float() / state_W).round().to(torch.int8)
self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W)
del self.weight
def forward(self, x):
if self.training:
return _switchback_global.apply(x, self.weight, self.bias)
else:
if not hasattr(self, "state_W"):
self.prepare_for_eval()
X = x.view(-1, x.size(-1))
X_int8, state_X = quantize_rowwise_nogroup(X)
return int8_matmul_mixed_dequanitze_bias(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
if self.vectorize:
return int8_matmul_rowwise_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False)
SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True)
# This is just the standard linear function.
class StandardLinearFunction(torch.autograd.Function): class StandardLinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias=None): def forward(ctx, input, weight, bias=None):

View File

@ -1,190 +0,0 @@
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
tl.libdevice
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise_nogroup_gelu(
x_ptr,
output_ptr,
output_maxs,
output_fp16,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
cdf = 0.5 * (1.0 + tl.libdevice.tanh(x * 0.7978845608 * (1 + 0.044715 * x * x)))
x_new = x * cdf
tl.store(output_fp16 + offsets, x_new, mask=row_mask)
abs_x = tl.abs(x_new)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x_new / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup_gelu(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup_gelu[grid](x, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs, output_fp16
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise_nogroup_back_gelu(
x_ptr,
in_ptr,
output_ptr,
output_maxs,
output_fp16,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x_out = tl.load(x_ptr + offsets, mask=row_mask)
x_in = tl.load(in_ptr + offsets, mask=row_mask)
cdf = 0.5 * (1.0 + tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in)))
intermediate = tl.libdevice.tanh(x_in * 0.7978845608 * (1 + 0.044715 * x_in * x_in))
dcdf = 0.5 * (0.7978845608 + 0.1070322243 * x_in * x_in) * (1 - intermediate * intermediate)
x = x_out * (cdf + x_in * dcdf)
tl.store(output_fp16 + offsets, x, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup_back_gelu(x: torch.Tensor, y : torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_fp16 = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup_back_gelu[grid](x, y, output, output_maxs, output_fp16, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs, output_fp16
# if __name__ == '__main__':
# torch.manual_seed(0)
# x = torch.randn(1280, 768).cuda().to(torch.float16)
# out = quantize_rowwise_nogroup(x)
# x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
# max2 = x.abs().max(1)[0]
# print(torch.allclose(out[1], max2))
# print( (x_real == out[0]).float().mean() )
# # for i in range(x.shape[0]):
# # print( (x_real[i, :] == out[0][i, :]).float().mean() )
# # print(out[0])
# # print(x_real)
# # import pdb; pdb.set_trace()
# # print(out[2])
# # print(out[2][:10])
# sums = x.sum(dim=0)
# #print(sums[:10])
# #print( (sums == out[2]).float().mean() )
# import pdb; pdb.set_trace()
# # import pdb; pdb.set_trace()
# # exit()
# # repeat = 16
# # for _ in range(8):
# # out = quantize_rowwise_nogroup(x)
# # triton_graph = torch.cuda.CUDAGraph()
# # with torch.cuda.graph(triton_graph):
# # out = quantize_rowwise_nogroup(x)
# # triton_graph.replay()
# # torch.cuda.synchronize()
# # start = time.time()
# # for _ in range(repeat):
# # triton_graph.replay()
# # torch.cuda.synchronize()
# # end = time.time()
# # print(out[0])
# # print(out[1])
# # print(x / x.abs().max(dim=1, keepdim=True)[0])
# # max1 = out[1]
# # max2 = x.abs().max(1)[0]
# # print(max1, max2)
# # print(torch.allclose(max1, max2))
# #print(f"time: {(end - start) / repeat * 1000:.3f} ms")

View File

@ -5,10 +5,14 @@ import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
def get_configs_io_bound(): def get_configs_io_bound():
configs = [] configs = []
for num_stages in [2, 3, 4, 5, 6]: for num_stages in [2, 3, 4, 5, 6]:
@ -60,130 +64,7 @@ def get_configs_io_bound():
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
}) })
@triton.jit @triton.jit
def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr)
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w):
device = a.device
divfactor = 1. / (127. * 127.)
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
@ -236,6 +117,7 @@ def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl
acc = (w_factor * (x_factor * (acc * divfactor))) acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
# conditionally add bias
if has_bias: if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty) bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :] acc = acc + bias[None, :]
@ -249,7 +131,7 @@ def _kernel_bias(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_mixed_dequanitze_bias(a, b, state_x, state_w, bias): def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
device = a.device device = a.device
divfactor = 1. / (127. * 127.) divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1 has_bias = 0 if bias is None else 1
@ -266,9 +148,9 @@ def int8_matmul_mixed_dequanitze_bias(a, b, state_x, state_w, bias):
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel # launch int8_matmul_mixed_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel_bias[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
b.stride(0), b.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), c.stride(0), c.stride(1),

View File

@ -4,6 +4,10 @@ import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and columnwise quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def init_to_zero(name): def init_to_zero(name):
return lambda nargs: nargs[name].zero_() return lambda nargs: nargs[name].zero_()
@ -60,7 +64,7 @@ def get_configs_io_bound():
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
}) })
@triton.jit @triton.jit
def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor, def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
stride_am, stride_ak, stride_am, stride_ak,
stride_bk, stride_bn, stride_bk, stride_bn,
stride_cm, stride_cn, stride_cm, stride_cn,
@ -113,6 +117,10 @@ def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor,
acc = (w_factor * (x_factor * (acc * divfactor))) acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty) acc = acc.to(C.dtype.element_ty)
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :] mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting # handles write-back with reduction-splitting
@ -122,9 +130,11 @@ def _kernel(A, B, C, state_x_ptr, state_w_ptr, M, N, K, divfactor,
tl.atomic_add(C, acc, mask=mask) tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize(a, b, state_x, state_w): def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):
divfactor = 1. / (127. * 127.) divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
device = a.device device = a.device
# handle non-contiguous inputs if necessary # handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1: if a.stride(0) > 1 and a.stride(1) > 1:
@ -139,9 +149,9 @@ def int8_matmul_rowwise_dequantize(a, b, state_x, state_w):
c = torch.empty((M, N), device=device, dtype=torch.float16) c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types # accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel # launch int8_matmul_rowwise_dequantize kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, state_x, state_w, M, N, K, divfactor, _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1), a.stride(0), a.stride(1),
b.stride(0), b.stride(1), b.stride(0), b.stride(1),
c.stride(0), c.stride(1), c.stride(0), c.stride(1),

View File

@ -1,160 +0,0 @@
import torch
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
def init_to_zero(name):
return lambda nargs: nargs[name].zero_()
def get_configs_io_bound():
configs = []
for num_stages in [2, 3, 4, 5, 6]:
for block_m in [16, 32]:
for block_k in [32, 64]:
for block_n in [32, 64, 128, 256]:
num_warps = 2 if block_n <= 64 else 4
configs.append(
triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1},
num_stages=num_stages, num_warps=num_warps))
# split_k
for split_k in [2, 4, 8, 16]:
configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},
num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))
return configs
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
# good for int8
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2),
] + get_configs_io_bound(),
key=['M', 'N', 'K'],
prune_configs_by={
'early_config_prune': early_config_prune,
'perf_model': estimate_matmul_time,
'top_k': 10
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _kernel(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_factor = tl.load(state_w_ptr + rbn)[None, :]
x_factor = tl.load(state_x_ptr + ram)[:, None]
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = (w_factor * (x_factor * (acc * divfactor)))
acc = acc.to(C.dtype.element_ty)
if has_bias:
bias = tl.load(bias + rn).to(C.dtype.element_ty)
acc = acc + bias[None, :]
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def int8_matmul_rowwise_dequantize_bias(a, b, state_x, state_w, bias):
#print(bias)
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
if bias is not None:
bias = bias.contiguous()
device = a.device
# handle non-contiguous inputs if necessary
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], "incompatible dimensions"
M, K = a.shape
_, N = b.shape
# allocates output
c = torch.empty((M, N), device=device, dtype=torch.float16)
# accumulator types
ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
_kernel[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
GROUP_M=8, ACC_TYPE=ACC_TYPE)
return c

View File

@ -5,6 +5,8 @@ import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# This kernel does fused columnwise quantization and transpose.
# TODO: autotune this better. # TODO: autotune this better.
@triton.autotune( @triton.autotune(
configs=[ configs=[
@ -26,7 +28,7 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
key=['n_elements'] key=['n_elements']
) )
@triton.jit @triton.jit
def _quantize_columnwise_nogroup_transpose( def _quantize_columnwise_and_transpose(
x_ptr, x_ptr,
output_ptr, output_ptr,
output_maxs, output_maxs,
@ -51,7 +53,7 @@ def _quantize_columnwise_nogroup_transpose(
tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask)
tl.store(output_maxs + pid, max_val) tl.store(output_maxs + pid, max_val)
def quantize_columnwise_nogroup_transpose(x: torch.Tensor): def quantize_columnwise_and_transpose(x: torch.Tensor):
M, N = x.shape M, N = x.shape
output = torch.empty(N, M, device=x.device, dtype=torch.int8) output = torch.empty(N, M, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
@ -61,62 +63,6 @@ def quantize_columnwise_nogroup_transpose(x: torch.Tensor):
assert x.is_cuda and output.is_cuda assert x.is_cuda and output.is_cuda
n_elements = output.numel() n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
_quantize_columnwise_nogroup_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2)
return output, output_maxs return output, output_maxs
if __name__ == '__main__':
torch.manual_seed(0)
x = torch.randn(1280, 768).cuda().to(torch.float16)
out = quantize_columnwise_nogroup_transpose(x)
x_real = x.t().float()
x_real_int8 = (127. * x_real / x_real.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
maxs = x_real.abs().max(dim=1, keepdim=True)[0].half()
#print(out[0][2,:])
print((out[0] == x_real_int8).float().mean())
print((out[1] == maxs[:, 0]).float().mean())
# print(out[0])
# print(out[1])
# print(out[0][2,:])
# print(x_real[2, :])
# print((out[0] != x_real).nonzero())
#import pdb; pdb.set_trace()
# repeat = 16
# for _ in range(8):
# out = quantize_columnwise_nogroup_transpose(x)
# triton_graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(triton_graph):
# out = quantize_columnwise_nogroup_transpose(x)
# triton_graph.replay()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# triton_graph.replay()
# torch.cuda.synchronize()
# end = time.time()
# print(out[0])
# print(out[1])
# print(x / x.abs().max(dim=0, keepdim=True)[0])
# x_real = (127 * (x / x.abs().max(dim=0, keepdim=True)[0])).round().to(torch.int8)
# max1 = out[1]
# max2 = x.abs().max(0)[0]
# print(max1, max2)
# import pdb; pdb.set_trace()
# print(torch.allclose(max1, max2))
# print(f"time: {(end - start) / repeat * 1000:.3f} ms")

View File

@ -5,7 +5,7 @@ import triton
import triton.language as tl import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# TODO: autotune this better. # global quantize
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4),
@ -42,6 +42,7 @@ def quantize_global(x: torch.Tensor):
return output, absmax return output, absmax
# global quantize and transpose
@triton.autotune( @triton.autotune(
configs=[ configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4),
@ -97,34 +98,3 @@ def quantize_global_transpose(input):
_quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N)
return out, absmax return out, absmax
if __name__ == '__main__':
w = torch.randn(768, 1280).cuda().to(torch.float16)
W_int8, state_w = quantize_global(w)
r_state_w = w.abs().max()
r_W_int8 = ((127 * w.float()) / state_w).round().to(torch.int8)
print((r_W_int8 == W_int8).float().mean())
# print(r_W_int8)
# print(W_int8)
exit()
repeat = 16
for _ in range(8):
out = quantize_global(w)
triton_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(triton_graph):
out = quantize_global(w)
triton_graph.replay()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
triton_graph.replay()
torch.cuda.synchronize()
end = time.time()
print(f"time: {(end - start) / repeat * 1000:.3f} ms")

View File

@ -0,0 +1,61 @@
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise(
x_ptr,
output_ptr,
output_maxs,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs

View File

@ -1,174 +0,0 @@
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _quantize_rowwise_nogroup(
x_ptr,
output_ptr,
output_maxs,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
def quantize_rowwise_nogroup(x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_quantize_rowwise_nogroup[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output, output_maxs
@triton.autotune(
configs=[
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _experimental_quantize_rowwise_nogroup(
x_ptr,
output_ptr,
bias_grad_ptr,
output_maxs,
n_elements,
M: tl.constexpr, N: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
P2M: tl.constexpr,
):
pid = tl.program_id(axis=0)
if pid < M:
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
abs_x = tl.abs(x)
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)
output = tl.libdevice.llrint(127. * (x / max_val))
tl.store(output_ptr + offsets, output, mask=row_mask)
tl.store(output_maxs + pid, max_val)
else:
real_pid = pid - M
arange_new = tl.arange(0, P2M)
mask_new = arange_new < M
offsets_new = real_pid + arange_new * N
new_x = tl.load(x_ptr + offsets_new, mask=mask_new)
s = tl.sum(tl.where(mask_new, new_x, 0).to(tl.float32), axis=0)
tl.store(bias_grad_ptr + real_pid, s)
def experimental_quantize_rowwise_nogroup(x: torch.Tensor):
M, N = x.shape
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)
bias_grad = torch.empty(x.shape[1], device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
P2M = int(2 ** (math.ceil(math.log2(x.shape[0]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0] + x.shape[1],)
_experimental_quantize_rowwise_nogroup[grid](x, output, bias_grad, output_maxs, n_elements, M, N, BLOCK_SIZE=x.shape[1], P2=P2, P2M=P2M)
return output, output_maxs, bias_grad
if __name__ == '__main__':
torch.manual_seed(0)
x = torch.randn(1280, 768).cuda().to(torch.float16)
out = quantize_rowwise_nogroup(x)
x_real = (127 * x.float() / x.abs().max(dim=1, keepdim=True)[0]).round().to(torch.int8)
max2 = x.abs().max(1)[0]
print(torch.allclose(out[1], max2))
print( (x_real == out[0]).float().mean() )
# for i in range(x.shape[0]):
# print( (x_real[i, :] == out[0][i, :]).float().mean() )
# print(out[0])
# print(x_real)
# import pdb; pdb.set_trace()
# print(out[2])
# print(out[2][:10])
sums = x.sum(dim=0)
#print(sums[:10])
#print( (sums == out[2]).float().mean() )
import pdb; pdb.set_trace()
# import pdb; pdb.set_trace()
# exit()
# repeat = 16
# for _ in range(8):
# out = quantize_rowwise_nogroup(x)
# triton_graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(triton_graph):
# out = quantize_rowwise_nogroup(x)
# triton_graph.replay()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# triton_graph.replay()
# torch.cuda.synchronize()
# end = time.time()
# print(out[0])
# print(out[1])
# print(x / x.abs().max(dim=1, keepdim=True)[0])
# max1 = out[1]
# max2 = x.abs().max(1)[0]
# print(max1, max2)
# print(torch.allclose(max1, max2))
#print(f"time: {(end - start) / repeat * 1000:.3f} ms")

View File

@ -0,0 +1,60 @@
{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28139352798461914, "standard_gw": 0.2811811864376068, "standard_gx": 0.30258670449256897, "rowwise_fwd": 0.1994594931602478, "rowwise_bwd": 0.16159191727638245, "global_fwd": 0.19502267241477966, "global_bwd": 0.16080215573310852, "x_quantize_rowwise": 0.03306940197944641, "g_quantize_rowwise": 0.08210167288780212, "w_quantize_rowwise": 0.03385916352272034, "w_quantize_colwise_transpose": 0.08635595440864563, "w_quantize_global": 0.09237229824066162, "w_quantize_global_transpose": 0.10007619857788086, "time_standard": 0.8651614189147949, "time_rowwise": 0.8776187896728516, "time_global": 0.944625586271286}
{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.262625515460968, "standard_gw": 0.2806223928928375, "standard_gx": 0.31118839979171753, "rowwise_fwd": 0.1828707754611969, "rowwise_bwd": 0.21236762404441833, "global_fwd": 0.16665831208229065, "global_bwd": 0.19929558038711548, "x_quantize_rowwise": 0.08227676153182983, "g_quantize_rowwise": 0.03310292959213257, "w_quantize_rowwise": 0.032648444175720215, "w_quantize_colwise_transpose": 0.09015202522277832, "w_quantize_global": 0.0988692045211792, "w_quantize_global_transpose": 0.10057538747787476, "time_standard": 0.8544363081455231, "time_rowwise": 0.9140409529209137, "time_global": 0.96140056848526}
{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5731917917728424, "standard_gw": 0.5709454417228699, "standard_gx": 0.5963630974292755, "rowwise_fwd": 0.37662312388420105, "rowwise_bwd": 0.281747430562973, "global_fwd": 0.36768242716789246, "global_bwd": 0.28043612837791443, "x_quantize_rowwise": 0.046547502279281616, "g_quantize_rowwise": 0.15532970428466797, "w_quantize_rowwise": 0.032436102628707886, "w_quantize_colwise_transpose": 0.08635222911834717, "w_quantize_global": 0.0947415828704834, "w_quantize_global_transpose": 0.10129809379577637, "time_standard": 1.7405003309249878, "time_rowwise": 1.5499815344810486, "time_global": 1.616980880498886}
{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5341619253158569, "standard_gw": 0.5690865218639374, "standard_gx": 0.599835067987442, "rowwise_fwd": 0.3233291208744049, "rowwise_bwd": 0.41359663009643555, "global_fwd": 0.2831108868122101, "global_bwd": 0.37280842661857605, "x_quantize_rowwise": 0.15563145279884338, "g_quantize_rowwise": 0.046741217374801636, "w_quantize_rowwise": 0.03306940197944641, "w_quantize_colwise_transpose": 0.09020790457725525, "w_quantize_global": 0.0925213098526001, "w_quantize_global_transpose": 0.09945780038833618, "time_standard": 1.7030835151672363, "time_rowwise": 1.6316622495651245, "time_global": 1.6193576157093048}
{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2199915945529938, "standard_gw": 1.1069811880588531, "standard_gx": 1.09761580824852, "rowwise_fwd": 0.738043338060379, "rowwise_bwd": 0.5549229681491852, "global_fwd": 0.7219798862934113, "global_bwd": 0.5512163043022156, "x_quantize_rowwise": 0.08748471736907959, "g_quantize_rowwise": 0.3023110330104828, "w_quantize_rowwise": 0.03182142972946167, "w_quantize_colwise_transpose": 0.08632615208625793, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.10032951831817627, "time_standard": 3.424588590860367, "time_rowwise": 2.9078908264636993, "time_global": 2.9647573828697205}
{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1040829122066498, "standard_gw": 1.1221766471862793, "standard_gx": 1.1548101902008057, "rowwise_fwd": 0.581938773393631, "rowwise_bwd": 0.7480122148990631, "global_fwd": 0.5537159740924835, "global_bwd": 0.7232688367366791, "x_quantize_rowwise": 0.30193477869033813, "g_quantize_rowwise": 0.08745118975639343, "w_quantize_rowwise": 0.03374740481376648, "w_quantize_colwise_transpose": 0.09068101644515991, "w_quantize_global": 0.09645149111747742, "w_quantize_global_transpose": 0.10189786553382874, "time_standard": 3.3810697495937347, "time_rowwise": 2.9659420251846313, "time_global": 2.9868967831134796}
{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.4533793330192566, "standard_gw": 2.1938569843769073, "standard_gx": 2.179361879825592, "rowwise_fwd": 1.4615543186664581, "rowwise_bwd": 1.0522231459617615, "global_fwd": 1.4288239181041718, "global_bwd": 1.0450035333633423, "x_quantize_rowwise": 0.1691766083240509, "g_quantize_rowwise": 0.5951300263404846, "w_quantize_rowwise": 0.03337860107421875, "w_quantize_colwise_transpose": 0.08653849363327026, "w_quantize_global": 0.0940859317779541, "w_quantize_global_transpose": 0.09976327419281006, "time_standard": 6.826598197221756, "time_rowwise": 5.5918581783771515, "time_global": 5.625840276479721}
{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.1698065102100372, "standard_gw": 2.1875128149986267, "standard_gx": 2.2887587547302246, "rowwise_fwd": 1.0762326419353485, "rowwise_bwd": 1.4638006687164307, "global_fwd": 1.0450668632984161, "global_bwd": 1.4308765530586243, "x_quantize_rowwise": 0.5953535437583923, "g_quantize_rowwise": 0.16899779438972473, "w_quantize_rowwise": 0.03240257501602173, "w_quantize_colwise_transpose": 0.09106099605560303, "w_quantize_global": 0.09546056389808655, "w_quantize_global_transpose": 0.09852275252342224, "time_standard": 6.6460780799388885, "time_rowwise": 5.615361034870148, "time_global": 5.621790885925293}
{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.858218133449554, "standard_gw": 4.3631307780742645, "standard_gx": 4.404045641422272, "rowwise_fwd": 2.9063820838928223, "rowwise_bwd": 2.094462513923645, "global_fwd": 2.8426870703697205, "global_bwd": 2.0792782306671143, "x_quantize_rowwise": 0.33241137862205505, "g_quantize_rowwise": 1.1817105114459991, "w_quantize_rowwise": 0.03374367952346802, "w_quantize_colwise_transpose": 0.08633732795715332, "w_quantize_global": 0.09231641888618469, "w_quantize_global_transpose": 0.100012868642807, "time_standard": 13.62539455294609, "time_rowwise": 10.998178273439407, "time_global": 10.991547256708145}
{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.246581345796585, "standard_gw": 4.42587211728096, "standard_gx": 4.581417888402939, "rowwise_fwd": 2.1114833652973175, "rowwise_bwd": 2.9050447046756744, "global_fwd": 2.0806826651096344, "global_bwd": 2.85966694355011, "x_quantize_rowwise": 1.1816024780273438, "g_quantize_rowwise": 0.33330172300338745, "w_quantize_rowwise": 0.033445656299591064, "w_quantize_colwise_transpose": 0.09065866470336914, "w_quantize_global": 0.09239837527275085, "w_quantize_global_transpose": 0.09984523057937622, "time_standard": 13.253871351480484, "time_rowwise": 11.081408709287643, "time_global": 11.073369532823563}
{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.4859529435634613, "standard_gw": 0.46338513493537903, "standard_gx": 0.42321905493736267, "rowwise_fwd": 0.2761557698249817, "rowwise_bwd": 0.20775198936462402, "global_fwd": 0.2713911235332489, "global_bwd": 0.20639970898628235, "x_quantize_rowwise": 0.033095479011535645, "g_quantize_rowwise": 0.11894106864929199, "w_quantize_rowwise": 0.03125518560409546, "w_quantize_colwise_transpose": 0.1424551010131836, "w_quantize_global": 0.07288157939910889, "w_quantize_global_transpose": 0.08071959018707275, "time_standard": 1.372557133436203, "time_rowwise": 1.2730397284030914, "time_global": 1.2468136847019196}
{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.3920421004295349, "standard_gw": 0.44424086809158325, "standard_gx": 0.4759356379508972, "rowwise_fwd": 0.23231282830238342, "rowwise_bwd": 0.28430670499801636, "global_fwd": 0.20883232355117798, "global_bwd": 0.2741999924182892, "x_quantize_rowwise": 0.12018159031867981, "g_quantize_rowwise": 0.03195926547050476, "w_quantize_rowwise": 0.026017427444458008, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.07734447717666626, "w_quantize_global_transpose": 0.0788569450378418, "time_standard": 1.3122186064720154, "time_rowwise": 1.2863576412200928, "time_global": 1.235615462064743}
{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0111741721630096, "standard_gw": 0.9267590939998627, "standard_gx": 0.8254274725914001, "rowwise_fwd": 0.5434826016426086, "rowwise_bwd": 0.4077926278114319, "global_fwd": 0.5318708717823029, "global_bwd": 0.40537863969802856, "x_quantize_rowwise": 0.059738755226135254, "g_quantize_rowwise": 0.2299174666404724, "w_quantize_rowwise": 0.02545863389968872, "w_quantize_colwise_transpose": 0.14269724488258362, "w_quantize_global": 0.07300823926925659, "w_quantize_global_transpose": 0.07878988981246948, "time_standard": 2.7633607387542725, "time_rowwise": 2.335846424102783, "time_global": 2.305462956428528}
{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8095316588878632, "standard_gw": 0.8607134222984314, "standard_gx": 0.9204968810081482, "rowwise_fwd": 0.4275888204574585, "rowwise_bwd": 0.5485899746417999, "global_fwd": 0.41000545024871826, "global_bwd": 0.5317628383636475, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.059254467487335205, "w_quantize_rowwise": 0.02466142177581787, "w_quantize_colwise_transpose": 0.14865398406982422, "w_quantize_global": 0.07582828402519226, "w_quantize_global_transpose": 0.08231401443481445, "time_standard": 2.5907419621944427, "time_rowwise": 2.2996440529823303, "time_global": 2.2500604391098022}
{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0658522844314575, "standard_gw": 1.718364655971527, "standard_gx": 1.6660578548908234, "rowwise_fwd": 1.066897064447403, "rowwise_bwd": 0.8070804178714752, "global_fwd": 1.0473169386386871, "global_bwd": 0.8021742105484009, "x_quantize_rowwise": 0.11274218559265137, "g_quantize_rowwise": 0.4518181085586548, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.14259666204452515, "w_quantize_global": 0.07484853267669678, "w_quantize_global_transpose": 0.07976219058036804, "time_standard": 5.450274795293808, "time_rowwise": 4.326000809669495, "time_global": 4.287026822566986}
{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.7549192309379578, "standard_gw": 1.6954988241195679, "standard_gx": 1.8179528415203094, "rowwise_fwd": 0.8649080991744995, "rowwise_bwd": 1.0746456682682037, "global_fwd": 0.8023083209991455, "global_bwd": 1.0471977293491364, "x_quantize_rowwise": 0.45225024223327637, "g_quantize_rowwise": 0.11286512017250061, "w_quantize_rowwise": 0.0252649188041687, "w_quantize_colwise_transpose": 0.14732033014297485, "w_quantize_global": 0.07537379860877991, "w_quantize_global_transpose": 0.0807642936706543, "time_standard": 6.268370896577835, "time_rowwise": 4.372753202915192, "time_global": 4.266258329153061}
{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.098430275917053, "standard_gw": 3.3501461148262024, "standard_gx": 5.560480058193207, "rowwise_fwd": 2.112947404384613, "rowwise_bwd": 1.605246216058731, "global_fwd": 2.0697638392448425, "global_bwd": 1.5953518450260162, "x_quantize_rowwise": 0.21921470761299133, "g_quantize_rowwise": 0.8956789970397949, "w_quantize_rowwise": 0.02710893750190735, "w_quantize_colwise_transpose": 0.14268234372138977, "w_quantize_global": 0.07259473204612732, "w_quantize_global_transpose": 0.07899105548858643, "time_standard": 13.009056448936462, "time_rowwise": 8.35302472114563, "time_global": 8.281741291284561}
{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.586959421634674, "standard_gw": 3.358360379934311, "standard_gx": 3.6434978246688843, "rowwise_fwd": 1.6269534826278687, "rowwise_bwd": 2.128206193447113, "global_fwd": 1.5950687229633331, "global_bwd": 2.0831897854804993, "x_quantize_rowwise": 0.8954145014286041, "g_quantize_rowwise": 0.21914392709732056, "w_quantize_rowwise": 0.026203691959381104, "w_quantize_colwise_transpose": 0.14658644795417786, "w_quantize_global": 0.07478520274162292, "w_quantize_global_transpose": 0.07964670658111572, "time_standard": 12.58881762623787, "time_rowwise": 8.400868624448776, "time_global": 8.305609226226807}
{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.229725062847137, "standard_gw": 6.791356950998306, "standard_gx": 6.806455552577972, "rowwise_fwd": 4.252471029758453, "rowwise_bwd": 3.2062679529190063, "global_fwd": 4.175614565610886, "global_bwd": 3.1837262213230133, "x_quantize_rowwise": 0.4321373999118805, "g_quantize_rowwise": 1.787092536687851, "w_quantize_rowwise": 0.0270158052444458, "w_quantize_colwise_transpose": 0.1424252986907959, "w_quantize_global": 0.07348507642745972, "w_quantize_global_transpose": 0.07829815149307251, "time_standard": 21.827537566423416, "time_rowwise": 16.63876697421074, "time_global": 16.52171090245247}
{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.279478669166565, "standard_gw": 6.7345499992370605, "standard_gx": 7.206875830888748, "rowwise_fwd": 3.209315240383148, "rowwise_bwd": 4.256397485733032, "global_fwd": 3.180190920829773, "global_bwd": 4.177983850240707, "x_quantize_rowwise": 1.7836056649684906, "g_quantize_rowwise": 0.4321075975894928, "w_quantize_rowwise": 0.03205239772796631, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09316205978393555, "w_quantize_global_transpose": 0.10086596012115479, "time_standard": 25.220904499292374, "time_rowwise": 16.5947787463665, "time_global": 16.502466052770615}
{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5776733160018921, "standard_gw": 0.5300231277942657, "standard_gx": 0.6005913019180298, "rowwise_fwd": 0.33330172300338745, "rowwise_bwd": 0.2957060933113098, "global_fwd": 0.32876431941986084, "global_bwd": 0.29108673334121704, "x_quantize_rowwise": 0.03466755151748657, "g_quantize_rowwise": 0.12264400720596313, "w_quantize_rowwise": 0.033874064683914185, "w_quantize_colwise_transpose": 0.1775398850440979, "w_quantize_global": 0.09503215551376343, "w_quantize_global_transpose": 0.10617449879646301, "time_standard": 1.7082877457141876, "time_rowwise": 1.5277564525604248, "time_global": 1.5083923935890198}
{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5164109170436859, "standard_gw": 0.5367249250411987, "standard_gx": 0.5876161158084869, "rowwise_fwd": 0.3132447600364685, "rowwise_bwd": 0.3396235406398773, "global_fwd": 0.2943649888038635, "global_bwd": 0.33209100365638733, "x_quantize_rowwise": 0.12357160449028015, "g_quantize_rowwise": 0.035997480154037476, "w_quantize_rowwise": 0.03213062882423401, "w_quantize_colwise_transpose": 0.17676874995231628, "w_quantize_global": 0.09861215949058533, "w_quantize_global_transpose": 0.0998862087726593, "time_standard": 1.6407519578933716, "time_rowwise": 1.5580616891384125, "time_global": 1.5212483704090118}
{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.2096501886844635, "standard_gw": 1.0663382709026337, "standard_gx": 1.0961703956127167, "rowwise_fwd": 0.6396733224391937, "rowwise_bwd": 0.5173943936824799, "global_fwd": 0.6296299397945404, "global_bwd": 0.5130060017108917, "x_quantize_rowwise": 0.06211921572685242, "g_quantize_rowwise": 0.2361498773097992, "w_quantize_rowwise": 0.03260001540184021, "w_quantize_colwise_transpose": 0.17679482698440552, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.09913742542266846, "time_standard": 3.372158855199814, "time_rowwise": 2.7310699224472046, "time_global": 2.6999935507774353}
{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.1065565049648285, "standard_gw": 1.0664314031600952, "standard_gx": 1.1266544461250305, "rowwise_fwd": 0.5352050065994263, "rowwise_bwd": 0.6464086472988129, "global_fwd": 0.513765960931778, "global_bwd": 0.6284862756729126, "x_quantize_rowwise": 0.23620948195457458, "g_quantize_rowwise": 0.062271952629089355, "w_quantize_rowwise": 0.031460076570510864, "w_quantize_colwise_transpose": 0.17675384879112244, "w_quantize_global": 0.09486451745033264, "w_quantize_global_transpose": 0.09898096323013306, "time_standard": 3.2996423542499542, "time_rowwise": 2.7547404170036316, "time_global": 2.7010105550289154}
{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.4367496371269226, "standard_gw": 2.0806193351745605, "standard_gx": 2.19624862074852, "rowwise_fwd": 1.2554042041301727, "rowwise_bwd": 1.0227933526039124, "global_fwd": 1.2322552502155304, "global_bwd": 1.0152235627174377, "x_quantize_rowwise": 0.11792033910751343, "g_quantize_rowwise": 0.4639364778995514, "w_quantize_rowwise": 0.03241002559661865, "w_quantize_colwise_transpose": 0.17657503485679626, "w_quantize_global": 0.09655207395553589, "w_quantize_global_transpose": 0.09958073496818542, "time_standard": 6.713617593050003, "time_rowwise": 5.149658769369125, "time_global": 5.106087774038315}
{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.1935217082500458, "standard_gw": 2.0055584609508514, "standard_gx": 2.1882541477680206, "rowwise_fwd": 1.0396353900432587, "rowwise_bwd": 1.2542344629764557, "global_fwd": 1.0161921381950378, "global_bwd": 1.233428716659546, "x_quantize_rowwise": 0.4642195999622345, "g_quantize_rowwise": 0.11782720685005188, "w_quantize_rowwise": 0.033117830753326416, "w_quantize_colwise_transpose": 0.17696991562843323, "w_quantize_global": 0.09416043758392334, "w_quantize_global_transpose": 0.10101497173309326, "time_standard": 6.387334316968918, "time_rowwise": 5.091562867164612, "time_global": 5.032401531934738}
{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.804681986570358, "standard_gw": 4.763372242450714, "standard_gx": 4.064023494720459, "rowwise_fwd": 2.484843134880066, "rowwise_bwd": 1.9691288471221924, "global_fwd": 2.441786229610443, "global_bwd": 1.9574686884880066, "x_quantize_rowwise": 0.2294592559337616, "g_quantize_rowwise": 0.9196549654006958, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.1768544316291809, "w_quantize_global": 0.09644776582717896, "w_quantize_global_transpose": 0.09847059845924377, "time_standard": 13.632077723741531, "time_rowwise": 10.574690997600555, "time_global": 10.506659746170044}
{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.0907710790634155, "standard_gw": 3.9793066680431366, "standard_gx": 4.302978515625, "rowwise_fwd": 1.992940902709961, "rowwise_bwd": 2.4996213614940643, "global_fwd": 1.9551962614059448, "global_bwd": 2.457551658153534, "x_quantize_rowwise": 0.9200014173984528, "g_quantize_rowwise": 0.2293996512889862, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.17882883548736572, "w_quantize_global": 0.09540095925331116, "w_quantize_global_transpose": 0.09880587458610535, "time_standard": 12.373056262731552, "time_rowwise": 9.831476956605911, "time_global": 9.73566249012947}
{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.655728936195374, "standard_gw": 8.261296898126602, "standard_gx": 8.064884692430496, "rowwise_fwd": 5.007706582546234, "rowwise_bwd": 3.8615092635154724, "global_fwd": 4.920527338981628, "global_bwd": 3.8330331444740295, "x_quantize_rowwise": 0.45276060700416565, "g_quantize_rowwise": 1.8306002020835876, "w_quantize_rowwise": 0.031366944313049316, "w_quantize_colwise_transpose": 0.1766495406627655, "w_quantize_global": 0.09412690997123718, "w_quantize_global_transpose": 0.09780004620552063, "time_standard": 25.981910526752472, "time_rowwise": 19.621890038251877, "time_global": 19.49014514684677}
{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 8.033104240894318, "standard_gw": 8.2889124751091, "standard_gx": 8.622754365205765, "rowwise_fwd": 3.8747042417526245, "rowwise_bwd": 5.003921687602997, "global_fwd": 3.8315393030643463, "global_bwd": 4.9162134528160095, "x_quantize_rowwise": 1.8304847180843353, "g_quantize_rowwise": 0.4522763192653656, "w_quantize_rowwise": 0.03413110971450806, "w_quantize_colwise_transpose": 0.1771189272403717, "w_quantize_global": 0.09519979357719421, "w_quantize_global_transpose": 0.09930506348609924, "time_standard": 24.944771081209183, "time_rowwise": 19.661549478769302, "time_global": 19.51393112540245}
{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7954612374305725, "standard_gw": 0.7456131279468536, "standard_gx": 0.8799619972705841, "rowwise_fwd": 0.43267011642456055, "rowwise_bwd": 0.34622475504875183, "global_fwd": 0.42615458369255066, "global_bwd": 0.344250351190567, "x_quantize_rowwise": 0.03748014569282532, "g_quantize_rowwise": 0.13304129242897034, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.2407953143119812, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10305643081665039, "time_standard": 2.4210363626480103, "time_rowwise": 1.96877121925354, "time_global": 1.8842294812202454}
{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7120333611965179, "standard_gw": 0.7622130215167999, "standard_gx": 0.8262209594249725, "rowwise_fwd": 0.3702230751514435, "rowwise_bwd": 0.4419572651386261, "global_fwd": 0.3479123115539551, "global_bwd": 0.4306286573410034, "x_quantize_rowwise": 0.13308599591255188, "g_quantize_rowwise": 0.037495046854019165, "w_quantize_rowwise": 0.03398209810256958, "w_quantize_colwise_transpose": 0.23782625794410706, "w_quantize_global": 0.09853765368461609, "w_quantize_global_transpose": 0.10247156023979187, "time_standard": 2.3004673421382904, "time_rowwise": 2.016782760620117, "time_global": 1.9123442471027374}
{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6292817890644073, "standard_gw": 1.5109702944755554, "standard_gx": 1.482747495174408, "rowwise_fwd": 0.8386112749576569, "rowwise_bwd": 0.6844550371170044, "global_fwd": 0.8220970630645752, "global_bwd": 0.6802082061767578, "x_quantize_rowwise": 0.06883963942527771, "g_quantize_rowwise": 0.25641173124313354, "w_quantize_rowwise": 0.033054500818252563, "w_quantize_colwise_transpose": 0.24027004837989807, "w_quantize_global": 0.0967271625995636, "w_quantize_global_transpose": 0.102948397397995, "time_standard": 4.622999578714371, "time_rowwise": 3.6326125264167786, "time_global": 3.5382024943828583}
{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4877021312713623, "standard_gw": 1.5015341341495514, "standard_gx": 1.529306173324585, "rowwise_fwd": 0.715944916009903, "rowwise_bwd": 0.8529908955097198, "global_fwd": 0.680088996887207, "global_bwd": 0.8224695920944214, "x_quantize_rowwise": 0.2568177878856659, "g_quantize_rowwise": 0.06864592432975769, "w_quantize_rowwise": 0.03343448042869568, "w_quantize_colwise_transpose": 0.23645907640457153, "w_quantize_global": 0.09399279952049255, "w_quantize_global_transpose": 0.10286271572113037, "time_standard": 4.518542438745499, "time_rowwise": 3.665827214717865, "time_global": 3.5264119505882263}
{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.261040896177292, "standard_gw": 2.8816498816013336, "standard_gx": 2.8357282280921936, "rowwise_fwd": 1.6594752669334412, "rowwise_bwd": 1.359265297651291, "global_fwd": 1.6287527978420258, "global_bwd": 1.3503879308700562, "x_quantize_rowwise": 0.13146549463272095, "g_quantize_rowwise": 0.5035959184169769, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.24086236953735352, "w_quantize_global": 0.0945068895816803, "w_quantize_global_transpose": 0.10332837700843811, "time_standard": 8.978419005870819, "time_rowwise": 6.8106986582279205, "time_global": 6.693687289953232}
{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.848360687494278, "standard_gw": 2.8955675661563873, "standard_gx": 3.0499882996082306, "rowwise_fwd": 1.3900883495807648, "rowwise_bwd": 1.6595833003520966, "global_fwd": 1.3514049351215363, "global_bwd": 1.629263162612915, "x_quantize_rowwise": 0.5036592483520508, "g_quantize_rowwise": 0.13118237257003784, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.23709610104560852, "w_quantize_global": 0.0951625406742096, "w_quantize_global_transpose": 0.10216236114501953, "time_standard": 8.793916553258896, "time_rowwise": 6.851561367511749, "time_global": 6.708402186632156}
{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.4978525042533875, "standard_gw": 6.462603807449341, "standard_gx": 5.5987648665905, "rowwise_fwd": 3.2996535301208496, "rowwise_bwd": 2.6320070028305054, "global_fwd": 3.2426007091999054, "global_bwd": 2.612769603729248, "x_quantize_rowwise": 0.2561397850513458, "g_quantize_rowwise": 0.9984448552131653, "w_quantize_rowwise": 0.033076852560043335, "w_quantize_colwise_transpose": 0.24232640862464905, "w_quantize_global": 0.09618699550628662, "w_quantize_global_transpose": 0.10257214307785034, "time_standard": 18.559221178293228, "time_rowwise": 13.9242522418499, "time_global": 13.771317899227142}
{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.5702440440654755, "standard_gw": 5.717620253562927, "standard_gx": 6.08203187584877, "rowwise_fwd": 2.649586647748947, "rowwise_bwd": 3.315173089504242, "global_fwd": 2.6132799685001373, "global_bwd": 3.257807344198227, "x_quantize_rowwise": 0.9980201721191406, "g_quantize_rowwise": 0.256560742855072, "w_quantize_rowwise": 0.03356859087944031, "w_quantize_colwise_transpose": 0.23729726672172546, "w_quantize_global": 0.09495764970779419, "w_quantize_global_transpose": 0.103779137134552, "time_standard": 17.369896173477173, "time_rowwise": 13.207826763391495, "time_global": 13.04202526807785}
{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 13.058379292488098, "standard_gw": 11.480242013931274, "standard_gx": 11.092845350503922, "rowwise_fwd": 6.637874990701675, "rowwise_bwd": 5.24790957570076, "global_fwd": 6.521012634038925, "global_bwd": 5.214303731918335, "x_quantize_rowwise": 0.5057565867900848, "g_quantize_rowwise": 1.989319920539856, "w_quantize_rowwise": 0.03439188003540039, "w_quantize_colwise_transpose": 0.24280324578285217, "w_quantize_global": 0.09520724415779114, "w_quantize_global_transpose": 0.10240450501441956, "time_standard": 35.631466656923294, "time_rowwise": 26.138298213481903, "time_global": 25.908246636390686}
{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.13397628068924, "standard_gw": 11.371888220310211, "standard_gx": 12.12756335735321, "rowwise_fwd": 5.2495077252388, "rowwise_bwd": 6.638709455728531, "global_fwd": 5.215313285589218, "global_bwd": 6.5222084522247314, "x_quantize_rowwise": 1.9870512187480927, "g_quantize_rowwise": 0.5058236420154572, "w_quantize_rowwise": 0.034634023904800415, "w_quantize_colwise_transpose": 0.23674964904785156, "w_quantize_global": 0.09457767009735107, "w_quantize_global_transpose": 0.10183081030845642, "time_standard": 34.63342785835266, "time_rowwise": 26.024363934993744, "time_global": 25.798693299293518}
{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2125298380851746, "standard_gw": 1.1111274361610413, "standard_gx": 1.0840706527233124, "rowwise_fwd": 0.6057210266590118, "rowwise_bwd": 0.51865354180336, "global_fwd": 0.5952082574367523, "global_bwd": 0.5167685449123383, "x_quantize_rowwise": 0.045686960220336914, "g_quantize_rowwise": 0.15827640891075134, "w_quantize_rowwise": 0.04361197352409363, "w_quantize_colwise_transpose": 0.34067779779434204, "w_quantize_global": 0.13644620776176453, "w_quantize_global_transpose": 0.14925003051757812, "time_standard": 3.407727926969528, "time_rowwise": 2.823755145072937, "time_global": 2.7127638459205627}
{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.0731369256973267, "standard_gw": 1.1365897953510284, "standard_gx": 1.1498592793941498, "rowwise_fwd": 0.5573518574237823, "rowwise_bwd": 0.615488737821579, "global_fwd": 0.5220361053943634, "global_bwd": 0.5939789116382599, "x_quantize_rowwise": 0.15765801072120667, "g_quantize_rowwise": 0.04369020462036133, "w_quantize_rowwise": 0.047359615564346313, "w_quantize_colwise_transpose": 0.5526281893253326, "w_quantize_global": 0.13606995344161987, "w_quantize_global_transpose": 0.15017390251159668, "time_standard": 3.359586000442505, "time_rowwise": 3.1107664108276367, "time_global": 2.7401968836784363}
{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4274885654449463, "standard_gw": 2.1799951791763306, "standard_gx": 2.1426528692245483, "rowwise_fwd": 1.195710152387619, "rowwise_bwd": 1.027170568704605, "global_fwd": 1.1747106909751892, "global_bwd": 1.0251589119434357, "x_quantize_rowwise": 0.08098781108856201, "g_quantize_rowwise": 0.3052949905395508, "w_quantize_rowwise": 0.043764710426330566, "w_quantize_colwise_transpose": 0.33987686038017273, "w_quantize_global": 0.13646483421325684, "w_quantize_global_transpose": 0.14739856123924255, "time_standard": 6.750136613845825, "time_rowwise": 5.172800272703171, "time_global": 5.050010979175568}
{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1661892533302307, "standard_gw": 2.0948275923728943, "standard_gx": 2.306375652551651, "rowwise_fwd": 1.0587647557258606, "rowwise_bwd": 1.1999905109405518, "global_fwd": 1.0296404361724854, "global_bwd": 1.1749230325222015, "x_quantize_rowwise": 0.3054030239582062, "g_quantize_rowwise": 0.08077546954154968, "w_quantize_rowwise": 0.047225505113601685, "w_quantize_colwise_transpose": 0.600133091211319, "w_quantize_global": 0.13613328337669373, "w_quantize_global_transpose": 0.1484006643295288, "time_standard": 6.567392498254776, "time_rowwise": 5.387119948863983, "time_global": 4.97010350227356}
{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.807606339454651, "standard_gw": 4.170913249254227, "standard_gx": 4.117622971534729, "rowwise_fwd": 2.370934933423996, "rowwise_bwd": 1.9481778144836426, "global_fwd": 2.3383721709251404, "global_bwd": 1.9443817436695099, "x_quantize_rowwise": 0.1547597348690033, "g_quantize_rowwise": 0.6000511348247528, "w_quantize_rowwise": 0.04361942410469055, "w_quantize_colwise_transpose": 0.3403201699256897, "w_quantize_global": 0.13600289821624756, "w_quantize_global_transpose": 0.1474134624004364, "time_standard": 13.096142560243607, "time_rowwise": 9.628776460886002, "time_global": 9.491894394159317}
{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.1619837284088135, "standard_gw": 4.181284457445145, "standard_gx": 4.635505378246307, "rowwise_fwd": 1.9684135913848877, "rowwise_bwd": 2.3750364780426025, "global_fwd": 1.9445866346359253, "global_bwd": 2.3551955819129944, "x_quantize_rowwise": 0.6004162132740021, "g_quantize_rowwise": 0.15468522906303406, "w_quantize_rowwise": 0.04730746150016785, "w_quantize_colwise_transpose": 0.5999617278575897, "w_quantize_global": 0.1364201307296753, "w_quantize_global_transpose": 0.14847144484519958, "time_standard": 12.978773564100266, "time_rowwise": 9.927105158567429, "time_global": 9.521059691905975}
{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.52371209859848, "standard_gw": 8.354485034942627, "standard_gx": 8.69860127568245, "rowwise_fwd": 4.717472940683365, "rowwise_bwd": 3.8843750953674316, "global_fwd": 4.645414650440216, "global_bwd": 3.8761012256145477, "x_quantize_rowwise": 0.3024861216545105, "g_quantize_rowwise": 1.1897757649421692, "w_quantize_rowwise": 0.04366785287857056, "w_quantize_colwise_transpose": 0.33988431096076965, "w_quantize_global": 0.1359507441520691, "w_quantize_global_transpose": 0.14724582433700562, "time_standard": 26.576798409223557, "time_rowwise": 18.832147121429443, "time_global": 18.651459366083145}
{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.307881653308868, "standard_gw": 8.214320987462997, "standard_gx": 9.21182706952095, "rowwise_fwd": 3.8919784128665924, "rowwise_bwd": 4.72346693277359, "global_fwd": 3.8761794567108154, "global_bwd": 4.673641175031662, "x_quantize_rowwise": 1.1893920600414276, "g_quantize_rowwise": 0.3024972975254059, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.6039328873157501, "w_quantize_global": 0.13624504208564758, "w_quantize_global_transpose": 0.14867261052131653, "time_standard": 25.734029710292816, "time_rowwise": 18.972668796777725, "time_global": 18.540948629379272}
{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.30372044444084, "standard_gw": 16.480475664138794, "standard_gx": 17.61433482170105, "rowwise_fwd": 9.49602946639061, "rowwise_bwd": 7.768530398607254, "global_fwd": 9.3533955514431, "global_bwd": 7.749464362859726, "x_quantize_rowwise": 0.5977451801300049, "g_quantize_rowwise": 2.3684948682785034, "w_quantize_rowwise": 0.04375725984573364, "w_quantize_colwise_transpose": 0.34042075276374817, "w_quantize_global": 0.13628974556922913, "w_quantize_global_transpose": 0.14671683311462402, "time_standard": 53.398530930280685, "time_rowwise": 37.09545359015465, "time_global": 36.83258220553398}
{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.041003495454788, "standard_gw": 17.770148813724518, "standard_gx": 17.70009845495224, "rowwise_fwd": 7.756810635328293, "rowwise_bwd": 9.502101689577103, "global_fwd": 7.7384114265441895, "global_bwd": 9.36170294880867, "x_quantize_rowwise": 2.3686252534389496, "g_quantize_rowwise": 0.5980581045150757, "w_quantize_rowwise": 0.04723668098449707, "w_quantize_colwise_transpose": 0.6035342812538147, "w_quantize_global": 0.13603642582893372, "w_quantize_global_transpose": 0.1485198736190796, "time_standard": 53.511250764131546, "time_rowwise": 38.64651545882225, "time_global": 38.121502846479416}
{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.598241299390793, "standard_gw": 4.294309765100479, "standard_gx": 4.261095076799393, "rowwise_fwd": 2.0976848900318146, "rowwise_bwd": 1.9718967378139496, "global_fwd": 2.0763762295246124, "global_bwd": 1.9703581929206848, "x_quantize_rowwise": 0.08216872811317444, "g_quantize_rowwise": 0.4405900835990906, "w_quantize_rowwise": 0.1553371548652649, "w_quantize_colwise_transpose": 1.6110725700855255, "w_quantize_global": 0.481240451335907, "w_quantize_global_transpose": 0.5061514675617218, "time_standard": 13.153646141290665, "time_rowwise": 10.653059929609299, "time_global": 9.85119491815567}
{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.35885414481163, "standard_gw": 4.29583340883255, "standard_gx": 4.5370906591415405, "rowwise_fwd": 2.0015686750411987, "rowwise_bwd": 2.097565680742264, "global_fwd": 1.969795674085617, "global_bwd": 2.075403928756714, "x_quantize_rowwise": 0.43984130024909973, "g_quantize_rowwise": 0.08216127753257751, "w_quantize_rowwise": 0.22544339299201965, "w_quantize_colwise_transpose": 2.4342015385627747, "w_quantize_global": 0.48087164759635925, "w_quantize_global_transpose": 0.5099289119243622, "time_standard": 13.19177821278572, "time_rowwise": 11.576615273952484, "time_global": 9.85383614897728}
{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.09888744354248, "standard_gw": 8.230950683355331, "standard_gx": 8.465446531772614, "rowwise_fwd": 4.182614386081696, "rowwise_bwd": 3.747660666704178, "global_fwd": 4.138719290494919, "global_bwd": 3.74777615070343, "x_quantize_rowwise": 0.15515834093093872, "g_quantize_rowwise": 0.8699297904968262, "w_quantize_rowwise": 0.15544891357421875, "w_quantize_colwise_transpose": 1.6132444143295288, "w_quantize_global": 0.48100948333740234, "w_quantize_global_transpose": 0.5051903426647186, "time_standard": 25.795284658670425, "time_rowwise": 18.955007195472717, "time_global": 18.128734081983566}
{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.378107100725174, "standard_gw": 8.923027664422989, "standard_gx": 9.049762040376663, "rowwise_fwd": 3.765825182199478, "rowwise_bwd": 4.183519631624222, "global_fwd": 3.744799643754959, "global_bwd": 4.1590481996536255, "x_quantize_rowwise": 0.8693933486938477, "g_quantize_rowwise": 0.1553073525428772, "w_quantize_rowwise": 0.2258792519569397, "w_quantize_colwise_transpose": 2.4386271834373474, "w_quantize_global": 0.4811100661754608, "w_quantize_global_transpose": 0.5102269351482391, "time_standard": 26.350896805524826, "time_rowwise": 20.5615796148777, "time_global": 18.842913210392}
{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.266115337610245, "standard_gw": 17.671160399913788, "standard_gx": 17.10302010178566, "rowwise_fwd": 8.347474038600922, "rowwise_bwd": 7.514089345932007, "global_fwd": 8.263226598501205, "global_bwd": 7.487393915653229, "x_quantize_rowwise": 0.3021806478500366, "g_quantize_rowwise": 1.7319358885288239, "w_quantize_rowwise": 0.15519559383392334, "w_quantize_colwise_transpose": 1.6133114695549011, "w_quantize_global": 0.48247724771499634, "w_quantize_global_transpose": 0.506427139043808, "time_standard": 53.04029583930969, "time_rowwise": 37.3353473842144, "time_global": 36.44480183720589}
{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 17.73649826645851, "standard_gw": 16.359902918338776, "standard_gx": 18.0993489921093, "rowwise_fwd": 7.493957877159119, "rowwise_bwd": 8.352488279342651, "global_fwd": 7.486194372177124, "global_bwd": 8.28903540968895, "x_quantize_rowwise": 1.7313472926616669, "g_quantize_rowwise": 0.30205026268959045, "w_quantize_rowwise": 0.2255477011203766, "w_quantize_colwise_transpose": 2.4363920092582703, "w_quantize_global": 0.4815347492694855, "w_quantize_global_transpose": 0.5103759467601776, "time_standard": 52.195750176906586, "time_rowwise": 36.90168634057045, "time_global": 35.16044095158577}
{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.309611052274704, "standard_gw": 32.85098075866699, "standard_gx": 34.34552624821663, "rowwise_fwd": 16.74525812268257, "rowwise_bwd": 15.026237815618515, "global_fwd": 16.574162989854813, "global_bwd": 14.977734535932541, "x_quantize_rowwise": 0.5954466760158539, "g_quantize_rowwise": 3.4569576382637024, "w_quantize_rowwise": 0.15521422028541565, "w_quantize_colwise_transpose": 1.6133897006511688, "w_quantize_global": 0.4822872579097748, "w_quantize_global_transpose": 0.5065612494945526, "time_standard": 103.50611805915833, "time_rowwise": 70.44348493218422, "time_global": 69.44413110613823}
{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.40017828345299, "standard_gw": 33.037226647138596, "standard_gx": 36.30436211824417, "rowwise_fwd": 15.043705701828003, "rowwise_bwd": 16.756191849708557, "global_fwd": 15.011314302682877, "global_bwd": 16.580048948526382, "x_quantize_rowwise": 3.4548528492450714, "g_quantize_rowwise": 0.5951337516307831, "w_quantize_rowwise": 0.22584572434425354, "w_quantize_colwise_transpose": 2.4329908192157745, "w_quantize_global": 0.4813261330127716, "w_quantize_global_transpose": 0.5101598799228668, "time_standard": 104.74176704883575, "time_rowwise": 71.54594734311104, "time_global": 69.67006251215935}
{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 73.40333238244057, "standard_gw": 73.76311346888542, "standard_gx": 70.41774317622185, "rowwise_fwd": 33.37597846984863, "rowwise_bwd": 30.345775187015533, "global_fwd": 33.00366923213005, "global_bwd": 30.218638479709625, "x_quantize_rowwise": 1.1825822293758392, "g_quantize_rowwise": 6.902601569890976, "w_quantize_rowwise": 0.15529245138168335, "w_quantize_colwise_transpose": 1.6109198331832886, "w_quantize_global": 0.48149004578590393, "w_quantize_global_transpose": 0.5066059529781342, "time_standard": 217.58418902754784, "time_rowwise": 147.33626320958138, "time_global": 146.05870097875595}
{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 71.5160183608532, "standard_gw": 73.76786693930626, "standard_gx": 72.98104092478752, "rowwise_fwd": 30.291248112916946, "rowwise_bwd": 33.36654230952263, "global_fwd": 30.181586742401123, "global_bwd": 33.082425594329834, "x_quantize_rowwise": 6.902430206537247, "g_quantize_rowwise": 1.1815279722213745, "w_quantize_rowwise": 0.2262219786643982, "w_quantize_colwise_transpose": 2.4421699345111847, "w_quantize_global": 0.4816502332687378, "w_quantize_global_transpose": 0.5105249583721161, "time_standard": 218.26492622494698, "time_rowwise": 148.17800745368004, "time_global": 146.1080126464367}

View File

@ -12,12 +12,18 @@ if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5)) fig = plt.figure(tight_layout=True, figsize=(12,3.5))
gs = gridspec.GridSpec(1, 2) gs = gridspec.GridSpec(1, 2)
dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096]
batch_size_for_plot1 = 32768
batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17]
dims_to_xtick = [1024, 2048, 4096]
logscale_plot1 = True
ax = fig.add_subplot(gs[0, 0]) ax = fig.add_subplot(gs[0, 0])
rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True) rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True)
df = rdf[rdf.batch_size == 32768] df = rdf[rdf.batch_size == batch_size_for_plot1]
# first plot the time occupied by different operations
for k, marker, ls, color, name in [ for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'),
@ -29,17 +35,15 @@ if __name__ == '__main__':
('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'),
('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'),
#### time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'),
('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'),
('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'), ('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'),
('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'), ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'),
#('standard_gw', '.', '--', 'C1', 'standard_gw'),
]: ]:
xs = [] xs = []
ys = [] ys = []
for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: for embed_dim in dims_to_consider:
# average over dim -> 4*dim and 4*dim -> dim
df_ = df[df.dim_in == embed_dim] df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim) xs.append(embed_dim)
@ -56,24 +60,20 @@ if __name__ == '__main__':
ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.)
ax.set_xlabel('dim', fontsize=13) ax.set_xlabel('dim', fontsize=13)
ax.set_ylabel('time (ms)', fontsize=13) ax.set_ylabel('time (ms)', fontsize=13)
# make a legend which is below the plot
ax.grid() ax.grid()
ax.set_xscale('log') ax.set_xscale('log')
#ax.set_yscale('log') if logscale_plot1:
ax.set_yscale('log')
ax.tick_params(axis='x', labelsize=11) ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11) ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048, 4096]) ax.set_xticks(dims_to_xtick)
ax.set_xticklabels([1024, 2048, 4096]) ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True) ax.set_xticks([], minor=True)
leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10)
@ -86,7 +86,7 @@ if __name__ == '__main__':
ax = fig.add_subplot(gs[0, 1]) ax = fig.add_subplot(gs[0, 1])
# now plot the % speedup for different batch sizes # now plot the % speedup for different batch sizes
for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]): for j, batch_size in enumerate(batch_sizes_for_plot2):
all_xs, all_ys = [], [] all_xs, all_ys = [], []
for k, marker, ls, color, name in [ for k, marker, ls, color, name in [
('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'),
@ -95,7 +95,7 @@ if __name__ == '__main__':
xs, ys = [], [] xs, ys = [], []
df = rdf[rdf.batch_size == batch_size] df = rdf[rdf.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]: for embed_dim in dims_to_consider:
df_ = df[df.dim_in == embed_dim] df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4] df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim) xs.append(embed_dim)
@ -125,13 +125,13 @@ if __name__ == '__main__':
ax.tick_params(axis='x', labelsize=11) ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11) ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048, 4096]) ax.set_xticks(dims_to_xtick)
ax.set_xticklabels([1024, 2048, 4096]) ax.set_xticklabels(dims_to_xtick)
ax.set_xticks([], minor=True) ax.set_xticks([], minor=True)
ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
plt.savefig('tests/triton_tests/plot1.pdf', bbox_inches='tight') plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight')

View File

@ -0,0 +1,101 @@
import json
import time
import torch
import torch.nn as nn
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise import quantize_rowwise
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
def get_time(k, fn, info_dict):
for _ in range(repeat // 2):
fn()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
fn()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info_dict[k] = ms
if __name__ == '__main__':
torch.manual_seed(0)
wm = 4
for dim in [1024, 1280, 1408, 1664, 2048, 4096]:
# note "batch_size" is actually "batch_size * embed_dim", which is why it's large
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
# switch switches dim_in and dim_out
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
dim_out = dim * wm
dim_in = dim
if switch:
dim_out = dim
dim_in = wm * dim
dim_in = round(dim_in)
dim_out = round(dim_out)
# simulate forward pass
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
x_int8 = x.clone().to(torch.int8)
g_int8 = g.clone().to(torch.int8)
w_int8 = w.clone().to(torch.int8)
wt_int8 = w.t().contiguous().clone().to(torch.int8)
state_x_rowwise = x.max(dim=1)[0]
state_g_rowwise = g.max(dim=1)[0]
state_w_columnwise = w.max(dim=0)[0]
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
get_time('standard_fwd', lambda : x.matmul(w.t()), info)
get_time('standard_gw', lambda : g.t().matmul(x), info)
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info)
get_time('w_quantize_global', lambda : quantize_global(w), info)
get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info)
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info_json = json.dumps(info)
with open("speed_benchmark/info_a100_py2.jsonl", "a") as file:
file.write(info_json + "\n")

View File

@ -1,44 +1,57 @@
import pytest import pytest
import torch import torch
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear from bitsandbytes.nn.triton_based_modules import SwitchBackLinear
from bitsandbytes.nn import Linear8bitLt
@pytest.mark.parametrize("vectorrize", [False, True])
@pytest.mark.parametrize("triton_module", [SwitchBackGlobalLinear, SwitchBackLinear]) def test_switchback(vectorrize):
def test_switchbatch(triton_module):
for dim in [83, 17, 128]: for dim in [83, 17, 128]:
for batch in [13, 128, 256]: for batch in [13, 128, 256]:
standard = torch.nn.Linear(dim, 4 * dim).cuda().half() standard = torch.nn.Linear(dim, 4 * dim).cuda().half()
switchback = triton_module(dim, 4 * dim).cuda().half() print('vectorrize', vectorrize)
switchback = SwitchBackLinear(dim, 4 * dim, vectorize=vectorrize).cuda().half()
baseline = Linear8bitLt(dim, 4 * dim).cuda().half()
switchback.weight.data.copy_(standard.weight) switchback.weight.data.copy_(standard.weight)
switchback.bias.data.copy_(standard.bias) switchback.bias.data.copy_(standard.bias)
baseline.weight.data.copy_(standard.weight)
baseline.bias.data.copy_(standard.bias)
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True)
x2 = x1.clone().detach().requires_grad_(True)
x3 = x1.clone().detach().requires_grad_(True)
for i in range(100): out_standard = standard(x1)
x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True) (2**10 * out_standard.abs().mean()).backward()
x2 = x1.clone().detach().requires_grad_(True)
print('standard')
out_standard = standard(x1)
print('switchback')
out_sb = switchback(x1)
(out_standard.abs().mean()).backward() out_sb = switchback(x2)
(out_sb.abs().mean()).backward() (2**10 * out_sb.abs().mean()).backward()
err_sb = (out_standard - out_sb).abs().mean() out_baseline = baseline(x3)
print('OUT', err_sb) (2**10 * out_baseline.abs().mean()).backward()
err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() err_sb = (out_standard - out_sb).abs().mean()
err_baseline = (out_standard - out_baseline).abs().mean()
print('OUT', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
print('GW2', err_sb) err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean()
err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean()
err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() print('GW2', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
print('GW1', err_sb) err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean()
err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean()
#err_sb = (x1.grad - x2.grad).abs().mean() print('GW1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline
#print('GX1', err_sb) err_sb = (x1.grad - x2.grad).abs().mean()
err_baseline = (x1.grad - x3.grad).abs().mean()
print('GX1', err_sb, err_baseline)
assert err_sb < 2 * err_baseline

View File

@ -1,363 +0,0 @@
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
import time
# class AttentionOld(torch.nn.Module):
# def __init__(
# self,
# dim,
# num_heads=8,
# qkv_bias=True,
# scaled_cosine=False,
# scale_heads=False,
# attn_drop=0.,
# proj_drop=0.,
# linear_module=torch.nn.Linear,
# ):
# super().__init__()
# self.scaled_cosine = scaled_cosine
# self.scale_heads = scale_heads
# assert dim % num_heads == 0, 'dim should be divisible by num_heads'
# self.num_heads = num_heads
# self.head_dim = dim // num_heads
# self.scale = self.head_dim ** -0.5
# self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias)
# self.attn_drop = torch.nn.Dropout(attn_drop)
# if self.scale_heads:
# self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1)))
# else:
# self.head_scale = None
# self.out_proj = linear_module(dim, dim)
# self.out_drop = torch.nn.Dropout(proj_drop)
# def forward(self, x, attn_mask = None):
# L, N, C = x.shape
# q, k, v = self.in_proj_linear(x).chunk(3, dim=-1)
# q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
# q = q * self.scale
# attn = torch.bmm(q, k.transpose(-1, -2))
# if attn_mask is not None:
# if attn_mask.dtype == torch.bool:
# new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
# new_attn_mask.masked_fill_(attn_mask, float("-inf"))
# attn_mask = new_attn_mask
# attn += attn_mask
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
# x = torch.bmm(attn, v)
# x = x.transpose(0, 1).reshape(L, N, C)
# x = self.out_proj(x)
# x = self.out_drop(x)
# return x
class Attention(torch.nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=True,
scaled_cosine=False,
scale_heads=False,
attn_drop=0.,
proj_drop=0.,
linear_module=torch.nn.Linear,
):
super().__init__()
self.scaled_cosine = scaled_cosine
self.scale_heads = scale_heads
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.ln = torch.nn.LayerNorm(dim)
self.in_proj_linear = linear_module(dim, 3 * dim, bias = qkv_bias)
self.attn_drop = torch.nn.Dropout(attn_drop)
if self.scale_heads:
self.head_scale = torch.nn.Parameter(torch.ones((num_heads, 1, 1)))
else:
self.head_scale = None
self.out_proj = linear_module(dim, dim)
self.out_drop = torch.nn.Dropout(proj_drop)
def forward(self, x, attn_mask = None):
q, k, v = self.in_proj_linear(self.ln(x)).chunk(3, dim=-1)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)
x = self.out_proj(x)
return x
if __name__ == '__main__':
for dim in [1024, 1280, 1408, 1664, 2048]:
for batch in [2**14, 2**15, 2**16, 2**17]:
# if dim != 4096 or batch != 2**17:
# continue
x1 = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
qu = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
ke = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
va = torch.randn( batch // 256, 256, dim ).cuda().requires_grad_(True)
standard = Attention(dim).cuda()
my_standard = Attention(dim, linear_module=StandardLinear).cuda()
sb = Attention(dim, linear_module=SwitchBackGlobalLinear).cuda()
standard_compiled = torch.compile(standard)
ln_model = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
torch.nn.LayerNorm(dim),
).cuda()
ln_model_compiled = torch.compile(
ln_model
)
gelu_model = torch.nn.Sequential(
torch.nn.GELU(),
).cuda()
gelu_model_compiled = torch.compile(
gelu_model
)
print('Model part 2')
repeat = 32
info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim}
k = 'attn'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va)
((2 ** 16) * out_attn).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_attn = torch.nn.functional.scaled_dot_product_attention(qu, ke, va)
((2 ** 16) * out_attn).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'ln'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = ln_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = ln_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'ln_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = ln_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = ln_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'gelu'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = gelu_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = gelu_model(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'gelu_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out = gelu_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out = gelu_model_compiled(x1)
((2 ** 16) * out).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'my_standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
#
#
x1.grad.zero_()
k = 'standard_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'sb'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
info_json = json.dumps(info)
with open("tests/triton_tests/attn_info_ln.jsonl", "a") as file:
file.write(info_json + "\n")
#exit()
# err_fused = (out_standard - out_fused).abs().mean()
# err_sb = (out_standard - out_sb).abs().mean()
# print('OUT', err_fused, err_sb)
# err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean()
# err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean()
# print('GW2', err_fused, err_sb)
# err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean()
# err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean()
# print('GW1', err_fused, err_sb)
# err_fused = (x1.grad - x2.grad).abs().mean()
# err_sb = (x1.grad - x3.grad).abs().mean()
# print('GX1', err_fused, err_sb)
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.

View File

@ -1,20 +0,0 @@
{"repeat": 32, "batch_size": 16384, "dim": 1024, "attn": 2.1414458751678467, "ln": 1.6365647315979004, "ln_compiled": 1.799367368221283, "gelu": 1.0930374264717102, "gelu_compiled": 1.094818115234375, "standard": 4.159651696681976, "my_standard": 4.696495831012726, "standard_compiled": 3.675594925880432, "sb": 4.1465312242507935}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "attn": 4.100345075130463, "ln": 3.1594187021255493, "ln_compiled": 3.437422215938568, "gelu": 2.109348773956299, "gelu_compiled": 2.11450457572937, "standard": 7.706902921199799, "my_standard": 8.799396455287933, "standard_compiled": 6.735652685165405, "sb": 7.66376405954361}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "attn": 7.953710854053497, "ln": 6.236426532268524, "ln_compiled": 6.746955215930939, "gelu": 4.164382815361023, "gelu_compiled": 4.171714186668396, "standard": 14.894917607307434, "my_standard": 17.042435705661774, "standard_compiled": 12.985721230506897, "sb": 14.6140456199646}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "attn": 15.638880431652069, "ln": 12.333884835243225, "ln_compiled": 13.272866606712341, "gelu": 8.228793740272522, "gelu_compiled": 8.243747055530548, "standard": 29.425136744976044, "my_standard": 35.08377820253372, "standard_compiled": 25.69487690925598, "sb": 28.760001063346863}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "attn": 2.627238631248474, "ln": 2.0098239183425903, "ln_compiled": 2.4197474122047424, "gelu": 1.3455823063850403, "gelu_compiled": 1.35069340467453, "standard": 5.554787814617157, "my_standard": 6.2290579080581665, "standard_compiled": 5.132324993610382, "sb": 5.4178386926651}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "attn": 5.0596073269844055, "ln": 3.903590142726898, "ln_compiled": 4.719957709312439, "gelu": 2.6203468441963196, "gelu_compiled": 2.627365291118622, "standard": 10.546617209911346, "my_standard": 11.850126087665558, "standard_compiled": 9.685918688774109, "sb": 10.088451206684113}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "attn": 9.845800697803497, "ln": 7.711298763751984, "ln_compiled": 9.292080998420715, "gelu": 5.172915756702423, "gelu_compiled": 5.180932581424713, "standard": 21.371990442276, "my_standard": 23.921720683574677, "standard_compiled": 19.669152796268463, "sb": 20.267993211746216}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "attn": 19.375711679458618, "ln": 15.333592891693115, "ln_compiled": 18.245264887809753, "gelu": 10.264746844768524, "gelu_compiled": 10.283775627613068, "standard": 41.79700464010239, "my_standard": 45.84744572639465, "standard_compiled": 38.35208714008331, "sb": 38.35364431142807}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "attn": 2.9110386967658997, "ln": 2.1998360753059387, "ln_compiled": 2.581551671028137, "gelu": 1.4731436967849731, "gelu_compiled": 1.478634774684906, "standard": 6.764143705368042, "my_standard": 7.331632077693939, "standard_compiled": 6.24605268239975, "sb": 6.325609982013702}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "attn": 5.542516708374023, "ln": 4.289716482162476, "ln_compiled": 5.065307021141052, "gelu": 2.8742849826812744, "gelu_compiled": 2.882353961467743, "standard": 12.749537825584412, "my_standard": 13.79828155040741, "standard_compiled": 11.728867888450623, "sb": 11.642806231975555}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "attn": 10.80312579870224, "ln": 8.471302688121796, "ln_compiled": 9.96796041727066, "gelu": 5.681410431861877, "gelu_compiled": 5.6905597448349, "standard": 25.19702911376953, "my_standard": 27.226239442825317, "standard_compiled": 23.22910726070404, "sb": 22.682294249534607}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "attn": 21.284908056259155, "ln": 16.85701310634613, "ln_compiled": 19.643358886241913, "gelu": 11.292420327663422, "gelu_compiled": 11.314474046230316, "standard": 50.06787180900574, "my_standard": 54.29378151893616, "standard_compiled": 44.58653926849365, "sb": 45.359253883361816}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "attn": 3.382459282875061, "ln": 2.6206374168395996, "ln_compiled": 2.9666870832443237, "gelu": 1.7263293266296387, "gelu_compiled": 1.7317384481430054, "standard": 8.414775133132935, "my_standard": 9.117811918258667, "standard_compiled": 7.7542513608932495, "sb": 7.70898163318634}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "attn": 6.468378007411957, "ln": 5.125559866428375, "ln_compiled": 5.791269242763519, "gelu": 3.3864825963974, "gelu_compiled": 3.3920034766197205, "standard": 16.016244888305664, "my_standard": 17.25083589553833, "standard_compiled": 14.60808515548706, "sb": 14.347739517688751}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "attn": 12.645229697227478, "ln": 10.13532280921936, "ln_compiled": 11.427387595176697, "gelu": 6.6957250237464905, "gelu_compiled": 6.711684167385101, "standard": 31.792201101779938, "my_standard": 34.31189805269241, "standard_compiled": 29.10037338733673, "sb": 28.3128023147583}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "attn": 24.970605969429016, "ln": 20.182937383651733, "ln_compiled": 22.7489173412323, "gelu": 13.326868414878845, "gelu_compiled": 13.345755636692047, "standard": 63.46555054187775, "my_standard": 70.19880414009094, "standard_compiled": 56.40875548124313, "sb": 56.22846633195877}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "attn": 4.080049693584442, "ln": 3.2655522227287292, "ln_compiled": 3.3329352736473083, "gelu": 2.108432352542877, "gelu_compiled": 2.114713191986084, "standard": 11.370822787284851, "my_standard": 12.234866619110107, "standard_compiled": 10.377615690231323, "sb": 10.209612548351288}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "attn": 7.74645060300827, "ln": 6.418220698833466, "ln_compiled": 6.55733048915863, "gelu": 4.163652658462524, "gelu_compiled": 4.171028733253479, "standard": 21.39316499233246, "my_standard": 23.04024249315262, "standard_compiled": 19.431106746196747, "sb": 18.732361495494843}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "attn": 15.235155820846558, "ln": 12.684382498264313, "ln_compiled": 12.895286083221436, "gelu": 8.228868246078491, "gelu_compiled": 8.242718875408173, "standard": 42.55136102437973, "my_standard": 45.82635313272476, "standard_compiled": 38.663335144519806, "sb": 36.76284849643707}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "attn": 30.24454414844513, "ln": 25.25731921195984, "ln_compiled": 25.67601203918457, "gelu": 16.384944319725037, "gelu_compiled": 16.409948468208313, "standard": 84.26841348409653, "my_standard": 91.10662341117859, "standard_compiled": 76.89539343118668, "sb": 71.73164188861847}

View File

@ -1,353 +0,0 @@
import json
import time
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
from bitsandbytes.nn.triton_utils.v0.quantize_columnwise_nogroup_transpose import quantize_columnwise_nogroup_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize_bias import int8_matmul_rowwise_dequantize_bias
from bitsandbytes.nn.triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.nn.triton_utils.v0.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.nn.triton_utils.v0.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze, int8_matmul_mixed_dequanitze_bias
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
# not that big of an issue.
def get_time_standard_fwd(k, v):
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
##### time matmul 1
for _ in range(repeat // 2):
g.t().matmul(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.t().matmul(x)
torch.cuda.synchronize()
end = time.time()
print(f"time {k}: {(end - start) / repeat * 1000:.3f} ms")
return (end - start) / repeat * 1000
if __name__ == '__main__':
torch.manual_seed(0)
#for (dim, wm) in [(1024, 4), (1280, 4), (1408, 4.3637), (1664, 4.9231), (2048, 4), (4096, 4), (8096, 4)]
for (dim, wm) in [(1408, 4), (1664, 4),]:
for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]:
#for batch_size in [256*256, 256*512]:
for switch in [False, True]:
# hparams
repeat = 64
batch_size = batch_size
dim_out = dim * wm
dim_in = dim
if switch:
dim_out = dim
dim_in = wm * dim
dim_in = round(dim_in)
dim_out = round(dim_out)
# simulate forward pass
x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda()
g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda()
w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda()
x_int8 = x.clone().to(torch.int8)
g_int8 = g.clone().to(torch.int8)
w_int8 = w.clone().to(torch.int8)
wt_int8 = w.t().contiguous().clone().to(torch.int8)
state_x_rowwise = x.max(dim=1)[0]
state_g_rowwise = g.max(dim=1)[0]
state_w_columnwise = w.max(dim=0)[0]
state_w_rowwise = w.max(dim=1)[0]
state_w_global = w.max()
info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch}
k = 'standard_fwd'
for _ in range(repeat // 2):
x.matmul(w.t())
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
x.matmul(w.t())
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'standard_gw'
for _ in range(repeat // 2):
g.t().matmul(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.t().matmul(x)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'standard_gx'
for _ in range(repeat // 2):
g.matmul(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
g.matmul(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'rowwise_fwd'
for _ in range(repeat // 2):
int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'rowwise_bwd'
for _ in range(repeat // 2):
int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'global_fwd'
for _ in range(repeat // 2):
int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'global_bwd'
for _ in range(repeat // 2):
int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'x_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'g_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(g)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(g)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_rowwise'
for _ in range(repeat // 2):
quantize_rowwise_nogroup(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_colwise_transpose'
for _ in range(repeat // 2):
quantize_columnwise_nogroup_transpose(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_columnwise_nogroup_transpose(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_global'
for _ in range(repeat // 2):
quantize_global(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_global(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'w_quantize_global_transpose'
for _ in range(repeat // 2):
quantize_global_transpose(w)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_global_transpose(w)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_x'
for _ in range(repeat // 2):
newx = x.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = x.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_g'
for _ in range(repeat // 2):
newx = g.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = g.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
k = 'cast_w'
for _ in range(repeat // 2):
newx = w.to(torch.int8)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
newx = w.to(torch.int8)
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw']
time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd']
time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd']
print('TOTAL STANDARD', time_standard)
print('TOTAL ROWWISE', time_rowwise)
print('TOTAL GLOBAL', time_global)
print('speedup', -100*(time_global - time_standard)/time_standard)
info['time_standard'] = time_standard
info['time_rowwise'] = time_rowwise
info['time_global'] = time_global
info_json = json.dumps(info)
with open("tests/triton_tests/info.jsonl", "a") as file:
file.write(info_json + "\n")

View File

@ -1,142 +0,0 @@
{"repeat": 64, "batch_size": 1024, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.047907233238220215, "standard_gw": 0.04326179623603821, "standard_gx": 0.042986124753952026, "rowwise_fwd": 0.03902614116668701, "rowwise_bwd": 0.038955360651016235, "global_fwd": 0.03974884748458862, "global_bwd": 0.0391639769077301, "x_quantize_rowwise": 0.02619624137878418, "g_quantize_rowwise": 0.02695620059967041, "w_quantize_rowwise": 0.02631545066833496, "w_quantize_colwise_transpose": 0.08677691221237183, "w_quantize_global": 0.07359683513641357, "w_quantize_global_transpose": 0.08226558566093445, "cast_x": 0.007815659046173096, "cast_g": 0.016041100025177002, "cast_w": 0.01600012183189392, "time_standard": 0.13415515422821045, "time_rowwise": 0.28748810291290283, "time_global": 0.33118948340415955}
{"repeat": 64, "batch_size": 1024, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.04236400127410889, "standard_gw": 0.04898756742477417, "standard_gx": 0.04731118679046631, "rowwise_fwd": 0.03933534026145935, "rowwise_bwd": 0.03947317600250244, "global_fwd": 0.03688037395477295, "global_bwd": 0.039167702198028564, "x_quantize_rowwise": 0.02533942461013794, "g_quantize_rowwise": 0.02516806125640869, "w_quantize_rowwise": 0.02528354525566101, "w_quantize_colwise_transpose": 0.0903792679309845, "w_quantize_global": 0.0997595489025116, "w_quantize_global_transpose": 0.10209530591964722, "cast_x": 0.01626834273338318, "cast_g": 0.011973083019256592, "cast_w": 0.016044825315475464, "time_standard": 0.13866275548934937, "time_rowwise": 0.2939663827419281, "time_global": 0.37739798426628113}
{"repeat": 64, "batch_size": 2048, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.07753819227218628, "standard_gw": 0.08026883006095886, "standard_gx": 0.0906921923160553, "rowwise_fwd": 0.0630207359790802, "rowwise_bwd": 0.058263540267944336, "global_fwd": 0.06167963147163391, "global_bwd": 0.05801767110824585, "x_quantize_rowwise": 0.034205615520477295, "g_quantize_rowwise": 0.03341957926750183, "w_quantize_rowwise": 0.03244727849960327, "w_quantize_colwise_transpose": 0.08665025234222412, "w_quantize_global": 0.09483471512794495, "w_quantize_global_transpose": 0.10108202695846558, "cast_x": 0.012032687664031982, "cast_g": 0.03752484917640686, "cast_w": 0.01605972647666931, "time_standard": 0.24849921464920044, "time_rowwise": 0.3882758319377899, "time_global": 0.46350806951522827}
{"repeat": 64, "batch_size": 2048, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.09099021553993225, "standard_gw": 0.0799819827079773, "standard_gx": 0.07644668221473694, "rowwise_fwd": 0.05840510129928589, "rowwise_bwd": 0.06359070539474487, "global_fwd": 0.057831406593322754, "global_bwd": 0.06148591637611389, "x_quantize_rowwise": 0.03434717655181885, "g_quantize_rowwise": 0.03361701965332031, "w_quantize_rowwise": 0.03209337592124939, "w_quantize_colwise_transpose": 0.09028613567352295, "w_quantize_global": 0.0944770872592926, "w_quantize_global_transpose": 0.0994168221950531, "cast_x": 0.03769621253013611, "cast_g": 0.012010335922241211, "cast_w": 0.01600012183189392, "time_standard": 0.24741888046264648, "time_rowwise": 0.39232149720191956, "time_global": 0.4611574113368988}
{"repeat": 64, "batch_size": 4096, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.14450401067733765, "standard_gw": 0.14326348900794983, "standard_gx": 0.14762207865715027, "rowwise_fwd": 0.10525062680244446, "rowwise_bwd": 0.09800493717193604, "global_fwd": 0.10229647159576416, "global_bwd": 0.09718164801597595, "x_quantize_rowwise": 0.03429874777793884, "g_quantize_rowwise": 0.04567950963973999, "w_quantize_rowwise": 0.03365054726600647, "w_quantize_colwise_transpose": 0.08654966950416565, "w_quantize_global": 0.09663775563240051, "w_quantize_global_transpose": 0.10383129119873047, "cast_x": 0.01605972647666931, "cast_g": 0.08305534720420837, "cast_w": 0.01624971628189087, "time_standard": 0.43538957834243774, "time_rowwise": 0.5466975271701813, "time_global": 0.6231889128684998}
{"repeat": 64, "batch_size": 4096, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.14496594667434692, "standard_gw": 0.1412704586982727, "standard_gx": 0.14446303248405457, "rowwise_fwd": 0.10041892528533936, "rowwise_bwd": 0.10674074292182922, "global_fwd": 0.09856373071670532, "global_bwd": 0.10319426655769348, "x_quantize_rowwise": 0.045571476221084595, "g_quantize_rowwise": 0.03273040056228638, "w_quantize_rowwise": 0.033464282751083374, "w_quantize_colwise_transpose": 0.09154900908470154, "w_quantize_global": 0.0964440405368805, "w_quantize_global_transpose": 0.1031048595905304, "cast_x": 0.0835023820400238, "cast_g": 0.016242265701293945, "cast_w": 0.016283243894577026, "time_standard": 0.4306994378566742, "time_rowwise": 0.5517452955245972, "time_global": 0.6208792328834534}
{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28106942772865295, "standard_gw": 0.2841465175151825, "standard_gx": 0.301852822303772, "rowwise_fwd": 0.19879266619682312, "rowwise_bwd": 0.16228482127189636, "global_fwd": 0.19488856196403503, "global_bwd": 0.1607760787010193, "x_quantize_rowwise": 0.033974647521972656, "g_quantize_rowwise": 0.08221715688705444, "w_quantize_rowwise": 0.03248825669288635, "w_quantize_colwise_transpose": 0.08646398782730103, "w_quantize_global": 0.0939294695854187, "w_quantize_global_transpose": 0.09895861148834229, "cast_x": 0.03753975033760071, "cast_g": 0.15900656580924988, "cast_w": 0.01603737473487854, "time_standard": 0.8670687675476074, "time_rowwise": 0.8803680539131165, "time_global": 0.9488910436630249}
{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.26415660977363586, "standard_gw": 0.2679601311683655, "standard_gx": 0.30617788434028625, "rowwise_fwd": 0.180121511220932, "rowwise_bwd": 0.21555647253990173, "global_fwd": 0.17506256699562073, "global_bwd": 0.2116672694683075, "x_quantize_rowwise": 0.08289515972137451, "g_quantize_rowwise": 0.033795833587646484, "w_quantize_rowwise": 0.03366544842720032, "w_quantize_colwise_transpose": 0.09965524077415466, "w_quantize_global": 0.09595602750778198, "w_quantize_global_transpose": 0.1024976372718811, "cast_x": 0.1602955162525177, "cast_g": 0.03787502646446228, "cast_w": 0.016216188669204712, "time_standard": 0.8382946252822876, "time_rowwise": 0.9136497974395752, "time_global": 0.9698346257209778}
{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5719438195228577, "standard_gw": 0.524863600730896, "standard_gx": 0.6005167961120605, "rowwise_fwd": 0.3750324249267578, "rowwise_bwd": 0.28166547417640686, "global_fwd": 0.3674700856208801, "global_bwd": 0.2798214554786682, "x_quantize_rowwise": 0.04655122756958008, "g_quantize_rowwise": 0.1555122435092926, "w_quantize_rowwise": 0.03437697887420654, "w_quantize_colwise_transpose": 0.08634477853775024, "w_quantize_global": 0.09759142994880676, "w_quantize_global_transpose": 0.10081753134727478, "cast_x": 0.0828765332698822, "cast_g": 0.31184032559394836, "cast_w": 0.016063451766967773, "time_standard": 1.6973242163658142, "time_rowwise": 1.5043467283248901, "time_global": 1.5726275742053986}
{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5423910915851593, "standard_gw": 0.5674734711647034, "standard_gx": 0.5907565355300903, "rowwise_fwd": 0.3149174153804779, "rowwise_bwd": 0.3899820148944855, "global_fwd": 0.2909451723098755, "global_bwd": 0.3783814609050751, "x_quantize_rowwise": 0.15584751963615417, "g_quantize_rowwise": 0.04688650369644165, "w_quantize_rowwise": 0.031463801860809326, "w_quantize_colwise_transpose": 0.09072571992874146, "w_quantize_global": 0.09774044156074524, "w_quantize_global_transpose": 0.10405108332633972, "cast_x": 0.3111511468887329, "cast_g": 0.08282437920570374, "cast_w": 0.015992671251296997, "time_standard": 1.700621098279953, "time_rowwise": 1.5972964465618134, "time_global": 1.6413256525993347}
{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2115389108657837, "standard_gw": 1.1259466409683228, "standard_gx": 1.1027492582798004, "rowwise_fwd": 0.7407031953334808, "rowwise_bwd": 0.5539208650588989, "global_fwd": 0.7214657962322235, "global_bwd": 0.5515590310096741, "x_quantize_rowwise": 0.08765608072280884, "g_quantize_rowwise": 0.3022328019142151, "w_quantize_rowwise": 0.03347545862197876, "w_quantize_colwise_transpose": 0.08694455027580261, "w_quantize_global": 0.09706243872642517, "w_quantize_global_transpose": 0.10102614760398865, "cast_x": 0.1592189073562622, "cast_g": 0.6166175007820129, "cast_w": 0.01607835292816162, "time_standard": 3.440234810113907, "time_rowwise": 2.930879592895508, "time_global": 2.986948937177658}
{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1010989546775818, "standard_gw": 1.1352524161338806, "standard_gx": 1.1676251888275146, "rowwise_fwd": 0.5864761769771576, "rowwise_bwd": 0.7485374808311462, "global_fwd": 0.5547590553760529, "global_bwd": 0.7249303162097931, "x_quantize_rowwise": 0.3021731972694397, "g_quantize_rowwise": 0.08751824498176575, "w_quantize_rowwise": 0.033952295780181885, "w_quantize_colwise_transpose": 0.09011104702949524, "w_quantize_global": 0.09443238377571106, "w_quantize_global_transpose": 0.10376051068305969, "cast_x": 0.6167255342006683, "cast_g": 0.15922263264656067, "cast_w": 0.016070902347564697, "time_standard": 3.403976559638977, "time_rowwise": 2.984020859003067, "time_global": 3.0028261244297028}
{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.472013235092163, "standard_gw": 2.218998968601227, "standard_gx": 2.2116564214229584, "rowwise_fwd": 1.466125249862671, "rowwise_bwd": 1.0577328503131866, "global_fwd": 1.431729644536972, "global_bwd": 1.0476894676685333, "x_quantize_rowwise": 0.16929209232330322, "g_quantize_rowwise": 0.5952082574367523, "w_quantize_rowwise": 0.032100826501846313, "w_quantize_colwise_transpose": 0.08670613169670105, "w_quantize_global": 0.09590759873390198, "w_quantize_global_transpose": 0.10358169674873352, "cast_x": 0.31175464391708374, "cast_g": 1.2264922261238098, "cast_w": 0.016067177057266235, "time_standard": 6.902668625116348, "time_rowwise": 5.626164376735687, "time_global": 5.662407726049423}
{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.181064337491989, "standard_gw": 2.2256113588809967, "standard_gx": 2.3229196667671204, "rowwise_fwd": 1.0886266827583313, "rowwise_bwd": 1.4654062688350677, "global_fwd": 1.0472461581230164, "global_bwd": 1.433148980140686, "x_quantize_rowwise": 0.5954094231128693, "g_quantize_rowwise": 0.16921386122703552, "w_quantize_rowwise": 0.03442913293838501, "w_quantize_colwise_transpose": 0.09007751941680908, "w_quantize_global": 0.09575113654136658, "w_quantize_global_transpose": 0.10503828525543213, "cast_x": 1.2264810502529144, "cast_g": 0.3119036555290222, "cast_w": 0.01605600118637085, "time_standard": 6.729595363140106, "time_rowwise": 5.668774247169495, "time_global": 5.671419203281403}
{"repeat": 64, "batch_size": 1024, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.08157268166542053, "standard_gw": 0.07601454854011536, "standard_gx": 0.09059160947799683, "rowwise_fwd": 0.053066760301589966, "rowwise_bwd": 0.04787370562553406, "global_fwd": 0.05243346095085144, "global_bwd": 0.04809349775314331, "x_quantize_rowwise": 0.02571195363998413, "g_quantize_rowwise": 0.025898218154907227, "w_quantize_rowwise": 0.02714991569519043, "w_quantize_colwise_transpose": 0.19773468375205994, "w_quantize_global": 0.07273256778717041, "w_quantize_global_transpose": 0.08068978786468506, "cast_x": 0.008046627044677734, "cast_g": 0.0252649188041687, "cast_w": 0.0393986701965332, "time_standard": 0.24817883968353271, "time_rowwise": 0.4534497857093811, "time_global": 0.38157403469085693}
{"repeat": 64, "batch_size": 1024, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.09134411811828613, "standard_gw": 0.07602199912071228, "standard_gx": 0.09555742144584656, "rowwise_fwd": 0.047691166400909424, "rowwise_bwd": 0.05320459604263306, "global_fwd": 0.04759058356285095, "global_bwd": 0.0521540641784668, "x_quantize_rowwise": 0.025313347578048706, "g_quantize_rowwise": 0.025119632482528687, "w_quantize_rowwise": 0.0269375741481781, "w_quantize_colwise_transpose": 0.1857280731201172, "w_quantize_global": 0.07451698184013367, "w_quantize_global_transpose": 0.08009746670722961, "cast_x": 0.02547726035118103, "cast_g": 0.007897615432739258, "cast_w": 0.039536505937576294, "time_standard": 0.26292353868484497, "time_rowwise": 0.44001638889312744, "time_global": 0.3808140754699707}
{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.940010607242584, "standard_gw": 4.434864968061447, "standard_gx": 4.4097937643527985, "rowwise_fwd": 2.9467344284057617, "rowwise_bwd": 2.09181010723114, "global_fwd": 2.8806477785110474, "global_bwd": 2.0816922187805176, "x_quantize_rowwise": 0.33279508352279663, "g_quantize_rowwise": 1.1817067861557007, "w_quantize_rowwise": 0.03306567668914795, "w_quantize_colwise_transpose": 0.08666515350341797, "w_quantize_global": 0.0957287847995758, "w_quantize_global_transpose": 0.10242313146591187, "cast_x": 0.6165988743305206, "cast_g": 2.446405589580536, "cast_w": 0.016100704669952393, "time_standard": 13.78466933965683, "time_rowwise": 11.107642203569412, "time_global": 11.109858751296997}
{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.293464124202728, "standard_gw": 4.461295902729034, "standard_gx": 4.638340324163437, "rowwise_fwd": 2.116892486810684, "rowwise_bwd": 2.9479674994945526, "global_fwd": 2.0760856568813324, "global_bwd": 2.8755851089954376, "x_quantize_rowwise": 1.1818408966064453, "g_quantize_rowwise": 0.33276528120040894, "w_quantize_rowwise": 0.03287568688392639, "w_quantize_colwise_transpose": 0.09038299322128296, "w_quantize_global": 0.09598955512046814, "w_quantize_global_transpose": 0.100649893283844, "cast_x": 2.4467408657073975, "cast_g": 0.6165951490402222, "cast_w": 0.016082078218460083, "time_standard": 13.3931003510952, "time_rowwise": 11.164020746946335, "time_global": 11.12421229481697}
{"repeat": 64, "batch_size": 2048, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.1699887216091156, "standard_gw": 0.14045089483261108, "standard_gx": 0.17407909035682678, "rowwise_fwd": 0.10082125663757324, "rowwise_bwd": 0.08344277739524841, "global_fwd": 0.09941309690475464, "global_bwd": 0.08352473378181458, "x_quantize_rowwise": 0.025317072868347168, "g_quantize_rowwise": 0.03849714994430542, "w_quantize_rowwise": 0.02596527338027954, "w_quantize_colwise_transpose": 0.19767135381698608, "w_quantize_global": 0.07257238030433655, "w_quantize_global_transpose": 0.08127838373184204, "cast_x": 0.012032687664031982, "cast_g": 0.06345659494400024, "cast_w": 0.03953278064727783, "time_standard": 0.48451870679855347, "time_rowwise": 0.612165778875351, "time_global": 0.5410537123680115}
{"repeat": 64, "batch_size": 2048, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.14855340123176575, "standard_gw": 0.15553459525108337, "standard_gx": 0.16282498836517334, "rowwise_fwd": 0.09259581565856934, "rowwise_bwd": 0.11080875992774963, "global_fwd": 0.09166449308395386, "global_bwd": 0.10796263813972473, "x_quantize_rowwise": 0.03939121961593628, "g_quantize_rowwise": 0.025227665901184082, "w_quantize_rowwise": 0.027202069759368896, "w_quantize_colwise_transpose": 0.1940988004207611, "w_quantize_global": 0.07397681474685669, "w_quantize_global_transpose": 0.08178502321243286, "cast_x": 0.065632164478302, "cast_g": 0.01268833875656128, "cast_w": 0.04057586193084717, "time_standard": 0.46691298484802246, "time_rowwise": 0.6448589265346527, "time_global": 0.5755424499511719}
{"repeat": 64, "batch_size": 4096, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.32291561365127563, "standard_gw": 0.2875030040740967, "standard_gx": 0.3379322588443756, "rowwise_fwd": 0.19295886158943176, "rowwise_bwd": 0.16265735030174255, "global_fwd": 0.19031018018722534, "global_bwd": 0.16187503933906555, "x_quantize_rowwise": 0.02730637788772583, "g_quantize_rowwise": 0.06797909736633301, "w_quantize_rowwise": 0.02642720937728882, "w_quantize_colwise_transpose": 0.19745901226997375, "w_quantize_global": 0.07253512740135193, "w_quantize_global_transpose": 0.08047744631767273, "cast_x": 0.022336840629577637, "cast_g": 0.1209154725074768, "cast_w": 0.039268285036087036, "time_standard": 0.9483508765697479, "time_rowwise": 0.9622909128665924, "time_global": 0.8879862725734711}
{"repeat": 64, "batch_size": 4096, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.3019683063030243, "standard_gw": 0.288400799036026, "standard_gx": 0.3154948353767395, "rowwise_fwd": 0.18264353275299072, "rowwise_bwd": 0.2075284719467163, "global_fwd": 0.17072632908821106, "global_bwd": 0.1960061490535736, "x_quantize_rowwise": 0.06893649697303772, "g_quantize_rowwise": 0.02561509609222412, "w_quantize_rowwise": 0.026594847440719604, "w_quantize_colwise_transpose": 0.18575787544250488, "w_quantize_global": 0.07266923785209656, "w_quantize_global_transpose": 0.08060410618782043, "cast_x": 0.12182071805000305, "cast_g": 0.022590160369873047, "cast_w": 0.04000961780548096, "time_standard": 0.9058639407157898, "time_rowwise": 0.9854771196842194, "time_global": 0.9029582142829895}
{"repeat": 64, "batch_size": 8192, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 0.6489232182502747, "standard_gw": 0.5987770855426788, "standard_gx": 0.6644465029239655, "rowwise_fwd": 0.35867467522621155, "rowwise_bwd": 0.31855329871177673, "global_fwd": 0.353105366230011, "global_bwd": 0.31349435448646545, "x_quantize_rowwise": 0.03382191061973572, "g_quantize_rowwise": 0.12668967247009277, "w_quantize_rowwise": 0.02681836485862732, "w_quantize_colwise_transpose": 0.19756704568862915, "w_quantize_global": 0.07336586713790894, "w_quantize_global_transpose": 0.08036196231842041, "cast_x": 0.0583939254283905, "cast_g": 0.23520365357398987, "cast_w": 0.03935396671295166, "time_standard": 1.912146806716919, "time_rowwise": 1.660902053117752, "time_global": 1.579616218805313}
{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 0.5789436399936676, "standard_gw": 0.6130896508693695, "standard_gx": 0.6558857858181, "rowwise_fwd": 0.3464221954345703, "rowwise_bwd": 0.3650560975074768, "global_fwd": 0.3174394369125366, "global_bwd": 0.35758689045906067, "x_quantize_rowwise": 0.12686848640441895, "g_quantize_rowwise": 0.034302473068237305, "w_quantize_rowwise": 0.02745911478996277, "w_quantize_colwise_transpose": 0.1847483217716217, "w_quantize_global": 0.07192790508270264, "w_quantize_global_transpose": 0.08050352334976196, "cast_x": 0.23534893989562988, "cast_g": 0.05846098065376282, "cast_w": 0.03949552774429321, "time_standard": 1.847919076681137, "time_rowwise": 1.6979463398456573, "time_global": 1.6017183661460876}
{"repeat": 64, "batch_size": 1024, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.0573769211769104, "standard_gw": 0.061042606830596924, "standard_gx": 0.0783093273639679, "rowwise_fwd": 0.046797096729278564, "rowwise_bwd": 0.04620850086212158, "global_fwd": 0.04521384835243225, "global_bwd": 0.04425644874572754, "x_quantize_rowwise": 0.03257766366004944, "g_quantize_rowwise": 0.03449246287345886, "w_quantize_rowwise": 0.033657997846603394, "w_quantize_colwise_transpose": 0.1426301896572113, "w_quantize_global": 0.09257346391677856, "w_quantize_global_transpose": 0.10266527533531189, "cast_x": 0.011991709470748901, "cast_g": 0.020314007997512817, "cast_w": 0.027321279048919678, "time_standard": 0.19672885537147522, "time_rowwise": 0.39740651845932007, "time_global": 0.41282176971435547}
{"repeat": 64, "batch_size": 1024, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.07858872413635254, "standard_gw": 0.06122514605522156, "standard_gx": 0.05758553743362427, "rowwise_fwd": 0.04598498344421387, "rowwise_bwd": 0.04618242383003235, "global_fwd": 0.04597380757331848, "global_bwd": 0.046450644731521606, "x_quantize_rowwise": 0.03332272171974182, "g_quantize_rowwise": 0.033274292945861816, "w_quantize_rowwise": 0.0337548553943634, "w_quantize_colwise_transpose": 0.14807656407356262, "w_quantize_global": 0.09948387742042542, "w_quantize_global_transpose": 0.10120868682861328, "cast_x": 0.020120292901992798, "cast_g": 0.011488795280456543, "cast_w": 0.027466565370559692, "time_standard": 0.19739940762519836, "time_rowwise": 0.40182098746299744, "time_global": 0.420939177274704}
{"repeat": 64, "batch_size": 16384, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 1.3515166938304901, "standard_gw": 1.1536777019500732, "standard_gx": 1.224767416715622, "rowwise_fwd": 0.6912238895893097, "rowwise_bwd": 0.5562454462051392, "global_fwd": 0.67867711186409, "global_bwd": 0.5518943071365356, "x_quantize_rowwise": 0.06204098463058472, "g_quantize_rowwise": 0.24417787790298462, "w_quantize_rowwise": 0.025238841772079468, "w_quantize_colwise_transpose": 0.19756704568862915, "w_quantize_global": 0.07240846753120422, "w_quantize_global_transpose": 0.08046254515647888, "cast_x": 0.11138245463371277, "cast_g": 0.4637613892555237, "cast_w": 0.03935769200325012, "time_standard": 3.7299618124961853, "time_rowwise": 2.9301717877388, "time_global": 2.8433389961719513}
{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 1.2090615928173065, "standard_gw": 1.1396333575248718, "standard_gx": 1.2223869562149048, "rowwise_fwd": 0.5849376320838928, "rowwise_bwd": 0.6985403597354889, "global_fwd": 0.5565173923969269, "global_bwd": 0.6789751350879669, "x_quantize_rowwise": 0.2445802092552185, "g_quantize_rowwise": 0.06200745701789856, "w_quantize_rowwise": 0.027727335691452026, "w_quantize_colwise_transpose": 0.18501654267311096, "w_quantize_global": 0.07182732224464417, "w_quantize_global_transpose": 0.08069723844528198, "cast_x": 0.4638172686100006, "cast_g": 0.11136755347251892, "cast_w": 0.039517879486083984, "time_standard": 3.571081906557083, "time_rowwise": 2.9424428939819336, "time_global": 2.834238111972809}
{"repeat": 64, "batch_size": 32768, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 2.683013677597046, "standard_gw": 2.2987723350524902, "standard_gx": 2.4510622024536133, "rowwise_fwd": 1.359008252620697, "rowwise_bwd": 1.1018887162208557, "global_fwd": 1.3311207294464111, "global_bwd": 1.0954029858112335, "x_quantize_rowwise": 0.11804327368736267, "g_quantize_rowwise": 0.479232519865036, "w_quantize_rowwise": 0.026308000087738037, "w_quantize_colwise_transpose": 0.1975223422050476, "w_quantize_global": 0.07223710417747498, "w_quantize_global_transpose": 0.08019432425498962, "cast_x": 0.2161264419555664, "cast_g": 0.9207837283611298, "cast_w": 0.03929063677787781, "time_standard": 7.432848215103149, "time_rowwise": 5.580775439739227, "time_global": 5.475003272294998}
{"repeat": 64, "batch_size": 2048, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.11088326573371887, "standard_gw": 0.10994821786880493, "standard_gx": 0.12367218732833862, "rowwise_fwd": 0.07392093539237976, "rowwise_bwd": 0.07127970457077026, "global_fwd": 0.0730752944946289, "global_bwd": 0.07089227437973022, "x_quantize_rowwise": 0.03361701965332031, "g_quantize_rowwise": 0.03525242209434509, "w_quantize_rowwise": 0.03341585397720337, "w_quantize_colwise_transpose": 0.14318525791168213, "w_quantize_global": 0.09704753756523132, "w_quantize_global_transpose": 0.10221078991889954, "cast_x": 0.012002885341644287, "cast_g": 0.05240738391876221, "cast_w": 0.027313828468322754, "time_standard": 0.3445036709308624, "time_rowwise": 0.5006194114685059, "time_global": 0.5220435559749603}
{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 2.4625882506370544, "standard_gw": 2.421922981739044, "standard_gx": 2.380847930908203, "rowwise_fwd": 1.1231191456317902, "rowwise_bwd": 1.360483467578888, "global_fwd": 1.0947436094284058, "global_bwd": 1.3314113020896912, "x_quantize_rowwise": 0.4795975983142853, "g_quantize_rowwise": 0.11777132749557495, "w_quantize_rowwise": 0.02699345350265503, "w_quantize_colwise_transpose": 0.18484890460968018, "w_quantize_global": 0.07201358675956726, "w_quantize_global_transpose": 0.0803135335445404, "cast_x": 0.920858234167099, "cast_g": 0.21616369485855103, "cast_w": 0.03937259316444397, "time_standard": 7.265359163284302, "time_rowwise": 5.714736878871918, "time_global": 5.597773939371109}
{"repeat": 64, "batch_size": 2048, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.12437254190444946, "standard_gw": 0.11018291115760803, "standard_gx": 0.10970607399940491, "rowwise_fwd": 0.07167831063270569, "rowwise_bwd": 0.07583573460578918, "global_fwd": 0.07314234972000122, "global_bwd": 0.07501617074012756, "x_quantize_rowwise": 0.035624951124191284, "g_quantize_rowwise": 0.0333636999130249, "w_quantize_rowwise": 0.03264099359512329, "w_quantize_colwise_transpose": 0.14795735478401184, "w_quantize_global": 0.09621679782867432, "w_quantize_global_transpose": 0.10380148887634277, "cast_x": 0.05278363823890686, "cast_g": 0.01249462366104126, "cast_w": 0.02767890691757202, "time_standard": 0.3442615270614624, "time_rowwise": 0.5072839558124542, "time_global": 0.5273483693599701}
{"repeat": 64, "batch_size": 4096, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.21922588348388672, "standard_gw": 0.20731613039970398, "standard_gx": 0.23101642727851868, "rowwise_fwd": 0.1423358917236328, "rowwise_bwd": 0.1195073127746582, "global_fwd": 0.1401938498020172, "global_bwd": 0.11940300464630127, "x_quantize_rowwise": 0.03353878855705261, "g_quantize_rowwise": 0.06387382745742798, "w_quantize_rowwise": 0.03428757190704346, "w_quantize_colwise_transpose": 0.14376267790794373, "w_quantize_global": 0.09389594197273254, "w_quantize_global_transpose": 0.10196119546890259, "cast_x": 0.020060688257217407, "cast_g": 0.10236725211143494, "cast_w": 0.02732500433921814, "time_standard": 0.6575584411621094, "time_rowwise": 0.7446222007274628, "time_global": 0.7601827383041382}
{"repeat": 64, "batch_size": 4096, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.20026043057441711, "standard_gw": 0.21172687411308289, "standard_gx": 0.2276189625263214, "rowwise_fwd": 0.12956932187080383, "rowwise_bwd": 0.15310943126678467, "global_fwd": 0.12427568435668945, "global_bwd": 0.14432892203330994, "x_quantize_rowwise": 0.06471946835517883, "g_quantize_rowwise": 0.03309175372123718, "w_quantize_rowwise": 0.03242120146751404, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.09280815720558167, "w_quantize_global_transpose": 0.10265037417411804, "cast_x": 0.10267645120620728, "cast_g": 0.020150095224380493, "cast_w": 0.027399510145187378, "time_standard": 0.6396062672138214, "time_rowwise": 0.7719770073890686, "time_global": 0.773601233959198}
{"repeat": 64, "batch_size": 65536, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 5.324859172105789, "standard_gw": 4.977177828550339, "standard_gx": 4.468705505132675, "rowwise_fwd": 2.7004145085811615, "rowwise_bwd": 2.121664583683014, "global_fwd": 2.648312598466873, "global_bwd": 2.111390233039856, "x_quantize_rowwise": 0.22934377193450928, "g_quantize_rowwise": 0.9496547281742096, "w_quantize_rowwise": 0.02555176615715027, "w_quantize_colwise_transpose": 0.1977868378162384, "w_quantize_global": 0.0727437436580658, "w_quantize_global_transpose": 0.08098781108856201, "cast_x": 0.4259459674358368, "cast_g": 1.8352754414081573, "cast_w": 0.039637088775634766, "time_standard": 14.770742505788803, "time_rowwise": 11.201594024896622, "time_global": 11.069610714912415}
{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.49151480197906494, "standard_gw": 0.4681535065174103, "standard_gx": 0.42366236448287964, "rowwise_fwd": 0.2766512334346771, "rowwise_bwd": 0.2083033323287964, "global_fwd": 0.2709813416004181, "global_bwd": 0.20718947052955627, "x_quantize_rowwise": 0.034555792808532715, "g_quantize_rowwise": 0.11969730257987976, "w_quantize_rowwise": 0.03300607204437256, "w_quantize_colwise_transpose": 0.14345720410346985, "w_quantize_global": 0.09280070662498474, "w_quantize_global_transpose": 0.10214745998382568, "cast_x": 0.052288174629211426, "cast_g": 0.19747763872146606, "cast_w": 0.027339905500411987, "time_standard": 1.3833306729793549, "time_rowwise": 1.2838244438171387, "time_global": 1.2955255806446075}
{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.39635971188545227, "standard_gw": 0.44353678822517395, "standard_gx": 0.4724152386188507, "rowwise_fwd": 0.22813305258750916, "rowwise_bwd": 0.2868436276912689, "global_fwd": 0.2119205892086029, "global_bwd": 0.2749413251876831, "x_quantize_rowwise": 0.12082979083061218, "g_quantize_rowwise": 0.03444403409957886, "w_quantize_rowwise": 0.03444403409957886, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09495392441749573, "w_quantize_global_transpose": 0.1009330153465271, "cast_x": 0.19745156168937683, "cast_g": 0.05227327346801758, "cast_w": 0.027336180210113525, "time_standard": 1.312311738729477, "time_rowwise": 1.294981688261032, "time_global": 1.2815594673156738}
{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0207034647464752, "standard_gw": 0.897720456123352, "standard_gx": 0.8374936878681183, "rowwise_fwd": 0.5457103252410889, "rowwise_bwd": 0.4088357090950012, "global_fwd": 0.5308091640472412, "global_bwd": 0.40555745363235474, "x_quantize_rowwise": 0.05984678864479065, "g_quantize_rowwise": 0.2306811511516571, "w_quantize_rowwise": 0.0334717333316803, "w_quantize_colwise_transpose": 0.14356523752212524, "w_quantize_global": 0.09340420365333557, "w_quantize_global_transpose": 0.09996071457862854, "cast_x": 0.10207295417785645, "cast_g": 0.3880411386489868, "cast_w": 0.027671456336975098, "time_standard": 2.7559176087379456, "time_rowwise": 2.3198314011096954, "time_global": 2.31797993183136}
{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 4.502948373556137, "standard_gw": 4.418112337589264, "standard_gx": 4.748217761516571, "rowwise_fwd": 2.1329298615455627, "rowwise_bwd": 2.6968345046043396, "global_fwd": 2.102244645357132, "global_bwd": 2.6461556553840637, "x_quantize_rowwise": 0.9493157267570496, "g_quantize_rowwise": 0.2290569245815277, "w_quantize_rowwise": 0.02551451325416565, "w_quantize_colwise_transpose": 0.18491223454475403, "w_quantize_global": 0.07426366209983826, "w_quantize_global_transpose": 0.08058920502662659, "cast_x": 1.8352717161178589, "cast_g": 0.425681471824646, "cast_w": 0.039402395486831665, "time_standard": 13.669278472661972, "time_rowwise": 10.636676102876663, "time_global": 10.499738156795502}
{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8179470896720886, "standard_gw": 0.8687414228916168, "standard_gx": 0.9276494383811951, "rowwise_fwd": 0.4481859505176544, "rowwise_bwd": 0.5557462573051453, "global_fwd": 0.4100687801837921, "global_bwd": 0.5317367613315582, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.05963817238807678, "w_quantize_rowwise": 0.033523887395858765, "w_quantize_colwise_transpose": 0.14462321996688843, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10088086128234863, "cast_x": 0.3879927098751068, "cast_g": 0.10205060243606567, "cast_w": 0.02714991569519043, "time_standard": 2.6143379509449005, "time_rowwise": 2.3406408727169037, "time_global": 2.295881509780884}
{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0698904991149902, "standard_gw": 1.7200261354446411, "standard_gx": 1.663345843553543, "rowwise_fwd": 1.0664835572242737, "rowwise_bwd": 0.8059032261371613, "global_fwd": 1.0454729199409485, "global_bwd": 0.801432877779007, "x_quantize_rowwise": 0.1127384603023529, "g_quantize_rowwise": 0.4529319703578949, "w_quantize_rowwise": 0.03398582339286804, "w_quantize_colwise_transpose": 0.14343857765197754, "w_quantize_global": 0.09441003203392029, "w_quantize_global_transpose": 0.09993091225624084, "cast_x": 0.19744038581848145, "cast_g": 0.769149512052536, "cast_w": 0.02734735608100891, "time_standard": 5.453262478113174, "time_rowwise": 4.335507750511169, "time_global": 4.3269433081150055}
{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.758193761110306, "standard_gw": 1.6880109906196594, "standard_gx": 1.8163062632083893, "rowwise_fwd": 0.8343160152435303, "rowwise_bwd": 1.073598861694336, "global_fwd": 0.8045099675655365, "global_bwd": 1.0492689907550812, "x_quantize_rowwise": 0.453021377325058, "g_quantize_rowwise": 0.11304020881652832, "w_quantize_rowwise": 0.0337064266204834, "w_quantize_colwise_transpose": 0.1452416181564331, "w_quantize_global": 0.09451434016227722, "w_quantize_global_transpose": 0.0998079776763916, "cast_x": 0.769101083278656, "cast_g": 0.19731372594833374, "cast_w": 0.027332454919815063, "time_standard": 6.2625110149383545, "time_rowwise": 4.340935498476028, "time_global": 4.302173852920532}
{"repeat": 64, "batch_size": 131072, "dim_out": 6144, "dim_in": 1408, "wm": 4.3637, "switch": false, "standard_fwd": 10.728541761636734, "standard_gw": 9.228862822055817, "standard_gx": 8.837487548589706, "rowwise_fwd": 5.4414160549640656, "rowwise_bwd": 4.186157137155533, "global_fwd": 5.329187959432602, "global_bwd": 4.150416702032089, "x_quantize_rowwise": 0.4517659544944763, "g_quantize_rowwise": 1.890372484922409, "w_quantize_rowwise": 0.027563422918319702, "w_quantize_colwise_transpose": 0.1980513334274292, "w_quantize_global": 0.0733695924282074, "w_quantize_global_transpose": 0.08009746670722961, "cast_x": 0.8449330925941467, "cast_g": 3.6641769111156464, "cast_w": 0.03945454955101013, "time_standard": 28.794892132282257, "time_rowwise": 21.42418920993805, "time_global": 21.20407298207283}
{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.127204418182373, "standard_gw": 3.359321504831314, "standard_gx": 5.557261407375336, "rowwise_fwd": 2.1365806460380554, "rowwise_bwd": 1.6042962670326233, "global_fwd": 2.0923763513565063, "global_bwd": 1.5939176082611084, "x_quantize_rowwise": 0.21954253315925598, "g_quantize_rowwise": 0.8971206843852997, "w_quantize_rowwise": 0.03357976675033569, "w_quantize_colwise_transpose": 0.1431293785572052, "w_quantize_global": 0.10574981570243835, "w_quantize_global_transpose": 0.10281801223754883, "cast_x": 0.38795173168182373, "cast_g": 1.5318207442760468, "cast_w": 0.027142465114593506, "time_standard": 13.043787330389023, "time_rowwise": 8.39357078075409, "time_global": 8.370846509933472}
{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.576469004154205, "standard_gw": 3.361724317073822, "standard_gx": 3.6300085484981537, "rowwise_fwd": 1.6183294355869293, "rowwise_bwd": 2.1462254226207733, "global_fwd": 1.5953555703163147, "global_bwd": 2.0915642380714417, "x_quantize_rowwise": 0.8973218500614166, "g_quantize_rowwise": 0.2197064459323883, "w_quantize_rowwise": 0.03402307629585266, "w_quantize_colwise_transpose": 0.14822185039520264, "w_quantize_global": 0.09706616401672363, "w_quantize_global_transpose": 0.10339170694351196, "cast_x": 1.5312805771827698, "cast_g": 0.3879964351654053, "cast_w": 0.0269375741481781, "time_standard": 12.568201869726181, "time_rowwise": 8.425552397966385, "time_global": 8.366130292415619}
{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 6144, "wm": 4.3637, "switch": true, "standard_fwd": 8.900497108697891, "standard_gw": 9.188394993543625, "standard_gx": 9.503517299890518, "rowwise_fwd": 4.189815372228622, "rowwise_bwd": 5.426768213510513, "global_fwd": 4.155576229095459, "global_bwd": 5.329132080078125, "x_quantize_rowwise": 1.8885880708694458, "g_quantize_rowwise": 0.45193731784820557, "w_quantize_rowwise": 0.025987625122070312, "w_quantize_colwise_transpose": 0.1842118799686432, "w_quantize_global": 0.07349997758865356, "w_quantize_global_transpose": 0.08074194192886353, "cast_x": 3.6639943718910217, "cast_g": 0.8447282016277313, "cast_w": 0.03973767161369324, "time_standard": 27.592409402132034, "time_rowwise": 21.355703473091125, "time_global": 21.167870610952377}
{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.2329623401165, "standard_gw": 6.799045950174332, "standard_gx": 6.893906742334366, "rowwise_fwd": 4.252739250659943, "rowwise_bwd": 3.2025352120399475, "global_fwd": 4.176046699285507, "global_bwd": 3.173377364873886, "x_quantize_rowwise": 0.43221935629844666, "g_quantize_rowwise": 1.7872042953968048, "w_quantize_rowwise": 0.03328174352645874, "w_quantize_colwise_transpose": 0.1431480050086975, "w_quantize_global": 0.09707733988761902, "w_quantize_global_transpose": 0.10161846876144409, "cast_x": 0.7692091166973114, "cast_g": 3.057178109884262, "cast_w": 0.027302652597427368, "time_standard": 21.9259150326252, "time_rowwise": 16.65017381310463, "time_global": 16.56658947467804}
{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.278409510850906, "standard_gw": 6.815284490585327, "standard_gx": 7.280956953763962, "rowwise_fwd": 3.206692636013031, "rowwise_bwd": 4.246953874826431, "global_fwd": 3.1801797449588776, "global_bwd": 4.169579595327377, "x_quantize_rowwise": 1.7862766981124878, "g_quantize_rowwise": 0.4329495131969452, "w_quantize_rowwise": 0.03413483500480652, "w_quantize_colwise_transpose": 0.14493241906166077, "w_quantize_global": 0.09881332516670227, "w_quantize_global_transpose": 0.10376423597335815, "cast_x": 3.057088702917099, "cast_g": 0.7693544030189514, "cast_w": 0.027261674404144287, "time_standard": 25.374650955200195, "time_rowwise": 16.66722446680069, "time_global": 16.586847603321075}
{"repeat": 64, "batch_size": 1024, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.11636316776275635, "standard_gw": 0.11816620826721191, "standard_gx": 0.11482089757919312, "rowwise_fwd": 0.08482113480567932, "rowwise_bwd": 0.06284937262535095, "global_fwd": 0.08296221494674683, "global_bwd": 0.061664730310440063, "x_quantize_rowwise": 0.026706606149673462, "g_quantize_rowwise": 0.025641173124313354, "w_quantize_rowwise": 0.03740563988685608, "w_quantize_colwise_transpose": 0.2965778112411499, "w_quantize_global": 0.11304393410682678, "w_quantize_global_transpose": 0.12390688061714172, "cast_x": 0.008635222911834717, "cast_g": 0.037532299757003784, "cast_w": 0.06856024265289307, "time_standard": 0.3493502736091614, "time_rowwise": 0.652167946100235, "time_global": 0.5520917475223541}
{"repeat": 64, "batch_size": 1024, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.11609122157096863, "standard_gw": 0.11704489588737488, "standard_gx": 0.11566653847694397, "rowwise_fwd": 0.06706640124320984, "rowwise_bwd": 0.09074807167053223, "global_fwd": 0.06621330976486206, "global_bwd": 0.0859871506690979, "x_quantize_rowwise": 0.027574598789215088, "g_quantize_rowwise": 0.02520531415939331, "w_quantize_rowwise": 0.04095584154129028, "w_quantize_colwise_transpose": 0.37036463618278503, "w_quantize_global": 0.11350959539413452, "w_quantize_global_transpose": 0.12202560901641846, "cast_x": 0.03780052065849304, "cast_g": 0.00860169529914856, "cast_w": 0.06864592432975769, "time_standard": 0.3488026559352875, "time_rowwise": 0.7389597594738007, "time_global": 0.5575604736804962}
{"repeat": 64, "batch_size": 2048, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.22610649466514587, "standard_gw": 0.2229548990726471, "standard_gx": 0.22150203585624695, "rowwise_fwd": 0.1421608030796051, "rowwise_bwd": 0.10771304368972778, "global_fwd": 0.13930723071098328, "global_bwd": 0.10715052485466003, "x_quantize_rowwise": 0.02812594175338745, "g_quantize_rowwise": 0.04733726382255554, "w_quantize_rowwise": 0.03758445382118225, "w_quantize_colwise_transpose": 0.29515475034713745, "w_quantize_global": 0.11344626545906067, "w_quantize_global_transpose": 0.12392178177833557, "cast_x": 0.013589859008789062, "cast_g": 0.08285418152809143, "cast_w": 0.06850436329841614, "time_standard": 0.6705634295940399, "time_rowwise": 0.8810311555862427, "time_global": 0.7822439074516296}
{"repeat": 64, "batch_size": 2048, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.20173192024230957, "standard_gw": 0.2351999282836914, "standard_gx": 0.24710968136787415, "rowwise_fwd": 0.12035667896270752, "rowwise_bwd": 0.153418630361557, "global_fwd": 0.11473894119262695, "global_bwd": 0.14553219079971313, "x_quantize_rowwise": 0.04762038588523865, "g_quantize_rowwise": 0.02557411789894104, "w_quantize_rowwise": 0.04055723547935486, "w_quantize_colwise_transpose": 0.32641738653182983, "w_quantize_global": 0.1138448715209961, "w_quantize_global_transpose": 0.12255832552909851, "cast_x": 0.08405372500419617, "cast_g": 0.013835728168487549, "cast_w": 0.06961449980735779, "time_standard": 0.6840415298938751, "time_rowwise": 0.9491443634033203, "time_global": 0.8050687611103058}
{"repeat": 64, "batch_size": 4096, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 0.48126280307769775, "standard_gw": 0.46824291348457336, "standard_gx": 0.45252591371536255, "rowwise_fwd": 0.2749897539615631, "rowwise_bwd": 0.2111680805683136, "global_fwd": 0.2689175307750702, "global_bwd": 0.2104043960571289, "x_quantize_rowwise": 0.02676248550415039, "g_quantize_rowwise": 0.0842660665512085, "w_quantize_rowwise": 0.037495046854019165, "w_quantize_colwise_transpose": 0.2952851355075836, "w_quantize_global": 0.11366978287696838, "w_quantize_global_transpose": 0.12461841106414795, "cast_x": 0.0283755362033844, "cast_g": 0.1590624451637268, "cast_w": 0.06854161620140076, "time_standard": 1.4020316302776337, "time_rowwise": 1.3982094824314117, "time_global": 1.2968815863132477}
{"repeat": 64, "batch_size": 4096, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.4076175391674042, "standard_gw": 0.45526400208473206, "standard_gx": 0.4996545612812042, "rowwise_fwd": 0.238761305809021, "rowwise_bwd": 0.2913624048233032, "global_fwd": 0.2149641513824463, "global_bwd": 0.2717897295951843, "x_quantize_rowwise": 0.0845976173877716, "g_quantize_rowwise": 0.0266246497631073, "w_quantize_rowwise": 0.04038959741592407, "w_quantize_colwise_transpose": 0.33299997448921204, "w_quantize_global": 0.11374801397323608, "w_quantize_global_transpose": 0.12202560901641846, "cast_x": 0.15895813703536987, "cast_g": 0.028312206268310547, "cast_w": 0.06841868162155151, "time_standard": 1.3625361025333405, "time_rowwise": 1.4699995517730713, "time_global": 1.2890137732028961}
{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 1.02214515209198, "standard_gw": 0.9412020444869995, "standard_gx": 0.883936882019043, "rowwise_fwd": 0.5209781229496002, "rowwise_bwd": 0.41617080569267273, "global_fwd": 0.5089044570922852, "global_bwd": 0.4142932593822479, "x_quantize_rowwise": 0.03763660788536072, "g_quantize_rowwise": 0.15798211097717285, "w_quantize_rowwise": 0.0375211238861084, "w_quantize_colwise_transpose": 0.2973228693008423, "w_quantize_global": 0.11317431926727295, "w_quantize_global_transpose": 0.12396648526191711, "cast_x": 0.0685863196849823, "cast_g": 0.311531126499176, "cast_w": 0.0685080885887146, "time_standard": 2.8472840785980225, "time_rowwise": 2.4088136851787567, "time_global": 2.2971592843532562}
{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 0.8539073169231415, "standard_gw": 0.9352751076221466, "standard_gx": 0.9567439556121826, "rowwise_fwd": 0.4599541425704956, "rowwise_bwd": 0.531073659658432, "global_fwd": 0.42063742876052856, "global_bwd": 0.5125999450683594, "x_quantize_rowwise": 0.1581348478794098, "g_quantize_rowwise": 0.03755837678909302, "w_quantize_rowwise": 0.04056468605995178, "w_quantize_colwise_transpose": 0.3295913338661194, "w_quantize_global": 0.11314079165458679, "w_quantize_global_transpose": 0.12153387069702148, "cast_x": 0.3114752471446991, "cast_g": 0.06850063800811768, "cast_w": 0.06839632987976074, "time_standard": 2.7459263801574707, "time_rowwise": 2.492152154445648, "time_global": 2.2988803684711456}
{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 2.0550191402435303, "standard_gw": 1.7850138247013092, "standard_gx": 1.7571337521076202, "rowwise_fwd": 1.026798039674759, "rowwise_bwd": 0.8242167532444, "global_fwd": 1.0042376816272736, "global_bwd": 0.8189938962459564, "x_quantize_rowwise": 0.0688992440700531, "g_quantize_rowwise": 0.3054179251194, "w_quantize_rowwise": 0.03757700324058533, "w_quantize_colwise_transpose": 0.2973712980747223, "w_quantize_global": 0.11324509978294373, "w_quantize_global_transpose": 0.12398511171340942, "cast_x": 0.13050436973571777, "cast_g": 0.6165280938148499, "cast_w": 0.06848573684692383, "time_standard": 5.59716671705246, "time_rowwise": 4.345294088125229, "time_global": 4.2197927832603455}
{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 1.79310142993927, "standard_gw": 1.7801076173782349, "standard_gx": 1.9140169024467468, "rowwise_fwd": 0.8629709482192993, "rowwise_bwd": 1.0353922843933105, "global_fwd": 0.8200556039810181, "global_bwd": 1.002725213766098, "x_quantize_rowwise": 0.30517578125, "g_quantize_rowwise": 0.06880238652229309, "w_quantize_rowwise": 0.040318816900253296, "w_quantize_colwise_transpose": 0.3413744270801544, "w_quantize_global": 0.11326000094413757, "w_quantize_global_transpose": 0.12197345495223999, "cast_x": 0.6162337958812714, "cast_g": 0.13053417205810547, "cast_w": 0.06848946213722229, "time_standard": 5.487225949764252, "time_rowwise": 4.4341422617435455, "time_global": 4.212100058794022}
{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 4.0736086666584015, "standard_gw": 3.595758229494095, "standard_gx": 3.7020929157733917, "rowwise_fwd": 2.0306408405303955, "rowwise_bwd": 1.635722815990448, "global_fwd": 1.9890740513801575, "global_bwd": 1.627359539270401, "x_quantize_rowwise": 0.13131648302078247, "g_quantize_rowwise": 0.6001107394695282, "w_quantize_rowwise": 0.03781542181968689, "w_quantize_colwise_transpose": 0.2975836396217346, "w_quantize_global": 0.11357292532920837, "w_quantize_global_transpose": 0.12416765093803406, "cast_x": 0.2544410526752472, "cast_g": 1.2265890836715698, "cast_w": 0.06866827607154846, "time_standard": 11.371459811925888, "time_rowwise": 8.32894816994667, "time_global": 8.181359618902206}
{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 3.525231033563614, "standard_gw": 3.489706665277481, "standard_gx": 3.9937011897563934, "rowwise_fwd": 1.6627348959445953, "rowwise_bwd": 2.0311400294303894, "global_fwd": 1.6270726919174194, "global_bwd": 1.988884061574936, "x_quantize_rowwise": 0.5999915301799774, "g_quantize_rowwise": 0.1310594379901886, "w_quantize_rowwise": 0.04043802618980408, "w_quantize_colwise_transpose": 0.32950565218925476, "w_quantize_global": 0.11298432946205139, "w_quantize_global_transpose": 0.12201443314552307, "cast_x": 1.2257546186447144, "cast_g": 0.25444477796554565, "cast_w": 0.06848573684692383, "time_standard": 11.008638888597488, "time_rowwise": 8.28457623720169, "time_global": 8.071713149547577}
{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 8.123598992824554, "standard_gw": 8.085217326879501, "standard_gx": 7.293816655874252, "rowwise_fwd": 4.07782569527626, "rowwise_bwd": 3.196723759174347, "global_fwd": 4.001103341579437, "global_bwd": 3.1843744218349457, "x_quantize_rowwise": 0.2560615539550781, "g_quantize_rowwise": 1.1893659830093384, "w_quantize_rowwise": 0.037297606468200684, "w_quantize_colwise_transpose": 0.29668211936950684, "w_quantize_global": 0.11358782649040222, "w_quantize_global_transpose": 0.12476742267608643, "cast_x": 0.5020052194595337, "cast_g": 2.4454034864902496, "cast_w": 0.0684782862663269, "time_standard": 23.502632975578308, "time_rowwise": 17.139174044132233, "time_global": 16.95447787642479}
{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 6.932958960533142, "standard_gw": 7.0609524846076965, "standard_gx": 7.460080087184906, "rowwise_fwd": 3.1809918582439423, "rowwise_bwd": 4.078391939401627, "global_fwd": 3.185112029314041, "global_bwd": 3.99089977145195, "x_quantize_rowwise": 1.1891834437847137, "g_quantize_rowwise": 0.25588274002075195, "w_quantize_rowwise": 0.0406019389629364, "w_quantize_colwise_transpose": 0.3389529883861542, "w_quantize_global": 0.11313334107398987, "w_quantize_global_transpose": 0.12241676449775696, "cast_x": 2.4446770548820496, "cast_g": 0.5022138357162476, "cast_w": 0.06857141852378845, "time_standard": 21.453991532325745, "time_rowwise": 16.14495739340782, "time_global": 15.9175805747509}
{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 1664, "wm": 4.9231, "switch": false, "standard_fwd": 16.38999581336975, "standard_gw": 15.075922012329102, "standard_gx": 14.479495584964752, "rowwise_fwd": 8.128684014081955, "rowwise_bwd": 6.41091912984848, "global_fwd": 7.977847009897232, "global_bwd": 6.362702697515488, "x_quantize_rowwise": 0.5057230591773987, "g_quantize_rowwise": 2.3681968450546265, "w_quantize_rowwise": 0.037435442209243774, "w_quantize_colwise_transpose": 0.29555708169937134, "w_quantize_global": 0.11360272765159607, "w_quantize_global_transpose": 0.12426823377609253, "cast_x": 0.997692346572876, "cast_g": 4.8848651349544525, "cast_w": 0.0685565173625946, "time_standard": 45.945413410663605, "time_rowwise": 32.82243758440018, "time_global": 32.528262585401535}
{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 8192, "wm": 4.9231, "switch": true, "standard_fwd": 14.838922768831253, "standard_gw": 15.112213790416718, "standard_gx": 14.869242906570435, "rowwise_fwd": 6.402213126420975, "rowwise_bwd": 8.132629096508026, "global_fwd": 6.36359304189682, "global_bwd": 7.9823993146419525, "x_quantize_rowwise": 2.367999404668808, "g_quantize_rowwise": 0.5056969821453094, "w_quantize_rowwise": 0.04053488373756409, "w_quantize_colwise_transpose": 0.3559887409210205, "w_quantize_global": 0.1136288046836853, "w_quantize_global_transpose": 0.125102698802948, "cast_x": 4.880473017692566, "cast_g": 0.9965412318706512, "cast_w": 0.06855279207229614, "time_standard": 44.820379465818405, "time_rowwise": 32.91727602481842, "time_global": 32.57063403725624}
{"repeat": 64, "batch_size": 1024, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.15426427125930786, "standard_gw": 0.14531239867210388, "standard_gx": 0.1703128218650818, "rowwise_fwd": 0.09618699550628662, "rowwise_bwd": 0.10633841156959534, "global_fwd": 0.09483471512794495, "global_bwd": 0.10636076331138611, "x_quantize_rowwise": 0.02434849739074707, "g_quantize_rowwise": 0.026009976863861084, "w_quantize_rowwise": 0.04366040229797363, "w_quantize_colwise_transpose": 0.34148991107940674, "w_quantize_global": 0.13587623834609985, "w_quantize_global_transpose": 0.14698877930641174, "cast_x": 0.009745359420776367, "cast_g": 0.03773719072341919, "cast_w": 0.08277222514152527, "time_standard": 0.46988949179649353, "time_rowwise": 0.7833465933799744, "time_global": 0.6797313690185547}
{"repeat": 64, "batch_size": 1024, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.16738846898078918, "standard_gw": 0.14199689030647278, "standard_gx": 0.15476346015930176, "rowwise_fwd": 0.11660531163215637, "rowwise_bwd": 0.1050308346748352, "global_fwd": 0.11050701141357422, "global_bwd": 0.09868666529655457, "x_quantize_rowwise": 0.02781301736831665, "g_quantize_rowwise": 0.024966895580291748, "w_quantize_rowwise": 0.047437846660614014, "w_quantize_colwise_transpose": 0.5995631217956543, "w_quantize_global": 0.1362822949886322, "w_quantize_global_transpose": 0.14807283878326416, "cast_x": 0.0377558171749115, "cast_g": 0.00973045825958252, "cast_w": 0.0828281044960022, "time_standard": 0.4641488194465637, "time_rowwise": 1.063413918018341, "time_global": 0.6883256137371063}
{"repeat": 64, "batch_size": 2048, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.2727396786212921, "standard_gw": 0.2711080014705658, "standard_gx": 0.3120154142379761, "rowwise_fwd": 0.16424059867858887, "rowwise_bwd": 0.17686933279037476, "global_fwd": 0.161685049533844, "global_bwd": 0.17517060041427612, "x_quantize_rowwise": 0.025484710931777954, "g_quantize_rowwise": 0.047635287046432495, "w_quantize_rowwise": 0.04380941390991211, "w_quantize_colwise_transpose": 0.3401711583137512, "w_quantize_global": 0.13605505228042603, "w_quantize_global_transpose": 0.14705583453178406, "cast_x": 0.01584365963935852, "cast_g": 0.08274242281913757, "cast_w": 0.08281320333480835, "time_standard": 0.855863094329834, "time_rowwise": 1.0693185031414032, "time_global": 0.9641945362091064}
{"repeat": 64, "batch_size": 2048, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.28916075825691223, "standard_gw": 0.29472261667251587, "standard_gx": 0.30096620321273804, "rowwise_fwd": 0.19618868827819824, "rowwise_bwd": 0.17556175589561462, "global_fwd": 0.18328800797462463, "global_bwd": 0.16647577285766602, "x_quantize_rowwise": 0.047441571950912476, "g_quantize_rowwise": 0.026609748601913452, "w_quantize_rowwise": 0.04766508936882019, "w_quantize_colwise_transpose": 0.6060972809791565, "w_quantize_global": 0.1363418996334076, "w_quantize_global_transpose": 0.14806538820266724, "cast_x": 0.08295103907585144, "cast_g": 0.015836209058761597, "cast_w": 0.08285045623779297, "time_standard": 0.8848495781421661, "time_rowwise": 1.3942867517471313, "time_global": 1.0029450058937073}
{"repeat": 64, "batch_size": 4096, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 0.6430819630622864, "standard_gw": 0.5622953176498413, "standard_gx": 0.5780421197414398, "rowwise_fwd": 0.318676233291626, "rowwise_bwd": 0.29438361525535583, "global_fwd": 0.31290948390960693, "global_bwd": 0.290747731924057, "x_quantize_rowwise": 0.027455389499664307, "g_quantize_rowwise": 0.08405372500419617, "w_quantize_rowwise": 0.04369765520095825, "w_quantize_colwise_transpose": 0.34110620617866516, "w_quantize_global": 0.1360774040222168, "w_quantize_global_transpose": 0.14697015285491943, "cast_x": 0.037614256143569946, "cast_g": 0.15922263264656067, "cast_w": 0.08288025856018066, "time_standard": 1.7834194004535675, "time_rowwise": 1.671668142080307, "time_global": 1.560509204864502}
{"repeat": 64, "batch_size": 4096, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 0.551275908946991, "standard_gw": 0.591665506362915, "standard_gx": 0.6067268550395966, "rowwise_fwd": 0.33493712544441223, "rowwise_bwd": 0.32918527722358704, "global_fwd": 0.29528141021728516, "global_bwd": 0.31659379601478577, "x_quantize_rowwise": 0.08441135287284851, "g_quantize_rowwise": 0.025656074285507202, "w_quantize_rowwise": 0.04745647311210632, "w_quantize_colwise_transpose": 0.5993843078613281, "w_quantize_global": 0.1359879970550537, "w_quantize_global_transpose": 0.14815106987953186, "cast_x": 0.15932321548461914, "cast_g": 0.037439167499542236, "cast_w": 0.08288398385047913, "time_standard": 1.7496682703495026, "time_rowwise": 2.0126961171627045, "time_global": 1.5977472066879272}
{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2295916676521301, "standard_gw": 1.116037368774414, "standard_gx": 1.1164769530296326, "rowwise_fwd": 0.603698194026947, "rowwise_bwd": 0.5168020725250244, "global_fwd": 0.5922466516494751, "global_bwd": 0.5151033401489258, "x_quantize_rowwise": 0.0437907874584198, "g_quantize_rowwise": 0.157918781042099, "w_quantize_rowwise": 0.044032931327819824, "w_quantize_colwise_transpose": 0.34073740243911743, "w_quantize_global": 0.13559311628341675, "w_quantize_global_transpose": 0.14679506421089172, "cast_x": 0.08263811469078064, "cast_g": 0.3115162253379822, "cast_w": 0.08287280797958374, "time_standard": 3.4621059894561768, "time_rowwise": 2.8230175375938416, "time_global": 2.707485109567642}
{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.090865582227707, "standard_gw": 1.1468492448329926, "standard_gx": 1.1166594922542572, "rowwise_fwd": 0.5559474229812622, "rowwise_bwd": 0.6105974316596985, "global_fwd": 0.5200020968914032, "global_bwd": 0.592011958360672, "x_quantize_rowwise": 0.15802308917045593, "g_quantize_rowwise": 0.04357844591140747, "w_quantize_rowwise": 0.04709511995315552, "w_quantize_colwise_transpose": 0.5969703197479248, "w_quantize_global": 0.13620033860206604, "w_quantize_global_transpose": 0.148136168718338, "cast_x": 0.31115859746932983, "cast_g": 0.08263811469078064, "cast_w": 0.08268281817436218, "time_standard": 3.3543743193149567, "time_rowwise": 3.159061074256897, "time_global": 2.744801342487335}
{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4665743112564087, "standard_gw": 2.1993443369865417, "standard_gx": 2.1993033587932587, "rowwise_fwd": 1.192428171634674, "rowwise_bwd": 1.023314893245697, "global_fwd": 1.1711902916431427, "global_bwd": 1.0202191770076752, "x_quantize_rowwise": 0.08077174425125122, "g_quantize_rowwise": 0.30520185828208923, "w_quantize_rowwise": 0.043783336877822876, "w_quantize_colwise_transpose": 0.339999794960022, "w_quantize_global": 0.13628602027893066, "w_quantize_global_transpose": 0.14696642756462097, "cast_x": 0.15902891755104065, "cast_g": 0.6164535880088806, "cast_w": 0.08285418152809143, "time_standard": 6.865222007036209, "time_rowwise": 5.184844136238098, "time_global": 5.059979856014252}
{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1861791610717773, "standard_gw": 2.157818526029587, "standard_gx": 2.321537584066391, "rowwise_fwd": 1.0536126792430878, "rowwise_bwd": 1.1971630156040192, "global_fwd": 1.02127343416214, "global_bwd": 1.1707991361618042, "x_quantize_rowwise": 0.30522048473358154, "g_quantize_rowwise": 0.08065253496170044, "w_quantize_rowwise": 0.04741176962852478, "w_quantize_colwise_transpose": 0.5979575216770172, "w_quantize_global": 0.1362040638923645, "w_quantize_global_transpose": 0.14854222536087036, "cast_x": 0.6162486970424652, "cast_g": 0.1591891050338745, "cast_w": 0.08288398385047913, "time_standard": 6.665535271167755, "time_rowwise": 5.439836531877518, "time_global": 5.020510405302048}
{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.891645163297653, "standard_gw": 4.233300685882568, "standard_gx": 4.2071714997291565, "rowwise_fwd": 2.3616664111614227, "rowwise_bwd": 1.9419342279434204, "global_fwd": 2.3244209587574005, "global_bwd": 1.9598640501499176, "x_quantize_rowwise": 0.15483051538467407, "g_quantize_rowwise": 0.6008371710777283, "w_quantize_rowwise": 0.043839216232299805, "w_quantize_colwise_transpose": 0.3400743007659912, "w_quantize_global": 0.1362822949886322, "w_quantize_global_transpose": 0.14691054821014404, "cast_x": 0.31141936779022217, "cast_g": 1.2254081666469574, "cast_w": 0.08280202746391296, "time_standard": 13.332117348909378, "time_rowwise": 9.676482528448105, "time_global": 9.556446224451065}
{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.267625510692596, "standard_gw": 4.237007349729538, "standard_gx": 4.666488617658615, "rowwise_fwd": 1.9670464098453522, "rowwise_bwd": 2.362079918384552, "global_fwd": 1.9469596445560455, "global_bwd": 2.32585147023201, "x_quantize_rowwise": 0.6000921130180359, "g_quantize_rowwise": 0.15481188893318176, "w_quantize_rowwise": 0.04725530743598938, "w_quantize_colwise_transpose": 0.5976222455501556, "w_quantize_global": 0.13619661331176758, "w_quantize_global_transpose": 0.14815852046012878, "cast_x": 1.2261345982551575, "cast_g": 0.3117173910140991, "cast_w": 0.08279457688331604, "time_standard": 13.17112147808075, "time_rowwise": 9.965915232896805, "time_global": 9.549077600240707}
{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.787477552890778, "standard_gw": 8.533861488103867, "standard_gx": 8.979786187410355, "rowwise_fwd": 4.741787910461426, "rowwise_bwd": 3.871854394674301, "global_fwd": 4.674319177865982, "global_bwd": 3.9110779762268066, "x_quantize_rowwise": 0.3025829792022705, "g_quantize_rowwise": 1.1898204684257507, "w_quantize_rowwise": 0.043705105781555176, "w_quantize_colwise_transpose": 0.33997371792793274, "w_quantize_global": 0.13592839241027832, "w_quantize_global_transpose": 0.14724954962730408, "cast_x": 0.6160177290439606, "cast_g": 2.4440810084342957, "cast_w": 0.08280575275421143, "time_standard": 27.301125228405, "time_rowwise": 19.023586064577103, "time_global": 18.89484003186226}
{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.461769670248032, "standard_gw": 8.428700268268585, "standard_gx": 9.447630494832993, "rowwise_fwd": 3.881257027387619, "rowwise_bwd": 4.7471001744270325, "global_fwd": 3.9101652801036835, "global_bwd": 4.662122577428818, "x_quantize_rowwise": 1.1892355978488922, "g_quantize_rowwise": 0.3024376928806305, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.5982778966426849, "w_quantize_global": 0.13624131679534912, "w_quantize_global_transpose": 0.1484602689743042, "cast_x": 2.4463236331939697, "cast_g": 0.6163865327835083, "cast_w": 0.08278340101242065, "time_standard": 26.33810043334961, "time_rowwise": 19.194088876247406, "time_global": 18.777363002300262}
{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.699689000844955, "standard_gw": 16.89574122428894, "standard_gx": 17.907552421092987, "rowwise_fwd": 9.453803300857544, "rowwise_bwd": 7.8153833746910095, "global_fwd": 9.313825517892838, "global_bwd": 7.8215524554252625, "x_quantize_rowwise": 0.5986690521240234, "g_quantize_rowwise": 2.368006855249405, "w_quantize_rowwise": 0.043682754039764404, "w_quantize_colwise_transpose": 0.3406330943107605, "w_quantize_global": 0.13626739382743835, "w_quantize_global_transpose": 0.14715641736984253, "cast_x": 1.2262165546417236, "cast_g": 4.8834048211574554, "cast_w": 0.08272379636764526, "time_standard": 54.50298264622688, "time_rowwise": 37.51591965556145, "time_global": 37.28121891617775}
{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.66700127720833, "standard_gw": 18.56840029358864, "standard_gx": 18.049821257591248, "rowwise_fwd": 7.742393761873245, "rowwise_bwd": 9.479016065597534, "global_fwd": 7.806576788425446, "global_bwd": 9.328477084636688, "x_quantize_rowwise": 2.368297427892685, "g_quantize_rowwise": 0.5978643894195557, "w_quantize_rowwise": 0.047303736209869385, "w_quantize_colwise_transpose": 0.5982741713523865, "w_quantize_global": 0.13678893446922302, "w_quantize_global_transpose": 0.1488029956817627, "cast_x": 4.880513995885849, "cast_g": 1.2248307466506958, "cast_w": 0.08270144462585449, "time_standard": 55.285222828388214, "time_rowwise": 39.401549845933914, "time_global": 38.955207914114}
{"repeat": 64, "batch_size": 1024, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 0.529509037733078, "standard_gw": 0.5781911313533783, "standard_gx": 0.6095841526985168, "rowwise_fwd": 0.2811029553413391, "rowwise_bwd": 0.3345906734466553, "global_fwd": 0.27928128838539124, "global_bwd": 0.33126771450042725, "x_quantize_rowwise": 0.025760382413864136, "g_quantize_rowwise": 0.06494298577308655, "w_quantize_rowwise": 0.15570968389511108, "w_quantize_colwise_transpose": 1.6086548566818237, "w_quantize_global": 0.481434166431427, "w_quantize_global_transpose": 0.505443662405014, "cast_x": 0.01582130789756775, "cast_g": 0.08295103907585144, "cast_w": 0.311531126499176, "time_standard": 1.7172843217849731, "time_rowwise": 3.048952668905258, "time_global": 2.2663213312625885}
{"repeat": 64, "batch_size": 1024, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 0.5729459226131439, "standard_gw": 0.5789846181869507, "standard_gx": 0.5775243043899536, "rowwise_fwd": 0.36711618304252625, "rowwise_bwd": 0.2913735806941986, "global_fwd": 0.33703818917274475, "global_bwd": 0.2821236848831177, "x_quantize_rowwise": 0.064849853515625, "g_quantize_rowwise": 0.025060027837753296, "w_quantize_rowwise": 0.22537633776664734, "w_quantize_colwise_transpose": 3.6401040852069855, "w_quantize_global": 0.4818551242351532, "w_quantize_global_transpose": 0.5101114511489868, "cast_x": 0.08286535739898682, "cast_g": 0.015828758478164673, "cast_w": 0.3114677965641022, "time_standard": 1.7294548451900482, "time_rowwise": 5.192864686250687, "time_global": 2.2800229489803314}
{"repeat": 64, "batch_size": 2048, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 1.1735819280147552, "standard_gw": 1.121576875448227, "standard_gx": 1.1242404580116272, "rowwise_fwd": 0.5535706877708435, "rowwise_bwd": 0.5567893385887146, "global_fwd": 0.5486570298671722, "global_bwd": 0.551365315914154, "x_quantize_rowwise": 0.02710893750190735, "g_quantize_rowwise": 0.11784210801124573, "w_quantize_rowwise": 0.15565752983093262, "w_quantize_colwise_transpose": 1.607745885848999, "w_quantize_global": 0.4824437201023102, "w_quantize_global_transpose": 0.5060508847236633, "cast_x": 0.03808736801147461, "cast_g": 0.15912577509880066, "cast_w": 0.31150132417678833, "time_standard": 3.4193992614746094, "time_rowwise": 4.14029136300087, "time_global": 3.35504487156868}
{"repeat": 64, "batch_size": 2048, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 1.1169910430908203, "standard_gw": 1.1065900325775146, "standard_gx": 1.1815577745437622, "rowwise_fwd": 0.5917288362979889, "rowwise_bwd": 0.5614385008811951, "global_fwd": 0.5646944046020508, "global_bwd": 0.5500949919223785, "x_quantize_rowwise": 0.118207186460495, "g_quantize_rowwise": 0.025041401386260986, "w_quantize_rowwise": 0.22566691040992737, "w_quantize_colwise_transpose": 3.635551780462265, "w_quantize_global": 0.4815608263015747, "w_quantize_global_transpose": 0.509701669216156, "cast_x": 0.15912950038909912, "cast_g": 0.03797560930252075, "cast_w": 0.3114044666290283, "time_standard": 3.405138850212097, "time_rowwise": 6.264224648475647, "time_global": 3.3558905124664307}
{"repeat": 64, "batch_size": 4096, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 2.3259930312633514, "standard_gw": 2.1472275257110596, "standard_gx": 2.213582396507263, "rowwise_fwd": 1.0509602725505829, "rowwise_bwd": 0.9888559579849243, "global_fwd": 1.0398179292678833, "global_bwd": 0.9887740015983582, "x_quantize_rowwise": 0.04647299647331238, "g_quantize_rowwise": 0.22570788860321045, "w_quantize_rowwise": 0.1554824411869049, "w_quantize_colwise_transpose": 1.610085368156433, "w_quantize_global": 0.48134103417396545, "w_quantize_global_transpose": 0.5054809153079987, "cast_x": 0.08297711610794067, "cast_g": 0.3115646541118622, "cast_w": 0.31159818172454834, "time_standard": 6.686802953481674, "time_rowwise": 6.224792450666428, "time_global": 5.434822291135788}
{"repeat": 64, "batch_size": 4096, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 2.19760462641716, "standard_gw": 2.2860951721668243, "standard_gx": 2.290956676006317, "rowwise_fwd": 1.0311491787433624, "rowwise_bwd": 1.0555200278759003, "global_fwd": 0.9858310222625732, "global_bwd": 1.0394863784313202, "x_quantize_rowwise": 0.22591277956962585, "g_quantize_rowwise": 0.046234577894210815, "w_quantize_rowwise": 0.22603943943977356, "w_quantize_colwise_transpose": 3.628809005022049, "w_quantize_global": 0.4819147288799286, "w_quantize_global_transpose": 0.5104243755340576, "cast_x": 0.3114528954029083, "cast_g": 0.08296966552734375, "cast_w": 0.3116317093372345, "time_standard": 6.7746564745903015, "time_rowwise": 8.499760180711746, "time_global": 5.575899034738541}
{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.633370786905289, "standard_gw": 4.397690296173096, "standard_gx": 4.286538809537888, "rowwise_fwd": 2.089906483888626, "rowwise_bwd": 1.9657425582408905, "global_fwd": 2.0679645240306854, "global_bwd": 1.9629858434200287, "x_quantize_rowwise": 0.08271634578704834, "g_quantize_rowwise": 0.43905526399612427, "w_quantize_rowwise": 0.1551508903503418, "w_quantize_colwise_transpose": 1.6106180846691132, "w_quantize_global": 0.48185884952545166, "w_quantize_global_transpose": 0.506274402141571, "cast_x": 0.15918537974357605, "cast_g": 0.6163418292999268, "cast_w": 0.311531126499176, "time_standard": 13.317599892616272, "time_rowwise": 10.74087992310524, "time_global": 9.938545525074005}
{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.424266517162323, "standard_gw": 4.391487687826157, "standard_gx": 4.61186096072197, "rowwise_fwd": 1.9874684512615204, "rowwise_bwd": 2.093140035867691, "global_fwd": 1.9647255539894104, "global_bwd": 2.06940621137619, "x_quantize_rowwise": 0.43999403715133667, "g_quantize_rowwise": 0.08271634578704834, "w_quantize_rowwise": 0.22581592202186584, "w_quantize_colwise_transpose": 3.631964325904846, "w_quantize_global": 0.4821456968784332, "w_quantize_global_transpose": 0.5102343857288361, "cast_x": 0.6164386868476868, "cast_g": 0.1591108739376068, "cast_w": 0.31154975295066833, "time_standard": 13.42761516571045, "time_rowwise": 12.852586805820465, "time_global": 9.940709918737411}
{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.229827672243118, "standard_gw": 8.319318294525146, "standard_gx": 8.652344346046448, "rowwise_fwd": 4.163607954978943, "rowwise_bwd": 3.778301179409027, "global_fwd": 4.121184349060059, "global_bwd": 3.7708766758441925, "x_quantize_rowwise": 0.1553669571876526, "g_quantize_rowwise": 0.8715838193893433, "w_quantize_rowwise": 0.15540048480033875, "w_quantize_colwise_transpose": 1.6092769801616669, "w_quantize_global": 0.4813969135284424, "w_quantize_global_transpose": 0.5070343613624573, "cast_x": 0.31150132417678833, "cast_g": 1.2259706854820251, "cast_w": 0.311482697725296, "time_standard": 26.201490312814713, "time_rowwise": 19.052855670452118, "time_global": 18.226761370897293}
{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.577890694141388, "standard_gw": 9.073298424482346, "standard_gx": 9.210295975208282, "rowwise_fwd": 3.7784352898597717, "rowwise_bwd": 4.165928810834885, "global_fwd": 3.7702471017837524, "global_bwd": 4.121150821447372, "x_quantize_rowwise": 0.868629664182663, "g_quantize_rowwise": 0.1554340124130249, "w_quantize_rowwise": 0.22614002227783203, "w_quantize_colwise_transpose": 3.6367811262607574, "w_quantize_global": 0.4828609526157379, "w_quantize_global_transpose": 0.510137528181076, "cast_x": 1.2258104979991913, "cast_g": 0.31299516558647156, "cast_w": 0.3114677965641022, "time_standard": 26.861485093832016, "time_rowwise": 21.90464735031128, "time_global": 18.981758505105972}
{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.52763444185257, "standard_gw": 17.835520207881927, "standard_gx": 17.375655472278595, "rowwise_fwd": 8.35346058011055, "rowwise_bwd": 7.584303617477417, "global_fwd": 8.300606161355972, "global_bwd": 7.550913840532303, "x_quantize_rowwise": 0.3016740083694458, "g_quantize_rowwise": 1.7321519553661346, "w_quantize_rowwise": 0.15538185834884644, "w_quantize_colwise_transpose": 1.6110800206661224, "w_quantize_global": 0.4815198481082916, "w_quantize_global_transpose": 0.5066357553005219, "cast_x": 0.6163753569126129, "cast_g": 2.4452805519104004, "cast_w": 0.31156837940216064, "time_standard": 53.73881012201309, "time_rowwise": 37.573572248220444, "time_global": 36.7090217769146}
{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 18.073823302984238, "standard_gw": 16.71283319592476, "standard_gx": 18.46104860305786, "rowwise_fwd": 7.542364299297333, "rowwise_bwd": 8.374195545911789, "global_fwd": 7.5644850730896, "global_bwd": 8.26016440987587, "x_quantize_rowwise": 1.7326027154922485, "g_quantize_rowwise": 0.30233338475227356, "w_quantize_rowwise": 0.2259574830532074, "w_quantize_colwise_transpose": 3.634512424468994, "w_quantize_global": 0.48204511404037476, "w_quantize_global_transpose": 0.5093887448310852, "cast_x": 2.445656806230545, "cast_g": 0.6163381040096283, "cast_w": 0.31144917011260986, "time_standard": 53.24770510196686, "time_rowwise": 38.524799048900604, "time_global": 35.56385263800621}
{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.123402416706085, "standard_gw": 32.68447890877724, "standard_gx": 34.13737937808037, "rowwise_fwd": 16.65867120027542, "rowwise_bwd": 15.004873275756836, "global_fwd": 16.536589711904526, "global_bwd": 14.949381351470947, "x_quantize_rowwise": 0.5952902138233185, "g_quantize_rowwise": 3.4581348299980164, "w_quantize_rowwise": 0.15559792518615723, "w_quantize_colwise_transpose": 1.6055963933467865, "w_quantize_global": 0.48203766345977783, "w_quantize_global_transpose": 0.5048215389251709, "cast_x": 1.2256354093551636, "cast_g": 4.875503480434418, "cast_w": 0.3110244870185852, "time_standard": 102.94526070356369, "time_rowwise": 70.16264274716377, "time_global": 69.210734218359}
{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.0223146378994, "standard_gw": 32.84081444144249, "standard_gx": 35.984884947538376, "rowwise_fwd": 15.018381178379059, "rowwise_bwd": 16.69919490814209, "global_fwd": 14.942582696676254, "global_bwd": 16.529250890016556, "x_quantize_rowwise": 3.442291170358658, "g_quantize_rowwise": 0.5951747298240662, "w_quantize_rowwise": 0.22576376795768738, "w_quantize_colwise_transpose": 3.621157258749008, "w_quantize_global": 0.48135966062545776, "w_quantize_global_transpose": 0.5095489323139191, "cast_x": 4.875205457210541, "cast_g": 1.2237727642059326, "cast_w": 0.3110431134700775, "time_standard": 103.84801402688026, "time_rowwise": 72.44277745485306, "time_global": 69.3410225212574}
{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 72.33698666095734, "standard_gw": 71.31465151906013, "standard_gx": 69.32922825217247, "rowwise_fwd": 33.37707370519638, "rowwise_bwd": 30.1642008125782, "global_fwd": 33.002063632011414, "global_bwd": 30.003495514392853, "x_quantize_rowwise": 1.1819563806056976, "g_quantize_rowwise": 6.896954029798508, "w_quantize_rowwise": 0.15557929873466492, "w_quantize_colwise_transpose": 1.6083605587482452, "w_quantize_global": 0.48125162720680237, "w_quantize_global_transpose": 0.5055665969848633, "cast_x": 2.442535012960434, "cast_g": 9.750165045261383, "cast_w": 0.31094998121261597, "time_standard": 212.98086643218994, "time_rowwise": 144.69877630472183, "time_global": 143.38593930006027}
{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 70.24158909916878, "standard_gw": 72.03734293580055, "standard_gx": 72.01339676976204, "rowwise_fwd": 30.072908848524094, "rowwise_bwd": 33.376410603523254, "global_fwd": 29.965493828058243, "global_bwd": 33.01112726330757, "x_quantize_rowwise": 6.894122809171677, "g_quantize_rowwise": 1.1817142367362976, "w_quantize_rowwise": 0.22567808628082275, "w_quantize_colwise_transpose": 3.616899251937866, "w_quantize_global": 0.4819147288799286, "w_quantize_global_transpose": 0.5107112228870392, "cast_x": 9.750377386808395, "cast_g": 2.4411343038082123, "cast_w": 0.31099095940589905, "time_standard": 214.29232880473137, "time_rowwise": 147.40507677197456, "time_global": 144.0824270248413}
{"repeat": 64, "batch_size": 65536, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 138.23134452104568, "standard_gw": 131.48364424705505, "standard_gx": 141.09868183732033, "rowwise_fwd": 65.38830325007439, "rowwise_bwd": 58.39048698544502, "global_fwd": 65.2194656431675, "global_bwd": 58.58004465699196, "x_quantize_rowwise": 1.1899955570697784, "g_quantize_rowwise": 6.623774766921997, "w_quantize_rowwise": 0.5935952067375183, "w_quantize_colwise_transpose": 24.08137544989586, "w_quantize_global": 1.740824431180954, "w_quantize_global_transpose": 1.8664970993995667, "cast_x": 2.413548529148102, "cast_g": 9.63655486702919, "cast_w": 1.1956281960010529, "time_standard": 410.81367060542107, "time_rowwise": 287.7511754631996, "time_global": 266.7042464017868}
{"repeat": 64, "batch_size": 65536, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 141.08363911509514, "standard_gw": 133.26667994260788, "standard_gx": 136.0350362956524, "rowwise_fwd": 58.49892646074295, "rowwise_bwd": 65.34496694803238, "global_fwd": 58.73573571443558, "global_bwd": 65.30505418777466, "x_quantize_rowwise": 6.648071110248566, "g_quantize_rowwise": 1.1903978884220123, "w_quantize_rowwise": 0.8329600095748901, "w_quantize_colwise_transpose": 15.297897160053253, "w_quantize_global": 1.7403066158294678, "w_quantize_global_transpose": 1.8791332840919495, "cast_x": 9.636614471673965, "cast_g": 2.4122819304466248, "cast_w": 1.1954344809055328, "time_standard": 410.3853553533554, "time_rowwise": 281.07989951968193, "time_global": 268.7653787434101}
{"repeat": 64, "batch_size": 1024, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 2.535879611968994, "standard_gw": 2.249978482723236, "standard_gx": 2.2262558341026306, "rowwise_fwd": 1.085665076971054, "rowwise_bwd": 1.069542020559311, "global_fwd": 1.0830685496330261, "global_bwd": 1.0597631335258484, "x_quantize_rowwise": 0.02650916576385498, "g_quantize_rowwise": 0.1200847327709198, "w_quantize_rowwise": 0.5937665700912476, "w_quantize_colwise_transpose": 23.926906287670135, "w_quantize_global": 1.7397291958332062, "w_quantize_global_transpose": 1.8652454018592834, "cast_x": 0.03688782453536987, "cast_g": 0.15725940465927124, "cast_w": 1.1969134211540222, "time_standard": 7.012113928794861, "time_rowwise": 29.07245233654976, "time_global": 8.144378662109375}
{"repeat": 64, "batch_size": 1024, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 2.245493233203888, "standard_gw": 2.2966675460338593, "standard_gx": 2.216015011072159, "rowwise_fwd": 1.1000856757164001, "rowwise_bwd": 1.0902360081672668, "global_fwd": 1.0597333312034607, "global_bwd": 1.0812543332576752, "x_quantize_rowwise": 0.11992454528808594, "g_quantize_rowwise": 0.026784837245941162, "w_quantize_rowwise": 0.8310377597808838, "w_quantize_colwise_transpose": 15.30550792813301, "w_quantize_global": 1.7401352524757385, "w_quantize_global_transpose": 1.8841177225112915, "cast_x": 0.1573599874973297, "cast_g": 0.03676116466522217, "cast_w": 1.195952296257019, "time_standard": 6.758175790309906, "time_rowwise": 20.770244300365448, "time_global": 8.208617568016052}
{"repeat": 64, "batch_size": 2048, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 4.197858273983002, "standard_gw": 4.288379102945328, "standard_gx": 4.155721515417099, "rowwise_fwd": 2.0567886531352997, "rowwise_bwd": 1.9073635339736938, "global_fwd": 2.0506344735622406, "global_bwd": 1.9086338579654694, "x_quantize_rowwise": 0.04758685827255249, "g_quantize_rowwise": 0.22284314036369324, "w_quantize_rowwise": 0.5935467779636383, "w_quantize_colwise_transpose": 23.935042321681976, "w_quantize_global": 1.7397813498973846, "w_quantize_global_transpose": 1.8662959337234497, "cast_x": 0.08194148540496826, "cast_g": 0.3077872097492218, "cast_w": 1.1968687176704407, "time_standard": 12.641958892345428, "time_rowwise": 33.05155038833618, "time_global": 12.124154716730118}
{"repeat": 64, "batch_size": 2048, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 4.126541316509247, "standard_gw": 4.309836775064468, "standard_gx": 4.117351025342941, "rowwise_fwd": 1.9266381859779358, "rowwise_bwd": 2.0577237010002136, "global_fwd": 1.908630132675171, "global_bwd": 2.0505934953689575, "x_quantize_rowwise": 0.22304058074951172, "g_quantize_rowwise": 0.04766136407852173, "w_quantize_rowwise": 0.8306317031383514, "w_quantize_colwise_transpose": 15.309855341911316, "w_quantize_global": 1.7415396869182587, "w_quantize_global_transpose": 1.8827766180038452, "cast_x": 0.30782073736190796, "cast_g": 0.08186325430870056, "cast_w": 1.1955127120018005, "time_standard": 12.553729116916656, "time_rowwise": 24.70538765192032, "time_global": 12.164078652858734}
{"repeat": 64, "batch_size": 4096, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 8.298952132463455, "standard_gw": 8.345257490873337, "standard_gx": 8.647706359624863, "rowwise_fwd": 4.106882959604263, "rowwise_bwd": 3.8046911358833313, "global_fwd": 4.09451499581337, "global_bwd": 3.8078874349594116, "x_quantize_rowwise": 0.08447840809822083, "g_quantize_rowwise": 0.4291348159313202, "w_quantize_rowwise": 0.5934201180934906, "w_quantize_colwise_transpose": 23.843105882406235, "w_quantize_global": 1.7399191856384277, "w_quantize_global_transpose": 1.8653236329555511, "cast_x": 0.1577921211719513, "cast_g": 0.6089024245738983, "cast_w": 1.1952444911003113, "time_standard": 25.291915982961655, "time_rowwise": 41.2069708108902, "time_global": 20.366515964269638}
{"repeat": 64, "batch_size": 4096, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 8.323360234498978, "standard_gw": 8.433796465396881, "standard_gx": 8.236430585384369, "rowwise_fwd": 3.8114115595817566, "rowwise_bwd": 4.106346517801285, "global_fwd": 3.8080140948295593, "global_bwd": 4.094675183296204, "x_quantize_rowwise": 0.4288516938686371, "g_quantize_rowwise": 0.08437782526016235, "w_quantize_rowwise": 0.8310228586196899, "w_quantize_colwise_transpose": 15.306610614061356, "w_quantize_global": 1.741155982017517, "w_quantize_global_transpose": 1.8809586763381958, "cast_x": 0.6091706454753876, "cast_g": 0.157233327627182, "cast_w": 1.1953115463256836, "time_standard": 24.993587285280228, "time_rowwise": 33.00241753458977, "time_global": 20.471829921007156}
{"repeat": 64, "batch_size": 8192, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 16.656354069709778, "standard_gw": 17.066240310668945, "standard_gx": 17.252348363399506, "rowwise_fwd": 8.220307528972626, "rowwise_bwd": 7.2372183203697205, "global_fwd": 8.2036592066288, "global_bwd": 7.236208766698837, "x_quantize_rowwise": 0.15832111239433289, "g_quantize_rowwise": 0.8406005799770355, "w_quantize_rowwise": 0.5935393273830414, "w_quantize_colwise_transpose": 23.86143058538437, "w_quantize_global": 1.7401576042175293, "w_quantize_global_transpose": 1.8653534352779388, "cast_x": 0.3079026937484741, "cast_g": 1.209162175655365, "cast_w": 1.1951625347137451, "time_standard": 50.97494274377823, "time_rowwise": 57.97765776515007, "time_global": 37.11054101586342}
{"repeat": 64, "batch_size": 8192, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 17.398890107870102, "standard_gw": 18.470749258995056, "standard_gx": 16.520217061042786, "rowwise_fwd": 7.235266268253326, "rowwise_bwd": 8.207589387893677, "global_fwd": 7.235914468765259, "global_bwd": 8.204508572816849, "x_quantize_rowwise": 0.8409880101680756, "g_quantize_rowwise": 0.15821680426597595, "w_quantize_rowwise": 0.8324198424816132, "w_quantize_colwise_transpose": 15.305522829294205, "w_quantize_global": 1.7396919429302216, "w_quantize_global_transpose": 1.8805749714374542, "cast_x": 1.2103468179702759, "cast_g": 0.30729547142982483, "cast_w": 1.1953599750995636, "time_standard": 52.389856427907944, "time_rowwise": 51.05075240135193, "time_global": 38.53064402937889}
{"repeat": 64, "batch_size": 16384, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 33.533211797475815, "standard_gw": 33.00020843744278, "standard_gx": 34.614477306604385, "rowwise_fwd": 16.364943236112595, "rowwise_bwd": 14.551006257534027, "global_fwd": 16.33496955037117, "global_bwd": 14.513172209262848, "x_quantize_rowwise": 0.3053396940231323, "g_quantize_rowwise": 1.6693994402885437, "w_quantize_rowwise": 0.5936138331890106, "w_quantize_colwise_transpose": 23.89485388994217, "w_quantize_global": 1.741711050271988, "w_quantize_global_transpose": 1.8656104803085327, "cast_x": 0.6089657545089722, "cast_g": 2.4122074246406555, "cast_w": 1.1951886117458344, "time_standard": 101.14789754152298, "time_rowwise": 90.37936478853226, "time_global": 69.430410861969}
{"repeat": 64, "batch_size": 16384, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 33.65536406636238, "standard_gw": 33.02193805575371, "standard_gx": 33.10496360063553, "rowwise_fwd": 14.54489678144455, "rowwise_bwd": 16.36252924799919, "global_fwd": 14.50401172041893, "global_bwd": 16.33254438638687, "x_quantize_rowwise": 1.6695670783519745, "g_quantize_rowwise": 0.3054291009902954, "w_quantize_rowwise": 0.83121657371521, "w_quantize_colwise_transpose": 15.305932611227036, "w_quantize_global": 1.7382949590682983, "w_quantize_global_transpose": 1.880194991827011, "cast_x": 2.412091940641403, "cast_g": 0.6079599261283875, "cast_w": 1.1950358748435974, "time_standard": 99.78226572275162, "time_rowwise": 82.04150944948196, "time_global": 69.45198029279709}
{"repeat": 64, "batch_size": 32768, "dim_out": 32384, "dim_in": 8096, "wm": 4, "switch": false, "standard_fwd": 67.96638667583466, "standard_gw": 67.99514591693878, "standard_gx": 69.66376304626465, "rowwise_fwd": 33.51752087473869, "rowwise_bwd": 29.131878167390823, "global_fwd": 32.65715390443802, "global_bwd": 29.13403883576393, "x_quantize_rowwise": 0.6002038717269897, "g_quantize_rowwise": 3.3336542546749115, "w_quantize_rowwise": 0.5934685468673706, "w_quantize_colwise_transpose": 23.92345294356346, "w_quantize_global": 1.7405375838279724, "w_quantize_global_transpose": 1.8656738102436066, "cast_x": 1.2112446129322052, "cast_g": 4.81804832816124, "cast_w": 1.1952146887779236, "time_standard": 205.6252956390381, "time_rowwise": 159.09532457590103, "time_global": 137.3264081776142}
{"repeat": 64, "batch_size": 32768, "dim_out": 8096, "dim_in": 32384, "wm": 4, "switch": true, "standard_fwd": 68.2341456413269, "standard_gw": 65.5074268579483, "standard_gx": 67.13805347681046, "rowwise_fwd": 29.153641313314438, "rowwise_bwd": 32.71844983100891, "global_fwd": 29.124341905117035, "global_bwd": 32.65979886054993, "x_quantize_rowwise": 3.3318176865577698, "g_quantize_rowwise": 0.6004795432090759, "w_quantize_rowwise": 0.8309967815876007, "w_quantize_colwise_transpose": 15.305690467357635, "w_quantize_global": 1.7405711114406586, "w_quantize_global_transpose": 1.8802620470523834, "cast_x": 4.8183538019657135, "cast_g": 1.2096390128135681, "cast_w": 1.1951103806495667, "time_standard": 200.87962597608566, "time_rowwise": 147.44850248098373, "time_global": 134.84469801187515}
{"repeat": 64, "batch_size": 1024, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.07764250040054321, "standard_gw": 0.07398426532745361, "standard_gx": 0.08482858538627625, "rowwise_fwd": 0.05266070365905762, "rowwise_bwd": 0.04478543996810913, "global_fwd": 0.052012503147125244, "global_bwd": 0.044364482164382935, "x_quantize_rowwise": 0.02640858292579651, "g_quantize_rowwise": 0.02539902925491333, "w_quantize_rowwise": 0.026457011699676514, "w_quantize_colwise_transpose": 0.17770379781723022, "w_quantize_global": 0.07440149784088135, "w_quantize_global_transpose": 0.08142739534378052, "cast_x": 0.008150935173034668, "cast_g": 0.022415071725845337, "cast_w": 0.03479421138763428, "time_standard": 0.23645535111427307, "time_rowwise": 0.42739883065223694, "time_global": 0.3779977560043335}
{"repeat": 64, "batch_size": 1024, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.08524581789970398, "standard_gw": 0.07383152842521667, "standard_gx": 0.07564574480056763, "rowwise_fwd": 0.04478171467781067, "rowwise_bwd": 0.052671879529953, "global_fwd": 0.04452839493751526, "global_bwd": 0.05219504237174988, "x_quantize_rowwise": 0.025328248739242554, "g_quantize_rowwise": 0.027123838663101196, "w_quantize_rowwise": 0.025607645511627197, "w_quantize_colwise_transpose": 0.17121434211730957, "w_quantize_global": 0.07916614413261414, "w_quantize_global_transpose": 0.08177384734153748, "cast_x": 0.022619962692260742, "cast_g": 0.008556991815567017, "cast_w": 0.034421682357788086, "time_standard": 0.23472309112548828, "time_rowwise": 0.42055919766426086, "time_global": 0.3839470446109772}
{"repeat": 64, "batch_size": 2048, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.13731792569160461, "standard_gw": 0.13414397835731506, "standard_gx": 0.14049187302589417, "rowwise_fwd": 0.10158121585845947, "rowwise_bwd": 0.07804110646247864, "global_fwd": 0.09908527135848999, "global_bwd": 0.07766112685203552, "x_quantize_rowwise": 0.026516616344451904, "g_quantize_rowwise": 0.03666803240776062, "w_quantize_rowwise": 0.024981796741485596, "w_quantize_colwise_transpose": 0.17706677317619324, "w_quantize_global": 0.07443130016326904, "w_quantize_global_transpose": 0.07870793342590332, "cast_x": 0.01224130392074585, "cast_g": 0.05828961730003357, "cast_w": 0.03501400351524353, "time_standard": 0.41195377707481384, "time_rowwise": 0.5789995193481445, "time_global": 0.5272142589092255}
{"repeat": 64, "batch_size": 2048, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.14651194214820862, "standard_gw": 0.14011189341545105, "standard_gx": 0.140264630317688, "rowwise_fwd": 0.081576406955719, "rowwise_bwd": 0.10671466588973999, "global_fwd": 0.08158013224601746, "global_bwd": 0.10219961404800415, "x_quantize_rowwise": 0.03775954246520996, "g_quantize_rowwise": 0.026103109121322632, "w_quantize_rowwise": 0.02656877040863037, "w_quantize_colwise_transpose": 0.17822161316871643, "w_quantize_global": 0.07506832480430603, "w_quantize_global_transpose": 0.07928535342216492, "cast_x": 0.05893409252166748, "cast_g": 0.012326985597610474, "cast_w": 0.03498047590255737, "time_standard": 0.42688846588134766, "time_rowwise": 0.5970560014247894, "time_global": 0.5421079695224762}
{"repeat": 64, "batch_size": 4096, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.2734065055847168, "standard_gw": 0.25558844208717346, "standard_gx": 0.29174983501434326, "rowwise_fwd": 0.173322856426239, "rowwise_bwd": 0.1515895128250122, "global_fwd": 0.17048418521881104, "global_bwd": 0.1506991684436798, "x_quantize_rowwise": 0.025950372219085693, "g_quantize_rowwise": 0.0653192400932312, "w_quantize_rowwise": 0.027138739824295044, "w_quantize_colwise_transpose": 0.17699971795082092, "w_quantize_global": 0.07373467087745667, "w_quantize_global_transpose": 0.07901713252067566, "cast_x": 0.02214685082435608, "cast_g": 0.11127442121505737, "cast_w": 0.03481656312942505, "time_standard": 0.8207447826862335, "time_rowwise": 0.8759088814258575, "time_global": 0.8207932114601135}
{"repeat": 64, "batch_size": 4096, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.27839839458465576, "standard_gw": 0.2537444233894348, "standard_gx": 0.28207898139953613, "rowwise_fwd": 0.16542896628379822, "rowwise_bwd": 0.18540024757385254, "global_fwd": 0.15722215175628662, "global_bwd": 0.17368420958518982, "x_quantize_rowwise": 0.06661936640739441, "g_quantize_rowwise": 0.027049332857131958, "w_quantize_rowwise": 0.025507062673568726, "w_quantize_colwise_transpose": 0.1741349697113037, "w_quantize_global": 0.07463246583938599, "w_quantize_global_transpose": 0.07879361510276794, "cast_x": 0.11301413178443909, "cast_g": 0.023346394300460815, "cast_w": 0.03505498170852661, "time_standard": 0.8142217993736267, "time_rowwise": 0.8978843688964844, "time_global": 0.8317455649375916}
{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5755424499511719, "standard_gw": 0.5219094455242157, "standard_gx": 0.5992203950881958, "rowwise_fwd": 0.33193081617355347, "rowwise_bwd": 0.295441597700119, "global_fwd": 0.32791122794151306, "global_bwd": 0.2906434237957001, "x_quantize_rowwise": 0.0337548553943634, "g_quantize_rowwise": 0.1225881278514862, "w_quantize_rowwise": 0.024937093257904053, "w_quantize_colwise_transpose": 0.17729029059410095, "w_quantize_global": 0.0730752944946289, "w_quantize_global_transpose": 0.07835403084754944, "cast_x": 0.058166682720184326, "cast_g": 0.21592900156974792, "cast_w": 0.03454089164733887, "time_standard": 1.6966722905635834, "time_rowwise": 1.5078522264957428, "time_global": 1.4482364058494568}
{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5104020237922668, "standard_gw": 0.5302242934703827, "standard_gx": 0.5842559039592743, "rowwise_fwd": 0.32220035791397095, "rowwise_bwd": 0.3576017916202545, "global_fwd": 0.2939775586128235, "global_bwd": 0.3313682973384857, "x_quantize_rowwise": 0.12369826436042786, "g_quantize_rowwise": 0.03423169255256653, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.16975775361061096, "w_quantize_global": 0.0768713653087616, "w_quantize_global_transpose": 0.08094683289527893, "cast_x": 0.21589547395706177, "cast_g": 0.05825608968734741, "cast_w": 0.03466010093688965, "time_standard": 1.6248822212219238, "time_rowwise": 1.5642158687114716, "time_global": 1.4713183045387268}
{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.194491982460022, "standard_gw": 1.0553859174251556, "standard_gx": 1.0726377367973328, "rowwise_fwd": 0.636763870716095, "rowwise_bwd": 0.5154944956302643, "global_fwd": 0.6281323730945587, "global_bwd": 0.5117170512676239, "x_quantize_rowwise": 0.062175095081329346, "g_quantize_rowwise": 0.23643672466278076, "w_quantize_rowwise": 0.025566667318344116, "w_quantize_colwise_transpose": 0.17768144607543945, "w_quantize_global": 0.07302314043045044, "w_quantize_global_transpose": 0.07866695523262024, "cast_x": 0.11140108108520508, "cast_g": 0.42498111724853516, "cast_w": 0.034831464290618896, "time_standard": 3.3225156366825104, "time_rowwise": 2.7095042169094086, "time_global": 2.645537257194519}
{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.0797791182994843, "standard_gw": 1.062549650669098, "standard_gx": 1.104947179555893, "rowwise_fwd": 0.5390122532844543, "rowwise_bwd": 0.6449781358242035, "global_fwd": 0.5145668983459473, "global_bwd": 0.6276033818721771, "x_quantize_rowwise": 0.23603439331054688, "g_quantize_rowwise": 0.062234699726104736, "w_quantize_rowwise": 0.02781301736831665, "w_quantize_colwise_transpose": 0.1703314483165741, "w_quantize_global": 0.07431954145431519, "w_quantize_global_transpose": 0.08028373122215271, "cast_x": 0.4249885678291321, "cast_g": 0.1113303005695343, "cast_w": 0.0348016619682312, "time_standard": 3.247275948524475, "time_rowwise": 2.742953598499298, "time_global": 2.657592296600342}
{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.392485737800598, "standard_gw": 2.046734094619751, "standard_gx": 2.177651971578598, "rowwise_fwd": 1.252591609954834, "rowwise_bwd": 1.0205842554569244, "global_fwd": 1.230098307132721, "global_bwd": 1.0132193565368652, "x_quantize_rowwise": 0.11823698878288269, "g_quantize_rowwise": 0.4639141261577606, "w_quantize_rowwise": 0.02602487802505493, "w_quantize_colwise_transpose": 0.17801672220230103, "w_quantize_global": 0.07301196455955505, "w_quantize_global_transpose": 0.07893890142440796, "cast_x": 0.21591037511825562, "cast_g": 0.843394547700882, "cast_w": 0.03460049629211426, "time_standard": 6.616871803998947, "time_rowwise": 5.106102675199509, "time_global": 5.0241537392139435}
{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.205628901720047, "standard_gw": 1.9917488098144531, "standard_gx": 2.1518059074878693, "rowwise_fwd": 1.040138304233551, "rowwise_bwd": 1.2538731098175049, "global_fwd": 1.0131187736988068, "global_bwd": 1.2291893362998962, "x_quantize_rowwise": 0.46381354331970215, "g_quantize_rowwise": 0.11790916323661804, "w_quantize_rowwise": 0.027123838663101196, "w_quantize_colwise_transpose": 0.17021596431732178, "w_quantize_global": 0.0752471387386322, "w_quantize_global_transpose": 0.08159875869750977, "cast_x": 0.8433908224105835, "cast_g": 0.215873122215271, "cast_w": 0.03452599048614502, "time_standard": 6.349183619022369, "time_rowwise": 5.064822733402252, "time_global": 4.972625523805618}
{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.755370318889618, "standard_gw": 4.736289381980896, "standard_gx": 4.0378570556640625, "rowwise_fwd": 2.4783052504062653, "rowwise_bwd": 1.9634142518043518, "global_fwd": 2.435591071844101, "global_bwd": 1.9498206675052643, "x_quantize_rowwise": 0.22948533296585083, "g_quantize_rowwise": 0.9186491370201111, "w_quantize_rowwise": 0.028233975172042847, "w_quantize_colwise_transpose": 0.17858296632766724, "w_quantize_global": 0.07418543100357056, "w_quantize_global_transpose": 0.07958710193634033, "cast_x": 0.4257224500179291, "cast_g": 1.680031418800354, "cast_w": 0.03458559513092041, "time_standard": 13.529516756534576, "time_rowwise": 10.532960295677185, "time_global": 10.423608124256134}
{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.050172865390778, "standard_gw": 3.916766494512558, "standard_gx": 4.281226545572281, "rowwise_fwd": 1.9789263606071472, "rowwise_bwd": 2.477586269378662, "global_fwd": 1.9495487213134766, "global_bwd": 2.434592694044113, "x_quantize_rowwise": 0.918261706829071, "g_quantize_rowwise": 0.22961944341659546, "w_quantize_rowwise": 0.025540590286254883, "w_quantize_colwise_transpose": 0.17032772302627563, "w_quantize_global": 0.07384642958641052, "w_quantize_global_transpose": 0.08105114102363586, "cast_x": 1.679886132478714, "cast_g": 0.42508915066719055, "cast_w": 0.03442913293838501, "time_standard": 12.248165905475616, "time_rowwise": 9.717028588056564, "time_global": 9.60368663072586}
{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.53347235918045, "standard_gw": 8.138865232467651, "standard_gx": 7.9666972160339355, "rowwise_fwd": 4.984956234693527, "rowwise_bwd": 3.850068897008896, "global_fwd": 4.9025751650333405, "global_bwd": 3.820303827524185, "x_quantize_rowwise": 0.45222043991088867, "g_quantize_rowwise": 1.8290691077709198, "w_quantize_rowwise": 0.026736408472061157, "w_quantize_colwise_transpose": 0.17832592129707336, "w_quantize_global": 0.07471069693565369, "w_quantize_global_transpose": 0.08177757263183594, "cast_x": 0.8435025811195374, "cast_g": 3.3529214560985565, "cast_w": 0.03475695848464966, "time_standard": 25.639034807682037, "time_rowwise": 19.460242241621017, "time_global": 19.299522042274475}
{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 7.996037602424622, "standard_gw": 8.2748644053936, "standard_gx": 8.523400872945786, "rowwise_fwd": 3.8556940853595734, "rowwise_bwd": 4.966288805007935, "global_fwd": 3.820043057203293, "global_bwd": 4.882067441940308, "x_quantize_rowwise": 1.8279887735843658, "g_quantize_rowwise": 0.4520900547504425, "w_quantize_rowwise": 0.02676248550415039, "w_quantize_colwise_transpose": 0.17083808779716492, "w_quantize_global": 0.07691606879234314, "w_quantize_global_transpose": 0.08223950862884521, "cast_x": 3.3530443906784058, "cast_g": 0.8434318006038666, "cast_w": 0.034671276807785034, "time_standard": 24.794302880764008, "time_rowwise": 19.574526697397232, "time_global": 19.416209310293198}
{"repeat": 64, "batch_size": 1024, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.09413063526153564, "standard_gw": 0.10038167238235474, "standard_gx": 0.09725615382194519, "rowwise_fwd": 0.05979463458061218, "rowwise_bwd": 0.0525452196598053, "global_fwd": 0.059057027101516724, "global_bwd": 0.05194917321205139, "x_quantize_rowwise": 0.02664700150489807, "g_quantize_rowwise": 0.02642720937728882, "w_quantize_rowwise": 0.030562281608581543, "w_quantize_colwise_transpose": 0.2400912344455719, "w_quantize_global": 0.09407848119735718, "w_quantize_global_transpose": 0.10256841778755188, "cast_x": 0.008724629878997803, "cast_g": 0.028502196073532104, "cast_w": 0.05552172660827637, "time_standard": 0.29176846146583557, "time_rowwise": 0.5364492535591125, "time_global": 0.4611089825630188}
{"repeat": 64, "batch_size": 1024, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.09753555059432983, "standard_gw": 0.10102242231369019, "standard_gx": 0.09121373295783997, "rowwise_fwd": 0.052150338888168335, "rowwise_bwd": 0.059779733419418335, "global_fwd": 0.05161017179489136, "global_bwd": 0.05943328142166138, "x_quantize_rowwise": 0.026702880859375, "g_quantize_rowwise": 0.02469494938850403, "w_quantize_rowwise": 0.03324449062347412, "w_quantize_colwise_transpose": 0.23468583822250366, "w_quantize_global": 0.09394437074661255, "w_quantize_global_transpose": 0.10142102837562561, "cast_x": 0.028360635042190552, "cast_g": 0.008717179298400879, "cast_w": 0.05577504634857178, "time_standard": 0.28977170586586, "time_rowwise": 0.5322806537151337, "time_global": 0.4588291049003601}
{"repeat": 64, "batch_size": 2048, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.18056854605674744, "standard_gw": 0.18374621868133545, "standard_gx": 0.19219890236854553, "rowwise_fwd": 0.1150965690612793, "rowwise_bwd": 0.0903494656085968, "global_fwd": 0.11263042688369751, "global_bwd": 0.08984282612800598, "x_quantize_rowwise": 0.027067959308624268, "g_quantize_rowwise": 0.040043145418167114, "w_quantize_rowwise": 0.03063306212425232, "w_quantize_colwise_transpose": 0.24128705263137817, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.1024976372718811, "cast_x": 0.01381710171699524, "cast_g": 0.06845593452453613, "cast_w": 0.05572289228439331, "time_standard": 0.5565136671066284, "time_rowwise": 0.7282234728336334, "time_global": 0.6494410336017609}
{"repeat": 64, "batch_size": 2048, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.16536936163902283, "standard_gw": 0.19479170441627502, "standard_gx": 0.18597766757011414, "rowwise_fwd": 0.09634345769882202, "rowwise_bwd": 0.11937320232391357, "global_fwd": 0.09264424443244934, "global_bwd": 0.11524930596351624, "x_quantize_rowwise": 0.04038214683532715, "g_quantize_rowwise": 0.025559216737747192, "w_quantize_rowwise": 0.03334507346153259, "w_quantize_colwise_transpose": 0.23956596851348877, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.1020580530166626, "cast_x": 0.06891414523124695, "cast_g": 0.013861805200576782, "cast_w": 0.05607306957244873, "time_standard": 0.546138733625412, "time_rowwise": 0.7493607699871063, "time_global": 0.6651394069194794}
{"repeat": 64, "batch_size": 4096, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.36064907908439636, "standard_gw": 0.3711991012096405, "standard_gx": 0.3863237798213959, "rowwise_fwd": 0.22270530462265015, "rowwise_bwd": 0.1760348677635193, "global_fwd": 0.21781772375106812, "global_bwd": 0.17484650015830994, "x_quantize_rowwise": 0.02625212073326111, "g_quantize_rowwise": 0.07131323218345642, "w_quantize_rowwise": 0.030372291803359985, "w_quantize_colwise_transpose": 0.23974105715751648, "w_quantize_global": 0.09407475590705872, "w_quantize_global_transpose": 0.1024492084980011, "cast_x": 0.028584152460098267, "cast_g": 0.1303069293498993, "cast_w": 0.05582347512245178, "time_standard": 1.1181719601154327, "time_rowwise": 1.137617975473404, "time_global": 1.057952642440796}
{"repeat": 64, "batch_size": 4096, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.32703205943107605, "standard_gw": 0.3764517605304718, "standard_gx": 0.3938935697078705, "rowwise_fwd": 0.18771737813949585, "rowwise_bwd": 0.2374798059463501, "global_fwd": 0.1843757927417755, "global_bwd": 0.23005902767181396, "x_quantize_rowwise": 0.07155537605285645, "g_quantize_rowwise": 0.02625212073326111, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.23755058646202087, "w_quantize_global": 0.09388476610183716, "w_quantize_global_transpose": 0.10246038436889648, "cast_x": 0.13131648302078247, "cast_g": 0.028781592845916748, "cast_w": 0.05638599395751953, "time_standard": 1.0973773896694183, "time_rowwise": 1.1699534952640533, "time_global": 1.0850392282009125}
{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7961541414260864, "standard_gw": 0.7424280047416687, "standard_gx": 0.8688867092132568, "rowwise_fwd": 0.432576984167099, "rowwise_bwd": 0.34543126821517944, "global_fwd": 0.4248805344104767, "global_bwd": 0.3432855010032654, "x_quantize_rowwise": 0.03750622272491455, "g_quantize_rowwise": 0.13292208313941956, "w_quantize_rowwise": 0.030599534511566162, "w_quantize_colwise_transpose": 0.24292618036270142, "w_quantize_global": 0.09351596236228943, "w_quantize_global_transpose": 0.1026056706905365, "cast_x": 0.06843730807304382, "cast_g": 0.2539418637752533, "cast_w": 0.05568563938140869, "time_standard": 2.407468855381012, "time_rowwise": 1.9643902778625488, "time_global": 1.8771439790725708}
{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7150471210479736, "standard_gw": 0.7525831460952759, "standard_gx": 0.8075274527072906, "rowwise_fwd": 0.36595389246940613, "rowwise_bwd": 0.4404708743095398, "global_fwd": 0.3485158085823059, "global_bwd": 0.4275962710380554, "x_quantize_rowwise": 0.1329965889453888, "g_quantize_rowwise": 0.03767386078834534, "w_quantize_rowwise": 0.03295019268989563, "w_quantize_colwise_transpose": 0.23509934544563293, "w_quantize_global": 0.09398534893989563, "w_quantize_global_transpose": 0.10186433792114258, "cast_x": 0.2537667751312256, "cast_g": 0.06839632987976074, "cast_w": 0.05571544170379639, "time_standard": 2.27515771985054, "time_rowwise": 1.9977279007434845, "time_global": 1.8952153623104095}
{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6392990946769714, "standard_gw": 1.4941170811653137, "standard_gx": 1.4451220631599426, "rowwise_fwd": 0.8369758725166321, "rowwise_bwd": 0.6830468773841858, "global_fwd": 0.8197203278541565, "global_bwd": 0.6782263517379761, "x_quantize_rowwise": 0.06883591413497925, "g_quantize_rowwise": 0.2565309405326843, "w_quantize_rowwise": 0.03046169877052307, "w_quantize_colwise_transpose": 0.2430342137813568, "w_quantize_global": 0.09346380829811096, "w_quantize_global_transpose": 0.10301917791366577, "cast_x": 0.13044849038124084, "cast_g": 0.5010999739170074, "cast_w": 0.05590170621871948, "time_standard": 4.578538239002228, "time_rowwise": 3.613002598285675, "time_global": 3.5139136016368866}
{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4654621481895447, "standard_gw": 1.5012174844741821, "standard_gx": 1.5183314681053162, "rowwise_fwd": 0.7059797644615173, "rowwise_bwd": 0.8470229804515839, "global_fwd": 0.6788894534111023, "global_bwd": 0.8200779557228088, "x_quantize_rowwise": 0.2564750611782074, "g_quantize_rowwise": 0.06899237632751465, "w_quantize_rowwise": 0.03293529152870178, "w_quantize_colwise_transpose": 0.23559853434562683, "w_quantize_global": 0.09375810623168945, "w_quantize_global_transpose": 0.10203942656517029, "cast_x": 0.5010105669498444, "cast_g": 0.13037025928497314, "cast_w": 0.05577504634857178, "time_standard": 4.485011100769043, "time_rowwise": 3.648221492767334, "time_global": 3.521449863910675}
{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.236088901758194, "standard_gw": 2.8601549565792084, "standard_gx": 2.8000958263874054, "rowwise_fwd": 1.6548968851566315, "rowwise_bwd": 1.3559646904468536, "global_fwd": 1.6249343752861023, "global_bwd": 1.3474412262439728, "x_quantize_rowwise": 0.13122707605361938, "g_quantize_rowwise": 0.5038455128669739, "w_quantize_rowwise": 0.03061816096305847, "w_quantize_colwise_transpose": 0.24301931262016296, "w_quantize_global": 0.09343400597572327, "w_quantize_global_transpose": 0.10178983211517334, "cast_x": 0.25383010506629944, "cast_g": 0.9955987334251404, "cast_w": 0.05569681525230408, "time_standard": 8.896339684724808, "time_rowwise": 6.779726594686508, "time_global": 6.662826985120773}
{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.8433389961719513, "standard_gw": 2.861086279153824, "standard_gx": 3.0227042734622955, "rowwise_fwd": 1.4057457447052002, "rowwise_bwd": 1.6565024852752686, "global_fwd": 1.3475008308887482, "global_bwd": 1.6247481107711792, "x_quantize_rowwise": 0.5038045346736908, "g_quantize_rowwise": 0.13130158185958862, "w_quantize_rowwise": 0.03298744559288025, "w_quantize_colwise_transpose": 0.23539364337921143, "w_quantize_global": 0.09393692016601562, "w_quantize_global_transpose": 0.10208785533905029, "cast_x": 0.9952597320079803, "cast_g": 0.25385990738868713, "cast_w": 0.05589798092842102, "time_standard": 8.72712954878807, "time_rowwise": 6.826821714639664, "time_global": 6.664466112852097}
{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.449159234762192, "standard_gw": 6.384443491697311, "standard_gx": 5.543403327465057, "rowwise_fwd": 3.3065229654312134, "rowwise_bwd": 2.6249960064888, "global_fwd": 3.2497718930244446, "global_bwd": 2.6061534881591797, "x_quantize_rowwise": 0.25821104645729065, "g_quantize_rowwise": 0.9981803596019745, "w_quantize_rowwise": 0.030606985092163086, "w_quantize_colwise_transpose": 0.24094432592391968, "w_quantize_global": 0.09358301758766174, "w_quantize_global_transpose": 0.10264664888381958, "cast_x": 0.5018562078475952, "cast_g": 1.9840113818645477, "cast_w": 0.05584210157394409, "time_standard": 18.37700605392456, "time_rowwise": 13.843905180692673, "time_global": 13.692989945411682}
{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.508493632078171, "standard_gw": 5.689781159162521, "standard_gx": 6.020743399858475, "rowwise_fwd": 2.640843391418457, "rowwise_bwd": 3.3075474202632904, "global_fwd": 2.605751156806946, "global_bwd": 3.2674334943294525, "x_quantize_rowwise": 0.9983181953430176, "g_quantize_rowwise": 0.25597214698791504, "w_quantize_rowwise": 0.03277510404586792, "w_quantize_colwise_transpose": 0.23587048053741455, "w_quantize_global": 0.09367987513542175, "w_quantize_global_transpose": 0.10236725211143494, "cast_x": 1.9848868250846863, "cast_g": 0.5010329186916351, "cast_w": 0.055771321058273315, "time_standard": 17.219018191099167, "time_rowwise": 13.161107897758484, "time_global": 13.013303279876709}
{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 12.975204735994339, "standard_gw": 11.424731463193893, "standard_gx": 11.05477660894394, "rowwise_fwd": 6.623122841119766, "rowwise_bwd": 5.253363400697708, "global_fwd": 6.506938487291336, "global_bwd": 5.211424082517624, "x_quantize_rowwise": 0.5057789385318756, "g_quantize_rowwise": 1.9870363175868988, "w_quantize_rowwise": 0.030517578125, "w_quantize_colwise_transpose": 0.24361908435821533, "w_quantize_global": 0.09384006261825562, "w_quantize_global_transpose": 0.10285153985023499, "cast_x": 0.9967051446437836, "cast_g": 3.9620958268642426, "cast_w": 0.05599111318588257, "time_standard": 35.45471280813217, "time_rowwise": 26.068169623613358, "time_global": 25.83260089159012}
{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.05555146932602, "standard_gw": 11.32136583328247, "standard_gx": 12.035444378852844, "rowwise_fwd": 5.243867635726929, "rowwise_bwd": 6.622854620218277, "global_fwd": 5.209986120462418, "global_bwd": 6.507329642772675, "x_quantize_rowwise": 1.9862838089466095, "g_quantize_rowwise": 0.506080687046051, "w_quantize_rowwise": 0.03318488597869873, "w_quantize_colwise_transpose": 0.23682788014411926, "w_quantize_global": 0.09349361062049866, "w_quantize_global_transpose": 0.1023709774017334, "cast_x": 3.962486982345581, "cast_g": 0.9956248104572296, "cast_w": 0.05572289228439331, "time_standard": 34.412361681461334, "time_rowwise": 25.950465351343155, "time_global": 25.726910680532455}

View File

@ -1,20 +0,0 @@
{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 3.807276487350464, "my_standard": 4.196919500827789, "standard_compiled": 3.771558403968811, "sb": 3.5132691264152527}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 7.215872406959534, "my_standard": 7.991522550582886, "standard_compiled": 7.241688668727875, "sb": 6.581142544746399}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 14.26444947719574, "my_standard": 15.685759484767914, "standard_compiled": 14.251746237277985, "sb": 12.735314667224884}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 28.49559485912323, "my_standard": 31.26966953277588, "standard_compiled": 28.414390981197357, "sb": 25.319166481494904}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 5.887262523174286, "my_standard": 6.132654845714569, "standard_compiled": 5.902409553527832, "sb": 4.947789013385773}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 11.14131510257721, "my_standard": 12.859955430030823, "standard_compiled": 11.133037507534027, "sb": 9.303092956542969}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 22.193141281604767, "my_standard": 25.66336840391159, "standard_compiled": 22.22583442926407, "sb": 18.285617232322693}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 44.23898458480835, "my_standard": 51.30268633365631, "standard_compiled": 44.08355802297592, "sb": 35.999126732349396}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 6.938718259334564, "my_standard": 7.269218564033508, "standard_compiled": 6.94604218006134, "sb": 5.764961242675781}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 13.04878294467926, "my_standard": 13.742901384830475, "standard_compiled": 13.011425733566284, "sb": 10.774023830890656}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 26.738539338111877, "my_standard": 27.739346027374268, "standard_compiled": 26.75659954547882, "sb": 21.882005035877228}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 51.905401051044464, "my_standard": 53.98637801408768, "standard_compiled": 51.8316924571991, "sb": 41.67725890874863}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 9.233824908733368, "my_standard": 9.619377553462982, "standard_compiled": 9.214423596858978, "sb": 7.557623088359833}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 17.324909567832947, "my_standard": 17.996780574321747, "standard_compiled": 17.29544997215271, "sb": 14.035224914550781}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 35.51657497882843, "my_standard": 36.674730479717255, "standard_compiled": 35.43049842119217, "sb": 28.38330715894699}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 69.0087378025055, "my_standard": 71.56594842672348, "standard_compiled": 68.82885098457336, "sb": 54.01633679866791}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 12.590140104293823, "my_standard": 13.106442987918854, "standard_compiled": 12.606985867023468, "sb": 10.286301374435425}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 24.830535054206848, "my_standard": 25.563716888427734, "standard_compiled": 24.895809590816498, "sb": 19.559212028980255}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 49.55078661441803, "my_standard": 51.16480588912964, "standard_compiled": 49.739621579647064, "sb": 38.29141706228256}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 98.36294502019882, "my_standard": 102.69322991371155, "standard_compiled": 98.76712411642075, "sb": 75.88706165552139}

View File

@ -1,20 +0,0 @@
{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 4.91420179605484, "my_standard": 5.577877163887024, "standard_compiled": 4.810944199562073, "sb": 4.512995481491089}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 8.876129984855652, "my_standard": 10.154612362384796, "standard_compiled": 8.820965886116028, "sb": 8.367843925952911}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 17.47015118598938, "my_standard": 19.857674837112427, "standard_compiled": 17.338842153549194, "sb": 15.992552042007446}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 34.824438393116, "my_standard": 39.499424397945404, "standard_compiled": 34.56207364797592, "sb": 31.573951244354248}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.342606782913208, "my_standard": 7.9323723912239075, "standard_compiled": 7.279552519321442, "sb": 6.395488977432251}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 13.69999349117279, "my_standard": 16.0503089427948, "standard_compiled": 13.603456318378448, "sb": 11.813104152679443}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 29.557034373283386, "my_standard": 34.2303067445755, "standard_compiled": 29.382556676864624, "sb": 22.882774472236633}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 53.629085421562195, "my_standard": 63.07622790336609, "standard_compiled": 53.33048850297928, "sb": 44.76426541805267}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 8.81417840719223, "my_standard": 9.477965533733368, "standard_compiled": 8.73943418264389, "sb": 7.479414343833923}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 16.242466866970062, "my_standard": 17.616644501686096, "standard_compiled": 16.14125818014145, "sb": 13.665586709976196}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 32.429613173007965, "my_standard": 34.80646014213562, "standard_compiled": 32.319076359272, "sb": 27.123987674713135}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 62.85770237445831, "my_standard": 67.55391508340836, "standard_compiled": 62.453076243400574, "sb": 51.53566598892212}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 11.585861444473267, "my_standard": 12.565858662128448, "standard_compiled": 11.504307389259338, "sb": 9.657211601734161}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 21.261662244796753, "my_standard": 22.771358489990234, "standard_compiled": 21.12410217523575, "sb": 17.64291524887085}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 42.85307973623276, "my_standard": 45.70870101451874, "standard_compiled": 42.57970303297043, "sb": 34.918561577796936}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 83.56057852506638, "my_standard": 89.11971747875214, "standard_compiled": 83.05662125349045, "sb": 66.32210314273834}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 15.7279372215271, "my_standard": 16.854502260684967, "standard_compiled": 15.655294060707092, "sb": 13.228952884674072}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 30.42648732662201, "my_standard": 32.26502239704132, "standard_compiled": 30.239209532737732, "sb": 24.354808032512665}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 60.779355466365814, "my_standard": 64.11923468112946, "standard_compiled": 60.89268624782562, "sb": 46.91776633262634}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 119.93677169084549, "my_standard": 128.19699943065643, "standard_compiled": 120.20225822925568, "sb": 92.3452153801918}

View File

@ -1,23 +0,0 @@
{"repeat": 32, "batch_size": 16384, "dim": 1024, "standard": 5.171686410903931, "my_standard": 5.839601159095764, "standard_compiled": 5.032263696193695, "sb": 4.89344447851181}
{"repeat": 32, "batch_size": 32768, "dim": 1024, "standard": 9.605035185813904, "my_standard": 10.910414159297943, "standard_compiled": 9.230785071849823, "sb": 9.128175675868988}
{"repeat": 32, "batch_size": 65536, "dim": 1024, "standard": 18.802084028720856, "my_standard": 21.311581134796143, "standard_compiled": 18.105976283550262, "sb": 17.489850521087646}
{"repeat": 32, "batch_size": 131072, "dim": 1024, "standard": 37.49683499336243, "my_standard": 42.40527004003525, "standard_compiled": 36.13145649433136, "sb": 34.58733111619949}
{"repeat": 32, "batch_size": 16384, "dim": 1280, "standard": 7.709823548793793, "my_standard": 8.290477097034454, "standard_compiled": 7.564418017864227, "sb": 6.8823546171188354}
{"repeat": 32, "batch_size": 32768, "dim": 1280, "standard": 14.64156061410904, "my_standard": 16.996942460536957, "standard_compiled": 14.4081711769104, "sb": 12.761622667312622}
{"repeat": 32, "batch_size": 65536, "dim": 1280, "standard": 31.40200674533844, "my_standard": 36.074504256248474, "standard_compiled": 30.981406569480896, "sb": 24.76389706134796}
{"repeat": 32, "batch_size": 131072, "dim": 1280, "standard": 56.93405121564865, "my_standard": 66.35250151157379, "standard_compiled": 56.07586354017258, "sb": 48.49743843078613}
{"repeat": 32, "batch_size": 16384, "dim": 1408, "standard": 9.188003838062286, "my_standard": 9.84550267457962, "standard_compiled": 9.006097912788391, "sb": 7.9473331570625305}
{"repeat": 32, "batch_size": 32768, "dim": 1408, "standard": 17.268165946006775, "my_standard": 18.64910125732422, "standard_compiled": 16.983114182949066, "sb": 14.70106840133667}
{"repeat": 32, "batch_size": 65536, "dim": 1408, "standard": 34.39047932624817, "my_standard": 36.69705241918564, "standard_compiled": 33.8401272892952, "sb": 29.188089072704315}
{"repeat": 32, "batch_size": 131072, "dim": 1408, "standard": 66.70494377613068, "my_standard": 71.27603143453598, "standard_compiled": 65.56134670972824, "sb": 55.6538850069046}
{"repeat": 32, "batch_size": 16384, "dim": 1664, "standard": 12.10707426071167, "my_standard": 12.931793928146362, "standard_compiled": 11.76995038986206, "sb": 10.228671133518219}
{"repeat": 32, "batch_size": 32768, "dim": 1664, "standard": 22.5130096077919, "my_standard": 23.962542414665222, "standard_compiled": 21.997176110744476, "sb": 18.89890432357788}
{"repeat": 32, "batch_size": 65536, "dim": 1664, "standard": 45.210108160972595, "my_standard": 47.94136434793472, "standard_compiled": 44.2262664437294, "sb": 37.37735003232956}
{"repeat": 32, "batch_size": 131072, "dim": 1664, "standard": 88.1955549120903, "my_standard": 93.6831533908844, "standard_compiled": 86.33609116077423, "sb": 71.23208791017532}
{"repeat": 32, "batch_size": 16384, "dim": 2048, "standard": 16.538940370082855, "my_standard": 17.607316374778748, "standard_compiled": 16.108587384223938, "sb": 14.030493795871735}
{"repeat": 32, "batch_size": 32768, "dim": 2048, "standard": 31.795650720596313, "my_standard": 33.57230871915817, "standard_compiled": 31.04180097579956, "sb": 25.971196591854095}
{"repeat": 32, "batch_size": 65536, "dim": 2048, "standard": 63.021354377269745, "my_standard": 66.8477788567543, "standard_compiled": 61.682507395744324, "sb": 50.138771533966064}
{"repeat": 32, "batch_size": 131072, "dim": 2048, "standard": 125.17062574625015, "my_standard": 133.60925763845444, "standard_compiled": 122.21191823482513, "sb": 98.40084612369537}
{"repeat": 32, "batch_size": 16384, "dim": 4096, "standard": 57.31645971536636, "my_standard": 60.84543466567993, "standard_compiled": 55.78199774026871, "sb": 45.43223977088928}
{"repeat": 32, "batch_size": 32768, "dim": 4096, "standard": 111.80306226015091, "my_standard": 119.0284714102745, "standard_compiled": 108.91905426979065, "sb": 85.4572057723999}
{"repeat": 32, "batch_size": 65536, "dim": 4096, "standard": 220.4471081495285, "my_standard": 233.0927476286888, "standard_compiled": 214.26431089639664, "sb": 163.30372542142868}

View File

@ -1,64 +0,0 @@
import time
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear, StandardLinear
def construct_model(dim, layers, module):
modules = []
for _ in range(layers):
modules.append(module(dim, 4*dim))
modules.append(module(4*dim, dim))
return nn.Sequential(*modules).cuda().train()
def get_time(model, x, name):
for _ in range(repeat // 2):
#with torch.cuda.amp.autocast():
out = model(x)
#(2**16 * out.pow(2).mean()).backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
# with torch.cuda.amp.autocast():
out = model(x)
#(2**16 * out.pow(2).mean()).backward()
torch.cuda.synchronize()
end = time.time()
print(f"time {name}: {(end - start) / repeat * 1000:.3f} ms")
if __name__ == '__main__':
torch.manual_seed(0)
# hparams
repeat = 16
dim=2048
layers =4
batch_size = 2
sequence_length = 2**15
# construct models
standard = construct_model(dim, layers, nn.Linear).half()
my_standard = construct_model(dim, layers, StandardLinear).half()
switchback = construct_model(dim, layers, SwitchBackLinear).half()
switchback_global = construct_model(dim, layers, SwitchBackGlobalLinear).half()
#bnb_8bitmixed = construct_model(dim, layers, bnn.Linear8bitLt)
# simulate forward pass
x = torch.randn(batch_size * sequence_length, dim, dtype=torch.float16).cuda()
# get time for forward and backward
get_time(standard, x, "standard")
get_time(my_standard, x, "my_standard")
get_time(switchback, x, "switchback")
get_time(switchback_global, x, "switchback_global")
#get_time(bnb_8bitmixed, x, "bnb_8bitmixed")

View File

@ -1,166 +0,0 @@
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
import time
if __name__ == '__main__':
print('Startin')
for dim in [1024, 1280, 1408, 1664, 2048]:
for batch in [2**14, 2**15, 2**16, 2**17]:
if dim != 4096 or batch != 2**17:
continue
x1 = torch.randn(batch, dim).cuda().requires_grad_(True)
d = 2
standard = torch.nn.Sequential(
torch.nn.Linear(dim, 4 * dim),
torch.nn.GELU(),
torch.nn.Linear(4 * dim, dim),
).cuda()
my_standard = torch.nn.Sequential(
StandardLinear(dim, 4 * dim),
torch.nn.GELU(),
StandardLinear(4 * dim, dim),
).cuda()
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
sb = torch.nn.Sequential(
SwitchBackGlobalLinear(dim, 4 * dim),
torch.nn.GELU(),
SwitchBackGlobalLinear(4 * dim, dim),
).cuda()
standard_compiled = torch.compile(standard)
print('Model part 2')
repeat = 32
info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim}
# k = 'standard'
# for _ in range(repeat // 2):
# with torch.cuda.amp.autocast():
# out_standard = standard(x1)
# ((2 ** 16) * out_standard).abs().mean().backward()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# with torch.cuda.amp.autocast():
# out_standard = standard(x1)
# ((2 ** 16) * out_standard).abs().mean().backward()
# torch.cuda.synchronize()
# end = time.time()
# ms = (end - start) / repeat * 1000
# print(f"time {k}: {ms:.3f} ms")
# info[k] = ms
# x1.grad.zero_()
# k = 'my_standard'
# for _ in range(repeat // 2):
# with torch.cuda.amp.autocast():
# out_my_standard = my_standard(x1)
# ((2 ** 16) * out_my_standard).abs().mean().backward()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# with torch.cuda.amp.autocast():
# out_my_standard = my_standard(x1)
# ((2 ** 16) * out_my_standard).abs().mean().backward()
# torch.cuda.synchronize()
# end = time.time()
# ms = (end - start) / repeat * 1000
# print(f"time {k}: {ms:.3f} ms")
# info[k] = ms
# x1.grad.zero_()
# k = 'standard_compiled'
# for _ in range(repeat // 2):
# with torch.cuda.amp.autocast():
# out_standard_compiled = standard_compiled(x1)
# ((2 ** 16) * out_standard_compiled).abs().mean().backward()
# torch.cuda.synchronize()
# start = time.time()
# for _ in range(repeat):
# with torch.cuda.amp.autocast():
# out_standard_compiled = standard_compiled(x1)
# ((2 ** 16) * out_standard_compiled).abs().mean().backward()
# torch.cuda.synchronize()
# end = time.time()
# ms = (end - start) / repeat * 1000
# print(f"time {k}: {ms:.3f} ms")
# info[k] = ms
# x1.grad.zero_()
k = 'sb'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
info_json = json.dumps(info)
with open("tests/triton_tests/info_mlp_autocast.jsonl", "a") as file:
file.write(info_json + "\n")
#exit()
# err_fused = (out_standard - out_fused).abs().mean()
# err_sb = (out_standard - out_sb).abs().mean()
# print('OUT', err_fused, err_sb)
# err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean()
# err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean()
# print('GW2', err_fused, err_sb)
# err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean()
# err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean()
# print('GW1', err_fused, err_sb)
# err_fused = (x1.grad - x2.grad).abs().mean()
# err_sb = (x1.grad - x3.grad).abs().mean()
# print('GX1', err_fused, err_sb)
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.

View File

@ -1,165 +0,0 @@
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, StandardLinear
import time
if __name__ == '__main__':
print('Startin')
for dim in [1024, 1280, 1408, 1664, 2048]:
for batch in [2**14, 2**15, 2**16, 2**17]:
x1 = torch.randn(batch, dim).cuda().requires_grad_(True)
d = 2
standard = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
torch.nn.Linear(dim, 4 * dim),
torch.nn.GELU(),
torch.nn.Linear(4 * dim, dim),
).cuda()
my_standard = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
StandardLinear(dim, 4 * dim),
torch.nn.GELU(),
StandardLinear(4 * dim, dim),
).cuda()
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
sb = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
SwitchBackGlobalLinear(dim, 4 * dim),
torch.nn.GELU(),
SwitchBackGlobalLinear(4 * dim, dim),
).cuda()
standard_compiled = torch.compile(standard)
print('Model part 2')
repeat = 32
info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim}
k = 'standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'my_standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'standard_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'sb'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
info_json = json.dumps(info)
with open("tests/triton_tests/info_mlp_autocast_ln.jsonl", "a") as file:
file.write(info_json + "\n")
#exit()
# err_fused = (out_standard - out_fused).abs().mean()
# err_sb = (out_standard - out_sb).abs().mean()
# print('OUT', err_fused, err_sb)
# err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean()
# err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean()
# print('GW2', err_fused, err_sb)
# err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean()
# err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean()
# print('GW1', err_fused, err_sb)
# err_fused = (x1.grad - x2.grad).abs().mean()
# err_sb = (x1.grad - x3.grad).abs().mean()
# print('GX1', err_fused, err_sb)
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 51 KiB

View File

@ -1,69 +0,0 @@
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib.gridspec as gridspec
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(6,3.5))
gs = gridspec.GridSpec(1, 1)
rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True)
ax = fig.add_subplot(gs[0, 0])
# now plot the % speedup for different batch sizes
for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]:
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
all_ys.append(ys)
color = cmap(j * 0.25)
real_ys = [100 * all_ys[1][i] / all_ys[0][i] for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% time occupied by quantize ops', fontsize=12)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048, 4096])
ax.set_xticklabels([1024, 2048, 4096])
ax.set_xticks([], minor=True)
#ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
plt.savefig('tests/triton_tests/plot2.pdf', bbox_inches='tight')

Binary file not shown.

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

View File

@ -1,193 +0,0 @@
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib.lines as mlines
import matplotlib.gridspec as gridspec
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
gs = gridspec.GridSpec(1, 3)
rdf1 = pd.read_json('tests/triton_tests/info_mlp_autocast_ln.jsonl', lines=True)
ax = fig.add_subplot(gs[0, 0])
# now plot the % speedup for different batch sizes
for j, batch_size in enumerate([2**15, 2**17]):#, 2**15, 2**17, 2**17]):
all_xs, all_ys = {}, {}
for k, marker, ls, color, name in [
('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)'),
#('standard', 'o', '-', 'C1', 'standard (total time)'),
('my_standard', 'o', '-', 'C2', 'my standard (total time)'),
('sb', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf1[rdf1.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048]:
df_ = df[df.dim == embed_dim]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_)
all_xs[k] = xs
all_ys[k] = ys
#ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5)
color= cmap(float(j))
speedup_over_my_standard = [-100 * (all_ys['sb'][i] - all_ys['my_standard'][i]) / all_ys['my_standard'][i] for i in range(len(all_ys['my_standard']))]
speedup_over_compile = [-100 * (all_ys['sb'][i] - all_ys['standard_compiled'][i]) / all_ys['standard_compiled'][i] for i in range(len(all_ys['standard_compiled']))]
ax.plot(xs, speedup_over_my_standard, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5)
ax.plot(xs, speedup_over_compile, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5, linestyle='--')
#ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=12)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048])
ax.set_xticklabels([1024, 2048])
ax.set_xticks([], minor=True)
ax.set_title('MLP Block', fontsize=10, loc='left', y=1.07, pad=-20)
##########################################
rdf2 = pd.read_json('tests/triton_tests/attn_info_ln.jsonl', lines=True)
ax = fig.add_subplot(gs[0, 1])
for j, batch_size in enumerate([2**15, 2**17]):#, 2**15, 2**17, 2**17]):
all_xs, all_ys = {}, {}
for k, marker, ls, color, name in [
('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)'),
#('standard', 'o', '-', 'C1', 'standard (total time)'),
('my_standard', 'o', '-', 'C2', 'my standard (total time)'),
('sb', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf2[rdf2.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048]:
df_ = df[df.dim == embed_dim]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_)
all_xs[k] = xs
all_ys[k] = ys
#ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5)
color= cmap(float(j))
speedup_over_my_standard = [-100 * (all_ys['sb'][i] - all_ys['my_standard'][i]) / all_ys['my_standard'][i] for i in range(len(all_ys['my_standard']))]
speedup_over_compile = [-100 * (all_ys['sb'][i] - all_ys['standard_compiled'][i]) / all_ys['standard_compiled'][i] for i in range(len(all_ys['standard_compiled']))]
ax.plot(xs, speedup_over_my_standard, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5)
ax.plot(xs, speedup_over_compile, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5, linestyle='--')
speedup_compiled = mlines.Line2D([], [], linestyle='--', color='gray', label='speedup over compiled')
speedup_baseline = mlines.Line2D([], [], linestyle='-', color='gray', label='speedup over baseline')
batch_size_4 = mlines.Line2D([], [], linestyle='-', color=cmap(0.), label=f'batch = {int(2**15 // 256)}, sequence = {256}')
batch_size_8 = mlines.Line2D([], [], linestyle='-', color=cmap(1.), label=f'batch = {int(2**17 / 256)} sequence = {256}')
# Create the legend with the proxy artists
# adjust plots so that they dont get squished by putting the legend under both
plt.subplots_adjust(left=0.2)
plt.subplots_adjust(right=0.8)
fig.legend(handles=[speedup_compiled, speedup_baseline, batch_size_4, batch_size_8], ncol=2, loc='upper center', bbox_to_anchor=(0.35, 0.255))
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=12)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048])
ax.set_xticklabels([1024, 2048])
ax.set_xticks([], minor=True)
ax.set_title('Attention Block', fontsize=10, loc='left', y=1.07, pad=-20)
##########################################
ax = fig.add_subplot(gs[0, 2])
for j, batch_size in enumerate([2**15]):#, 2**15, 2**17, 2**17]):
all_xs, all_ys = {}, {}
for k, marker, ls, color, name, b in [
('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)', False),
('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)', True),
#('standard', 'o', '-', 'C1', 'standard (total time)'),
#('my_standard', 'o', '-', 'C2', 'my standard (total time)'),
('attn', 'o', '-', 'C4', 'SwitchBack int8 (total time)', True),
]:
rdf = rdf2 if b else rdf1
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048]:
df_ = df[df.dim == embed_dim]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_)
all_xs[k + str(int(b))] = xs
all_ys[k + str(int(b))] = ys
#ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5)
print(all_ys.keys())
all_ys['standard_compiled'] = [x + y for x, y in zip(all_ys['standard_compiled0'], all_ys['standard_compiled1'])]
speedup_over_my_standard = [100 * all_ys['attn1'][i] / (all_ys['standard_compiled'][i] + all_ys['attn1'][i]) for i in range(len(all_ys['standard_compiled']))]
ax.plot(xs, speedup_over_my_standard, color='gold', label=r'% time occupied by attention', marker='H', markersize=8)
speedup_over_my_standard = [100 * all_ys['standard_compiled1'][i] / (all_ys['standard_compiled0'][i] + all_ys['standard_compiled1'][i]) for i in range(len(all_ys['standard_compiled']))]
ax.plot(xs, speedup_over_my_standard, color='indianred', label=r'% time occupied by attention block', marker='P', markersize=8)
ax.legend(bbox_to_anchor=(1.02, -0.27))
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% time', fontsize=12)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048])
ax.set_xticklabels([1024, 2048])
ax.set_xticks([], minor=True)
plt.savefig('tests/triton_tests/plot3.pdf', bbox_inches='tight')

View File

@ -1,43 +0,0 @@
import time
import torch
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
# 256 * 256 * 4096 _> 0.7
# 256 * 128 * 8192 -> 10
if __name__ == '__main__':
torch.manual_seed(0)
# hparams
repeat = 16
dim=8192
layers = 4
batch_size = 256 * 128
# simulate forward pass
x = torch.randn(batch_size, dim, dtype=torch.float16).cuda()
for _ in range(repeat // 2):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
end = time.time()
print(f"time: {(end - start) / repeat * 1000:.3f} ms")