mem efficient"

This commit is contained in:
Mitchell Wortsman 2023-04-08 19:34:18 +00:00
parent eb6c53cf55
commit da524d97c9
2 changed files with 119 additions and 2 deletions

View File

@ -3,6 +3,7 @@ import torch.nn as nn
import time
from functools import partial
from .triton_utils.v0.dequantize_rowwise import dequantize_rowwise
from .triton_utils.v0.quantize_rowwise import quantize_rowwise
from .triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
@ -97,6 +98,56 @@ class _switchback_vectorrize(torch.autograd.Function):
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
X_3D_sz = X_3D.size()
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
del X
W_int8, state_W = quantize_global(W)
print('in mem eff backward.')
# save for backward.
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
G_3D_sz = G_3D.size()
grad_X = grad_W = grad_bias = None
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
if ctx.needs_input_grad[1]:
real_X = dequantize_rowwise(X_int8, state_X)
del X_int8
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
del real_X
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)
return grad_X, grad_W, grad_bias
class SwitchBackLinear(nn.Linear):
def __init__(
@ -106,7 +157,8 @@ class SwitchBackLinear(nn.Linear):
bias: bool = True,
device=None,
dtype=None,
vectorize: bool = False
vectorize: bool = False,
mem_efficient : bool = False,
):
super().__init__(in_features, out_features, bias, device, dtype)
@ -114,8 +166,14 @@ class SwitchBackLinear(nn.Linear):
self.vectorize = vectorize
if self.vectorize:
self._fn = _switchback_vectorrize
if mem_efficient:
print('mem efficient is not supported for vectorize mode.')
exit(1)
else:
self._fn = _switchback_global
if mem_efficient:
self._fn = _switchback_global_mem_efficient
else:
self._fn = _switchback_global
def prepare_for_eval(self):
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
@ -158,6 +216,7 @@ class SwitchBackLinear(nn.Linear):
).view(*x.size()[:-1], -1)
SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vectorize=False, mem_efficient=True)
SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True)
# This is just the standard linear function.

View File

@ -0,0 +1,58 @@
import math
import torch
import time
import triton
import triton.language as tl
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@triton.autotune(
configs=[
triton.Config({}, num_stages=1, num_warps=8),
triton.Config({}, num_stages=2, num_warps=8),
triton.Config({}, num_stages=4, num_warps=8),
triton.Config({}, num_stages=8, num_warps=8),
triton.Config({}, num_stages=1),
triton.Config({}, num_stages=2),
triton.Config({}, num_stages=4),
triton.Config({}, num_stages=8),
triton.Config({}, num_warps=1),
triton.Config({}, num_warps=2),
triton.Config({}, num_warps=4),
triton.Config({}, num_warps=8),
],
key=['n_elements']
)
@triton.jit
def _dequantize_rowwise(
x_ptr,
state_x,
output_ptr,
inv_127,
n_elements,
BLOCK_SIZE: tl.constexpr,
P2: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
arange = tl.arange(0, P2)
offsets = block_start + arange
row_mask = arange < BLOCK_SIZE
x = tl.load(x_ptr + offsets, mask=row_mask)
max_val = tl.load(state_x + pid)
output = max_val * x * inv_127
tl.store(output_ptr + offsets, output, mask=row_mask)
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
assert x.is_cuda and output.is_cuda
n_elements = output.numel()
grid = lambda meta: (x.shape[0],)
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
return output