bitsandbytes-rocm/bitsandbytes/nn/triton_based_modules.py

259 lines
9.6 KiB
Python
Raw Normal View History

2023-03-29 06:47:08 +00:00
import torch
import torch.nn as nn
import time
2023-04-01 18:46:04 +00:00
from functools import partial
2023-03-29 06:47:08 +00:00
from bitsandbytes.triton.triton_utils import is_triton_available
from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
2023-03-29 06:47:08 +00:00
2023-04-01 18:46:04 +00:00
class _switchback_global(torch.autograd.Function):
2023-03-29 06:47:08 +00:00
@staticmethod
def forward(ctx, X_3D, W, bias):
2023-04-01 18:46:04 +00:00
# reshape input to [N * L, D]
2023-03-29 06:47:08 +00:00
X = X_3D.view(-1, X_3D.size(-1))
2023-04-01 18:46:04 +00:00
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_global(W)
# save for backward.
2023-03-29 06:47:08 +00:00
ctx.save_for_backward = X, W
2023-04-01 18:46:04 +00:00
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
2023-03-29 06:47:08 +00:00
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
2023-04-01 18:46:04 +00:00
2023-03-29 06:47:08 +00:00
@staticmethod
def backward(ctx, G_3D):
2023-04-01 18:46:04 +00:00
# reshape input to [N_out * L, D]
2023-03-29 06:47:08 +00:00
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
2023-04-01 18:46:04 +00:00
X, W = ctx.save_for_backward
2023-03-29 06:47:08 +00:00
if ctx.needs_input_grad[0]:
2023-04-01 18:46:04 +00:00
# rowwise quantize for G, global quantize for W
# for W, we also fuse the transpose operation because only A @ B^T is supported
# so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
2023-03-29 06:47:08 +00:00
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
2023-04-01 18:46:04 +00:00
# backward pass uses standard weight grad
2023-03-29 06:47:08 +00:00
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
2023-04-01 18:46:04 +00:00
class _switchback_vectorrize(torch.autograd.Function):
2023-03-29 06:47:08 +00:00
@staticmethod
def forward(ctx, X_3D, W, bias):
2023-04-01 18:46:04 +00:00
# reshape input to [N * L, D]
2023-03-29 06:47:08 +00:00
X = X_3D.view(-1, X_3D.size(-1))
ctx.save_for_backward = X, W
2023-04-01 18:46:04 +00:00
# rowwise quantize for X
# columnwise quantize for W (first rowwise, transpose later)
X_int8, state_X = quantize_rowwise(X)
W_int8, state_W = quantize_rowwise(W)
# matmult, fused dequant and add bias
# call kernel which expects rowwise quantized X and W
return int8_matmul_rowwise_dequantize(
2023-03-29 06:47:08 +00:00
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)
2023-03-29 06:47:08 +00:00
@staticmethod
def backward(ctx, G_3D):
2023-04-01 18:46:04 +00:00
X, W = ctx.save_for_backward
2023-03-29 06:47:08 +00:00
G = G_3D.reshape(-1, G_3D.size(-1))
grad_X = grad_W = grad_bias = None
if ctx.needs_input_grad[0]:
2023-04-01 18:46:04 +00:00
# rowwise quantize for G, columnwise quantize for W and fused transpose
# we call .t() for weight later because only A @ B^T is supported
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_columnwise_and_transpose(W)
grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
2023-03-29 06:47:08 +00:00
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
2023-04-01 18:46:04 +00:00
# backward pass uses standard weight grad
2023-03-29 06:47:08 +00:00
grad_W = torch.matmul(G.t(), X.to(G.dtype))
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
return grad_X, grad_W, grad_bias
2023-04-08 19:34:18 +00:00
class _switchback_global_mem_efficient(torch.autograd.Function):
@staticmethod
def forward(ctx, X_3D, W, bias):
# reshape input to [N * L, D]
X = X_3D.view(-1, X_3D.size(-1))
X_3D_sz = X_3D.size()
# rowwise quantize for X, global quantize for W
X_int8, state_X = quantize_rowwise(X)
del X
W_int8, state_W = quantize_global(W)
# save for backward.
ctx.save_for_backward = X_int8, state_X, W_int8, state_W
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)
@staticmethod
def backward(ctx, G_3D):
# reshape input to [N_out * L, D]
G = G_3D.reshape(-1, G_3D.size(-1))
G_3D_sz = G_3D.size()
grad_X = grad_W = grad_bias = None
X_int8, state_X, W_int8, state_W = ctx.save_for_backward
if ctx.needs_input_grad[1]:
real_X = dequantize_rowwise(X_int8, state_X)
del X_int8
grad_W = torch.matmul(G.t(), real_X.to(G.dtype))
del real_X
if ctx.needs_input_grad[2]:
grad_bias = G.sum(dim=0)
if ctx.needs_input_grad[0]:
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)
return grad_X, grad_W, grad_bias
2023-03-29 06:47:08 +00:00
2023-04-01 18:46:04 +00:00
class SwitchBackLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
2023-04-01 18:46:04 +00:00
bias: bool = True,
device=None,
2023-04-01 18:46:04 +00:00
dtype=None,
vector_wise_quantization: bool = False,
2023-04-08 19:34:18 +00:00
mem_efficient : bool = False,
2023-04-01 18:46:04 +00:00
):
super().__init__(in_features, out_features, bias, device, dtype)
if not is_triton_available:
raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear.
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''')
2023-04-01 18:46:04 +00:00
# By default, we use the global quantization.
self.vector_wise_quantization = vector_wise_quantization
if self.vector_wise_quantization:
2023-04-01 18:46:04 +00:00
self._fn = _switchback_vectorrize
2023-04-08 19:34:18 +00:00
if mem_efficient:
print('mem efficient is not supported for vector-wise quantization.')
2023-04-08 19:34:18 +00:00
exit(1)
2023-04-01 18:46:04 +00:00
else:
2023-04-08 19:34:18 +00:00
if mem_efficient:
self._fn = _switchback_global_mem_efficient
else:
self._fn = _switchback_global
2023-03-29 06:47:08 +00:00
def prepare_for_eval(self):
2023-04-01 18:46:04 +00:00
# If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass.
# Note this is experimental and not tested thoroughly.
# Note this needs to be explicitly called with something like
# def cond_prepare(m):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print('=> preparing for eval.')
if self.vector_wise_quantization:
2023-04-01 18:46:04 +00:00
W_int8, state_W = quantize_rowwise(self.weight)
else:
W_int8, state_W = quantize_global(self.weight)
2023-03-29 06:47:08 +00:00
self.register_buffer("W_int8", W_int8)
self.register_buffer("state_W", state_W)
del self.weight
def forward(self, x):
if self.training:
2023-04-01 18:46:04 +00:00
return self._fn.apply(x, self.weight, self.bias)
2023-03-29 06:47:08 +00:00
else:
2023-04-01 18:46:04 +00:00
# If it hasn't been "prepared for eval", run the standard forward pass.
if not hasattr(self, "W_int8"):
return self._fn.apply(x, self.weight, self.bias)
2023-04-01 18:46:04 +00:00
# Otherwise, use pre-computed weights.
2023-03-29 06:47:08 +00:00
X = x.view(-1, x.size(-1))
2023-04-01 18:46:04 +00:00
X_int8, state_X = quantize_rowwise(X)
2023-03-29 06:47:08 +00:00
if self.vector_wise_quantization:
2023-04-01 18:46:04 +00:00
return int8_matmul_rowwise_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
2023-03-29 06:47:08 +00:00
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
2023-03-29 06:47:08 +00:00
2023-04-01 18:46:04 +00:00
# This is just the standard linear function.
2023-03-31 18:20:54 +00:00
class StandardLinearFunction(torch.autograd.Function):
2023-03-29 06:47:08 +00:00
@staticmethod
def forward(ctx, input, weight, bias=None):
X = input.view(-1, input.size(-1))
ctx.save_for_backward(X, weight, bias)
output = input.matmul(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output.view(*input.size()[:-1], -1)
@staticmethod
def backward(ctx, grad_output_3D):
input, weight, bias = ctx.saved_tensors
grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1))
grad_input = grad_weight = grad_bias = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1)
if ctx.needs_input_grad[1]:
grad_weight = grad_output.t().matmul(input.to(grad_output.dtype))
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias
2023-03-31 18:20:54 +00:00
class StandardLinear(nn.Linear):
2023-03-29 06:47:08 +00:00
def forward(self, x):
2023-03-31 18:20:54 +00:00
return StandardLinearFunction.apply(x, self.weight, self.bias)