2023-03-29 06:47:08 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import time
|
2023-04-01 18:46:04 +00:00
|
|
|
from functools import partial
|
2023-03-29 06:47:08 +00:00
|
|
|
|
2023-04-12 19:16:55 +00:00
|
|
|
from bitsandbytes.triton.triton_utils import is_triton_available
|
|
|
|
|
2023-04-12 16:39:39 +00:00
|
|
|
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
|
|
|
|
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
|
|
|
|
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
|
|
|
|
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
|
|
|
|
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
|
|
|
|
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
|
2023-03-29 06:47:08 +00:00
|
|
|
|
2023-04-01 18:46:04 +00:00
|
|
|
|
|
|
|
class _switchback_global(torch.autograd.Function):
|
2023-03-29 06:47:08 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, X_3D, W, bias):
|
2023-04-01 18:46:04 +00:00
|
|
|
# reshape input to [N * L, D]
|
2023-03-29 06:47:08 +00:00
|
|
|
X = X_3D.view(-1, X_3D.size(-1))
|
|
|
|
|
2023-04-01 18:46:04 +00:00
|
|
|
# rowwise quantize for X, global quantize for W
|
|
|
|
X_int8, state_X = quantize_rowwise(X)
|
|
|
|
W_int8, state_W = quantize_global(W)
|
|
|
|
|
|
|
|
# save for backward.
|
2023-03-29 06:47:08 +00:00
|
|
|
ctx.save_for_backward = X, W
|
2023-04-01 18:46:04 +00:00
|
|
|
|
|
|
|
# matmult, fused dequant and add bias
|
|
|
|
# call "mixed" because we are mixing rowwise quantized and global quantized
|
|
|
|
return int8_matmul_mixed_dequanitze(
|
2023-03-29 06:47:08 +00:00
|
|
|
X_int8, W_int8.t(), state_X, state_W, bias
|
|
|
|
).view(*X_3D.size()[:-1], -1)
|
2023-04-01 18:46:04 +00:00
|
|
|
|
2023-03-29 06:47:08 +00:00
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, G_3D):
|
2023-04-01 18:46:04 +00:00
|
|
|
# reshape input to [N_out * L, D]
|
2023-03-29 06:47:08 +00:00
|
|
|
G = G_3D.reshape(-1, G_3D.size(-1))
|
|
|
|
|
|
|
|
grad_X = grad_W = grad_bias = None
|
|
|
|
|
2023-04-01 18:46:04 +00:00
|
|
|
X, W = ctx.save_for_backward
|
2023-03-29 06:47:08 +00:00
|
|
|
if ctx.needs_input_grad[0]:
|
2023-04-01 18:46:04 +00:00
|
|
|
# rowwise quantize for G, global quantize for W
|
|
|
|
# for W, we also fuse the transpose operation because only A @ B^T is supported
|
|
|
|
# so we transpose once then call .t() in the matmul
|
|
|
|
G_int8, state_G = quantize_rowwise(G)
|
|
|
|
W_int8, state_W = quantize_global_transpose(W)
|
|
|
|
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
|
2023-03-29 06:47:08 +00:00
|
|
|
*G_3D.size()[:-1], -1
|
|
|
|
)
|
|
|
|
if ctx.needs_input_grad[1]:
|
2023-04-01 18:46:04 +00:00
|
|
|
# backward pass uses standard weight grad
|
2023-03-29 06:47:08 +00:00
|
|
|
grad_W = torch.matmul(G.t(), X.to(G.dtype))
|
|
|
|
if ctx.needs_input_grad[2]:
|
|
|
|
grad_bias = G.sum(dim=0)
|
|
|
|
|
|
|
|
return grad_X, grad_W, grad_bias
|
2023-04-12 16:39:39 +00:00
|
|
|
|
2023-04-01 18:46:04 +00:00
|
|
|
class _switchback_vectorrize(torch.autograd.Function):
|
2023-03-29 06:47:08 +00:00
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, X_3D, W, bias):
|
2023-04-01 18:46:04 +00:00
|
|
|
# reshape input to [N * L, D]
|
2023-03-29 06:47:08 +00:00
|
|
|
X = X_3D.view(-1, X_3D.size(-1))
|
|
|
|
|
|
|
|
ctx.save_for_backward = X, W
|
2023-04-01 18:46:04 +00:00
|
|
|
# rowwise quantize for X
|
|
|
|
# columnwise quantize for W (first rowwise, transpose later)
|
|
|
|
X_int8, state_X = quantize_rowwise(X)
|
|
|
|
W_int8, state_W = quantize_rowwise(W)
|
|
|
|
|
|
|
|
# matmult, fused dequant and add bias
|
|
|
|
# call kernel which expects rowwise quantized X and W
|
|
|
|
return int8_matmul_rowwise_dequantize(
|
2023-03-29 06:47:08 +00:00
|
|
|
X_int8, W_int8.t(), state_X, state_W, bias
|
|
|
|
).view(*X_3D.size()[:-1], -1)
|
2023-04-12 16:39:39 +00:00
|
|
|
|
2023-03-29 06:47:08 +00:00
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, G_3D):
|
2023-04-01 18:46:04 +00:00
|
|
|
X, W = ctx.save_for_backward
|
2023-03-29 06:47:08 +00:00
|
|
|
|
|
|
|
G = G_3D.reshape(-1, G_3D.size(-1))
|
|
|
|
|
|
|
|
grad_X = grad_W = grad_bias = None
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[0]:
|
2023-04-01 18:46:04 +00:00
|
|
|
# rowwise quantize for G, columnwise quantize for W and fused transpose
|
|
|
|
# we call .t() for weight later because only A @ B^T is supported
|
|
|
|
G_int8, state_G = quantize_rowwise(G)
|
|
|
|
W_int8, state_W = quantize_columnwise_and_transpose(W)
|
|
|
|
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
|
2023-03-29 06:47:08 +00:00
|
|
|
*G_3D.size()[:-1], -1
|
|
|
|
)
|
|
|
|
if ctx.needs_input_grad[1]:
|
2023-04-01 18:46:04 +00:00
|
|
|
# backward pass uses standard weight grad
|
2023-03-29 06:47:08 +00:00
|
|
|
grad_W = torch.matmul(G.t(), X.to(G.dtype))
|
|
|
|
if ctx.needs_input_grad[2]:
|
|
|
|
grad_bias = G.sum(dim=0)
|
|
|
|
|
|
|
|
return grad_X, grad_W, grad_bias
|
2023-04-12 16:39:39 +00:00
|
|
|
|
2023-04-08 19:34:18 +00:00
|
|
|
class _switchback_global_mem_efficient(torch.autograd.Function):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, X_3D, W, bias):
|
|
|
|
# reshape input to [N * L, D]
|
|
|
|
X = X_3D.view(-1, X_3D.size(-1))
|
|
|
|
X_3D_sz = X_3D.size()
|
|
|
|
|
|
|
|
# rowwise quantize for X, global quantize for W
|
|
|
|
X_int8, state_X = quantize_rowwise(X)
|
|
|
|
del X
|
|
|
|
W_int8, state_W = quantize_global(W)
|
|
|
|
|
|
|
|
# save for backward.
|
|
|
|
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
|
|
|
|
|
|
|
|
# matmult, fused dequant and add bias
|
|
|
|
# call "mixed" because we are mixing rowwise quantized and global quantized
|
|
|
|
return int8_matmul_mixed_dequanitze(
|
|
|
|
X_int8, W_int8.t(), state_X, state_W, bias
|
|
|
|
).view(*X_3D_sz[:-1], -1)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, G_3D):
|
|
|
|
# reshape input to [N_out * L, D]
|
|
|
|
G = G_3D.reshape(-1, G_3D.size(-1))
|
|
|
|
G_3D_sz = G_3D.size()
|
|
|
|
|
|
|
|
grad_X = grad_W = grad_bias = None
|
|
|
|
|
|
|
|
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
|
|
|
|
if ctx.needs_input_grad[1]:
|
|
|
|
real_X = dequantize_rowwise(X_int8, state_X)
|
|
|
|
del X_int8
|
|
|
|
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
|
|
|
|
del real_X
|
|
|
|
if ctx.needs_input_grad[2]:
|
|
|
|
grad_bias = G.sum(dim=0)
|
|
|
|
if ctx.needs_input_grad[0]:
|
|
|
|
G_int8, state_G = quantize_rowwise(G)
|
|
|
|
del G
|
|
|
|
W_int8 = W_int8.t().contiguous()
|
|
|
|
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
|
|
|
|
*G_3D_sz[:-1], -1
|
|
|
|
)
|
|
|
|
|
|
|
|
return grad_X, grad_W, grad_bias
|
2023-03-29 06:47:08 +00:00
|
|
|
|
2023-04-01 18:46:04 +00:00
|
|
|
class SwitchBackLinear(nn.Linear):
|
|
|
|
def __init__(
|
2023-04-12 16:39:39 +00:00
|
|
|
self,
|
|
|
|
in_features: int,
|
|
|
|
out_features: int,
|
2023-04-01 18:46:04 +00:00
|
|
|
bias: bool = True,
|
2023-04-12 16:39:39 +00:00
|
|
|
device=None,
|
2023-04-01 18:46:04 +00:00
|
|
|
dtype=None,
|
2023-04-12 20:41:30 +00:00
|
|
|
vector_wise_quantization: bool = False,
|
2023-04-08 19:34:18 +00:00
|
|
|
mem_efficient : bool = False,
|
2023-04-01 18:46:04 +00:00
|
|
|
):
|
|
|
|
super().__init__(in_features, out_features, bias, device, dtype)
|
|
|
|
|
2023-04-12 19:16:55 +00:00
|
|
|
if not is_triton_available:
|
|
|
|
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
|
|
|
|
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
|
|
|
|
|
2023-04-01 18:46:04 +00:00
|
|
|
# By default, we use the global quantization.
|
2023-04-12 20:41:30 +00:00
|
|
|
self.vector_wise_quantization = vector_wise_quantization
|
|
|
|
if self.vector_wise_quantization:
|
2023-04-01 18:46:04 +00:00
|
|
|
self._fn = _switchback_vectorrize
|
2023-04-08 19:34:18 +00:00
|
|
|
if mem_efficient:
|
2023-04-12 20:41:30 +00:00
|
|
|
print('mem efficient is not supported for vector-wise quantization.')
|
2023-04-08 19:34:18 +00:00
|
|
|
exit(1)
|
2023-04-01 18:46:04 +00:00
|
|
|
else:
|
2023-04-08 19:34:18 +00:00
|
|
|
if mem_efficient:
|
|
|
|
self._fn = _switchback_global_mem_efficient
|
|
|
|
else:
|
|
|
|
self._fn = _switchback_global
|
2023-03-29 06:47:08 +00:00
|
|
|
|
|
|
|
def prepare_for_eval(self):
|
2023-04-01 18:46:04 +00:00
|
|
|
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
|
|
|
|
# Note this is experimental and not tested thoroughly.
|
|
|
|
# Note this needs to be explicitly called with something like
|
|
|
|
# def cond_prepare(m):
|
|
|
|
# if hasattr(m, "prepare_for_eval"):
|
|
|
|
# m.prepare_for_eval()
|
|
|
|
# model.apply(cond_prepare)
|
|
|
|
print('=> preparing for eval.')
|
2023-04-12 20:41:30 +00:00
|
|
|
if self.vector_wise_quantization:
|
2023-04-01 18:46:04 +00:00
|
|
|
W_int8, state_W = quantize_rowwise(self.weight)
|
|
|
|
else:
|
|
|
|
W_int8, state_W = quantize_global(self.weight)
|
2023-04-12 16:39:39 +00:00
|
|
|
|
2023-03-29 06:47:08 +00:00
|
|
|
self.register_buffer("W_int8", W_int8)
|
|
|
|
self.register_buffer("state_W", state_W)
|
|
|
|
|
|
|
|
del self.weight
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
if self.training:
|
2023-04-01 18:46:04 +00:00
|
|
|
return self._fn.apply(x, self.weight, self.bias)
|
2023-03-29 06:47:08 +00:00
|
|
|
else:
|
2023-04-01 18:46:04 +00:00
|
|
|
# If it hasn't been "prepared for eval", run the standard forward pass.
|
|
|
|
if not hasattr(self, "W_int8"):
|
|
|
|
return self._fn.apply(x, self.weight, self.bias)
|
2023-04-12 16:39:39 +00:00
|
|
|
|
2023-04-01 18:46:04 +00:00
|
|
|
# Otherwise, use pre-computed weights.
|
2023-03-29 06:47:08 +00:00
|
|
|
X = x.view(-1, x.size(-1))
|
2023-04-01 18:46:04 +00:00
|
|
|
X_int8, state_X = quantize_rowwise(X)
|
2023-03-29 06:47:08 +00:00
|
|
|
|
2023-04-12 20:41:30 +00:00
|
|
|
if self.vector_wise_quantization:
|
2023-04-01 18:46:04 +00:00
|
|
|
return int8_matmul_rowwise_dequantize(
|
|
|
|
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
|
|
|
).view(*x.size()[:-1], -1)
|
|
|
|
else:
|
|
|
|
return int8_matmul_mixed_dequanitze(
|
|
|
|
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
|
|
|
|
).view(*x.size()[:-1], -1)
|
2023-03-29 06:47:08 +00:00
|
|
|
|
2023-04-12 20:41:30 +00:00
|
|
|
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
|
|
|
|
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
|
|
|
|
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
|
2023-03-29 06:47:08 +00:00
|
|
|
|
2023-04-01 18:46:04 +00:00
|
|
|
# This is just the standard linear function.
|
2023-03-31 18:20:54 +00:00
|
|
|
class StandardLinearFunction(torch.autograd.Function):
|
2023-03-29 06:47:08 +00:00
|
|
|
@staticmethod
|
|
|
|
def forward(ctx, input, weight, bias=None):
|
|
|
|
X = input.view(-1, input.size(-1))
|
|
|
|
|
|
|
|
ctx.save_for_backward(X, weight, bias)
|
|
|
|
output = input.matmul(weight.t())
|
|
|
|
if bias is not None:
|
|
|
|
output += bias.unsqueeze(0).expand_as(output)
|
|
|
|
return output.view(*input.size()[:-1], -1)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def backward(ctx, grad_output_3D):
|
|
|
|
input, weight, bias = ctx.saved_tensors
|
|
|
|
|
|
|
|
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))
|
|
|
|
|
|
|
|
grad_input = grad_weight = grad_bias = None
|
|
|
|
|
|
|
|
if ctx.needs_input_grad[0]:
|
|
|
|
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
|
|
|
|
if ctx.needs_input_grad[1]:
|
|
|
|
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
|
|
|
|
if bias is not None and ctx.needs_input_grad[2]:
|
|
|
|
grad_bias = grad_output.sum(0)
|
|
|
|
|
|
|
|
return grad_input, grad_weight, grad_bias
|
|
|
|
|
2023-03-31 18:20:54 +00:00
|
|
|
class StandardLinear(nn.Linear):
|
2023-03-29 06:47:08 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
2023-03-31 18:20:54 +00:00
|
|
|
return StandardLinearFunction.apply(x, self.weight, self.bias)
|