mem efficient"
This commit is contained in:
parent
eb6c53cf55
commit
da524d97c9
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from .triton_utils.v0.dequantize_rowwise import dequantize_rowwise
|
||||||
from .triton_utils.v0.quantize_rowwise import quantize_rowwise
|
from .triton_utils.v0.quantize_rowwise import quantize_rowwise
|
||||||
from .triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
from .triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||||
from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||||
|
@ -98,6 +99,56 @@ class _switchback_vectorrize(torch.autograd.Function):
|
||||||
|
|
||||||
return grad_X, grad_W, grad_bias
|
return grad_X, grad_W, grad_bias
|
||||||
|
|
||||||
|
class _switchback_global_mem_efficient(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, X_3D, W, bias):
|
||||||
|
# reshape input to [N * L, D]
|
||||||
|
X = X_3D.view(-1, X_3D.size(-1))
|
||||||
|
X_3D_sz = X_3D.size()
|
||||||
|
|
||||||
|
# rowwise quantize for X, global quantize for W
|
||||||
|
X_int8, state_X = quantize_rowwise(X)
|
||||||
|
del X
|
||||||
|
W_int8, state_W = quantize_global(W)
|
||||||
|
|
||||||
|
print('in mem eff backward.')
|
||||||
|
|
||||||
|
# save for backward.
|
||||||
|
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
|
||||||
|
|
||||||
|
# matmult, fused dequant and add bias
|
||||||
|
# call "mixed" because we are mixing rowwise quantized and global quantized
|
||||||
|
return int8_matmul_mixed_dequanitze(
|
||||||
|
X_int8, W_int8.t(), state_X, state_W, bias
|
||||||
|
).view(*X_3D_sz[:-1], -1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, G_3D):
|
||||||
|
# reshape input to [N_out * L, D]
|
||||||
|
G = G_3D.reshape(-1, G_3D.size(-1))
|
||||||
|
G_3D_sz = G_3D.size()
|
||||||
|
|
||||||
|
grad_X = grad_W = grad_bias = None
|
||||||
|
|
||||||
|
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
|
||||||
|
if ctx.needs_input_grad[1]:
|
||||||
|
real_X = dequantize_rowwise(X_int8, state_X)
|
||||||
|
del X_int8
|
||||||
|
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
|
||||||
|
del real_X
|
||||||
|
if ctx.needs_input_grad[2]:
|
||||||
|
grad_bias = G.sum(dim=0)
|
||||||
|
if ctx.needs_input_grad[0]:
|
||||||
|
G_int8, state_G = quantize_rowwise(G)
|
||||||
|
del G
|
||||||
|
W_int8 = W_int8.t().contiguous()
|
||||||
|
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
|
||||||
|
*G_3D_sz[:-1], -1
|
||||||
|
)
|
||||||
|
|
||||||
|
return grad_X, grad_W, grad_bias
|
||||||
|
|
||||||
class SwitchBackLinear(nn.Linear):
|
class SwitchBackLinear(nn.Linear):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -106,7 +157,8 @@ class SwitchBackLinear(nn.Linear):
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
vectorize: bool = False
|
vectorize: bool = False,
|
||||||
|
mem_efficient : bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
|
||||||
|
@ -114,8 +166,14 @@ class SwitchBackLinear(nn.Linear):
|
||||||
self.vectorize = vectorize
|
self.vectorize = vectorize
|
||||||
if self.vectorize:
|
if self.vectorize:
|
||||||
self._fn = _switchback_vectorrize
|
self._fn = _switchback_vectorrize
|
||||||
|
if mem_efficient:
|
||||||
|
print('mem efficient is not supported for vectorize mode.')
|
||||||
|
exit(1)
|
||||||
else:
|
else:
|
||||||
self._fn = _switchback_global
|
if mem_efficient:
|
||||||
|
self._fn = _switchback_global_mem_efficient
|
||||||
|
else:
|
||||||
|
self._fn = _switchback_global
|
||||||
|
|
||||||
def prepare_for_eval(self):
|
def prepare_for_eval(self):
|
||||||
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
|
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
|
||||||
|
@ -158,6 +216,7 @@ class SwitchBackLinear(nn.Linear):
|
||||||
).view(*x.size()[:-1], -1)
|
).view(*x.size()[:-1], -1)
|
||||||
|
|
||||||
SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False)
|
SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False)
|
||||||
|
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vectorize=False, mem_efficient=True)
|
||||||
SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True)
|
SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True)
|
||||||
|
|
||||||
# This is just the standard linear function.
|
# This is just the standard linear function.
|
||||||
|
|
58
bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py
Normal file
58
bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||||
|
|
||||||
|
# rowwise quantize
|
||||||
|
|
||||||
|
# TODO: autotune this better.
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config({}, num_stages=1, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=2, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=4, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=8, num_warps=8),
|
||||||
|
triton.Config({}, num_stages=1),
|
||||||
|
triton.Config({}, num_stages=2),
|
||||||
|
triton.Config({}, num_stages=4),
|
||||||
|
triton.Config({}, num_stages=8),
|
||||||
|
triton.Config({}, num_warps=1),
|
||||||
|
triton.Config({}, num_warps=2),
|
||||||
|
triton.Config({}, num_warps=4),
|
||||||
|
triton.Config({}, num_warps=8),
|
||||||
|
],
|
||||||
|
key=['n_elements']
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _dequantize_rowwise(
|
||||||
|
x_ptr,
|
||||||
|
state_x,
|
||||||
|
output_ptr,
|
||||||
|
inv_127,
|
||||||
|
n_elements,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
P2: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
arange = tl.arange(0, P2)
|
||||||
|
offsets = block_start + arange
|
||||||
|
row_mask = arange < BLOCK_SIZE
|
||||||
|
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||||
|
max_val = tl.load(state_x + pid)
|
||||||
|
output = max_val * x * inv_127
|
||||||
|
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||||
|
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||||
|
|
||||||
|
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||||
|
|
||||||
|
assert x.is_cuda and output.is_cuda
|
||||||
|
n_elements = output.numel()
|
||||||
|
grid = lambda meta: (x.shape[0],)
|
||||||
|
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||||
|
return output
|
Loading…
Reference in New Issue
Block a user