mem efficient"
This commit is contained in:
parent
eb6c53cf55
commit
da524d97c9
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
|||
import time
|
||||
from functools import partial
|
||||
|
||||
from .triton_utils.v0.dequantize_rowwise import dequantize_rowwise
|
||||
from .triton_utils.v0.quantize_rowwise import quantize_rowwise
|
||||
from .triton_utils.v0.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
||||
from .triton_utils.v0.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
||||
|
@ -97,6 +98,56 @@ class _switchback_vectorrize(torch.autograd.Function):
|
|||
grad_bias = G.sum(dim=0)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
class _switchback_global_mem_efficient(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, X_3D, W, bias):
|
||||
# reshape input to [N * L, D]
|
||||
X = X_3D.view(-1, X_3D.size(-1))
|
||||
X_3D_sz = X_3D.size()
|
||||
|
||||
# rowwise quantize for X, global quantize for W
|
||||
X_int8, state_X = quantize_rowwise(X)
|
||||
del X
|
||||
W_int8, state_W = quantize_global(W)
|
||||
|
||||
print('in mem eff backward.')
|
||||
|
||||
# save for backward.
|
||||
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
|
||||
|
||||
# matmult, fused dequant and add bias
|
||||
# call "mixed" because we are mixing rowwise quantized and global quantized
|
||||
return int8_matmul_mixed_dequanitze(
|
||||
X_int8, W_int8.t(), state_X, state_W, bias
|
||||
).view(*X_3D_sz[:-1], -1)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, G_3D):
|
||||
# reshape input to [N_out * L, D]
|
||||
G = G_3D.reshape(-1, G_3D.size(-1))
|
||||
G_3D_sz = G_3D.size()
|
||||
|
||||
grad_X = grad_W = grad_bias = None
|
||||
|
||||
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
|
||||
if ctx.needs_input_grad[1]:
|
||||
real_X = dequantize_rowwise(X_int8, state_X)
|
||||
del X_int8
|
||||
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
|
||||
del real_X
|
||||
if ctx.needs_input_grad[2]:
|
||||
grad_bias = G.sum(dim=0)
|
||||
if ctx.needs_input_grad[0]:
|
||||
G_int8, state_G = quantize_rowwise(G)
|
||||
del G
|
||||
W_int8 = W_int8.t().contiguous()
|
||||
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
|
||||
*G_3D_sz[:-1], -1
|
||||
)
|
||||
|
||||
return grad_X, grad_W, grad_bias
|
||||
|
||||
class SwitchBackLinear(nn.Linear):
|
||||
def __init__(
|
||||
|
@ -106,7 +157,8 @@ class SwitchBackLinear(nn.Linear):
|
|||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
vectorize: bool = False
|
||||
vectorize: bool = False,
|
||||
mem_efficient : bool = False,
|
||||
):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
|
||||
|
@ -114,8 +166,14 @@ class SwitchBackLinear(nn.Linear):
|
|||
self.vectorize = vectorize
|
||||
if self.vectorize:
|
||||
self._fn = _switchback_vectorrize
|
||||
if mem_efficient:
|
||||
print('mem efficient is not supported for vectorize mode.')
|
||||
exit(1)
|
||||
else:
|
||||
self._fn = _switchback_global
|
||||
if mem_efficient:
|
||||
self._fn = _switchback_global_mem_efficient
|
||||
else:
|
||||
self._fn = _switchback_global
|
||||
|
||||
def prepare_for_eval(self):
|
||||
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
|
||||
|
@ -158,6 +216,7 @@ class SwitchBackLinear(nn.Linear):
|
|||
).view(*x.size()[:-1], -1)
|
||||
|
||||
SwitchBackLinearGlobal = partial(SwitchBackLinear, vectorize=False)
|
||||
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vectorize=False, mem_efficient=True)
|
||||
SwitchBackLinearVectorized = partial(SwitchBackLinear, vectorize=True)
|
||||
|
||||
# This is just the standard linear function.
|
||||
|
|
58
bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py
Normal file
58
bitsandbytes/nn/triton_utils/v0/dequantize_rowwise.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import math
|
||||
import torch
|
||||
import time
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time
|
||||
|
||||
# rowwise quantize
|
||||
|
||||
# TODO: autotune this better.
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_stages=1, num_warps=8),
|
||||
triton.Config({}, num_stages=2, num_warps=8),
|
||||
triton.Config({}, num_stages=4, num_warps=8),
|
||||
triton.Config({}, num_stages=8, num_warps=8),
|
||||
triton.Config({}, num_stages=1),
|
||||
triton.Config({}, num_stages=2),
|
||||
triton.Config({}, num_stages=4),
|
||||
triton.Config({}, num_stages=8),
|
||||
triton.Config({}, num_warps=1),
|
||||
triton.Config({}, num_warps=2),
|
||||
triton.Config({}, num_warps=4),
|
||||
triton.Config({}, num_warps=8),
|
||||
],
|
||||
key=['n_elements']
|
||||
)
|
||||
@triton.jit
|
||||
def _dequantize_rowwise(
|
||||
x_ptr,
|
||||
state_x,
|
||||
output_ptr,
|
||||
inv_127,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
P2: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
arange = tl.arange(0, P2)
|
||||
offsets = block_start + arange
|
||||
row_mask = arange < BLOCK_SIZE
|
||||
x = tl.load(x_ptr + offsets, mask=row_mask)
|
||||
max_val = tl.load(state_x + pid)
|
||||
output = max_val * x * inv_127
|
||||
tl.store(output_ptr + offsets, output, mask=row_mask)
|
||||
|
||||
|
||||
def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):
|
||||
output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)
|
||||
|
||||
P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))
|
||||
|
||||
assert x.is_cuda and output.is_cuda
|
||||
n_elements = output.numel()
|
||||
grid = lambda meta: (x.shape[0],)
|
||||
_dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)
|
||||
return output
|
Loading…
Reference in New Issue
Block a user