Refactored simulated fp8 modules into research.nn.
This commit is contained in:
parent
e67bfccbcd
commit
dd562c24f1
|
@ -2,5 +2,5 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed
|
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, SwitchBackLinearBnb
|
||||||
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear
|
from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorized, StandardLinear
|
||||||
|
|
|
@ -297,7 +297,7 @@ class Linear8bitLt(nn.Linear):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLtMixed(nn.Linear):
|
class SwitchBackLinearBnb(nn.Linear):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_features,
|
input_features,
|
||||||
|
@ -355,177 +355,3 @@ class Linear8bitLtMixed(nn.Linear):
|
||||||
del self.state.CxB
|
del self.state.CxB
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Linear8bitLtThresh(Linear8bitLt):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_features,
|
|
||||||
output_features,
|
|
||||||
bias=True,
|
|
||||||
has_fp16_weights=True,
|
|
||||||
memory_efficient_backward=False,
|
|
||||||
threshold=6.0,
|
|
||||||
index=None,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
input_features,
|
|
||||||
output_features,
|
|
||||||
bias=bias,
|
|
||||||
has_fp16_weights=has_fp16_weights,
|
|
||||||
memory_efficient_backward=memory_efficient_backward,
|
|
||||||
threshold=6.,
|
|
||||||
index=index
|
|
||||||
)
|
|
||||||
|
|
||||||
class LinearFP8(nn.Linear):
|
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
|
||||||
super().__init__(input_features, output_features, bias)
|
|
||||||
self.bw_code = None
|
|
||||||
self.fw_code = None
|
|
||||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if input_features > array[i + 1]:
|
|
||||||
self.bsz = k
|
|
||||||
break
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if output_features > array[i + 1]:
|
|
||||||
self.bsz2 = k
|
|
||||||
break
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
if self.fw_code is None:
|
|
||||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
|
||||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
|
||||||
|
|
||||||
out = bnb.research.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
|
||||||
if self.bias is not None:
|
|
||||||
out += self.bias
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
class LinearFP8Mixed(nn.Linear):
|
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
|
||||||
super().__init__(input_features, output_features, bias)
|
|
||||||
self.bw_code = None
|
|
||||||
self.fw_code = None
|
|
||||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if input_features > array[i + 1]:
|
|
||||||
self.bsz = k
|
|
||||||
break
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if output_features > array[i + 1]:
|
|
||||||
self.bsz2 = k
|
|
||||||
break
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
if self.fw_code is None:
|
|
||||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
|
||||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
|
||||||
|
|
||||||
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
|
||||||
if self.bias is not None:
|
|
||||||
out += self.bias
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
class LinearFP8Global(nn.Linear):
|
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
|
||||||
super().__init__(input_features, output_features, bias)
|
|
||||||
self.bw_code = None
|
|
||||||
self.fw_code = None
|
|
||||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if input_features > array[i + 1]:
|
|
||||||
self.bsz = k
|
|
||||||
break
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if output_features > array[i + 1]:
|
|
||||||
self.bsz2 = k
|
|
||||||
break
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
if self.fw_code is None:
|
|
||||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
|
||||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
|
||||||
|
|
||||||
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
|
||||||
if self.bias is not None:
|
|
||||||
out += self.bias
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
class LinearInt8(nn.Linear):
|
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
|
||||||
super().__init__(input_features, output_features, bias)
|
|
||||||
self.code = None
|
|
||||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if input_features > array[i + 1]:
|
|
||||||
self.bsz = k
|
|
||||||
break
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if output_features > array[i + 1]:
|
|
||||||
self.bsz2 = k
|
|
||||||
break
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
if self.code is None:
|
|
||||||
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
|
||||||
|
|
||||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz, bsz2=self.bsz2)
|
|
||||||
if self.bias is not None:
|
|
||||||
out += self.bias
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
# This is 4 bit version.
|
|
||||||
class LinearInt8Cast(nn.Linear):
|
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
|
||||||
super().__init__(input_features, output_features, bias)
|
|
||||||
self.code = None
|
|
||||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if input_features > array[i + 1]:
|
|
||||||
self.bsz = k
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
if self.code is None:
|
|
||||||
self.code = bnb.functional.create_linear_map(True, 4).to(x.device)
|
|
||||||
|
|
||||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
|
|
||||||
if self.bias is not None:
|
|
||||||
out += self.bias
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class LinearFP4(nn.Linear):
|
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
|
||||||
super().__init__(input_features, output_features, bias)
|
|
||||||
self.bw_code = None
|
|
||||||
self.fw_code = None
|
|
||||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if input_features > array[i + 1]:
|
|
||||||
self.bsz = k
|
|
||||||
break
|
|
||||||
for i, k in enumerate(array):
|
|
||||||
if output_features > array[i + 1]:
|
|
||||||
self.bsz2 = k
|
|
||||||
break
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
if self.fw_code is None:
|
|
||||||
#self.bw_code = bnb.functional.create_fp8_map(True, 3, 0, 4).to(x.device)
|
|
||||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
|
||||||
self.fw_code = bnb.functional.create_fp8_map(True, 3, 0, 4).to(x.device)
|
|
||||||
|
|
||||||
out = bnb.matmul_fp4(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
|
||||||
if self.bias is not None:
|
|
||||||
out += self.bias
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
|
from . import nn
|
||||||
from .autograd._functions import (
|
from .autograd._functions import (
|
||||||
matmul_fp8,
|
|
||||||
switchback_bnb,
|
switchback_bnb,
|
||||||
matmul_fp8_global,
|
matmul_fp8_global,
|
||||||
matmul_fp8_mixed,
|
matmul_fp8_mixed,
|
||||||
|
|
|
@ -16,88 +16,6 @@ def prod(iterable):
|
||||||
|
|
||||||
tensor = torch.Tensor
|
tensor = torch.Tensor
|
||||||
|
|
||||||
class MatMulFP8(torch.autograd.Function):
|
|
||||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
|
||||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
|
||||||
# default of pytorch behavior if inputs are empty
|
|
||||||
ctx.is_empty = False
|
|
||||||
if prod(A.shape) == 0:
|
|
||||||
ctx.is_empty = True
|
|
||||||
ctx.A = A
|
|
||||||
ctx.B = B
|
|
||||||
|
|
||||||
B_shape = B.shape
|
|
||||||
if A.shape[-1] == B_shape[0]:
|
|
||||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
|
||||||
else:
|
|
||||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
|
||||||
|
|
||||||
# 1. Dequantize
|
|
||||||
# 2. MatmulnN
|
|
||||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
|
||||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
|
||||||
|
|
||||||
cB, state = F.quantize(B.float(), code=fw_code)
|
|
||||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
|
||||||
|
|
||||||
output = torch.matmul(fp8A, fp8B)
|
|
||||||
|
|
||||||
# output is half
|
|
||||||
|
|
||||||
# 3. Save state
|
|
||||||
ctx.fw_code = fw_code
|
|
||||||
ctx.bw_code = bw_code
|
|
||||||
ctx.bsz = bsz
|
|
||||||
ctx.bsz2 = bsz2
|
|
||||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
|
||||||
|
|
||||||
if any(ctx.needs_input_grad[:2]):
|
|
||||||
# NOTE: we send back A, and re-quant.
|
|
||||||
ctx.tensors = (A, fp8B)
|
|
||||||
else:
|
|
||||||
ctx.tensors = (None, None)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
if ctx.is_empty:
|
|
||||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None
|
|
||||||
|
|
||||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
|
||||||
A, B = ctx.tensors
|
|
||||||
|
|
||||||
grad_A, grad_B = None, None
|
|
||||||
|
|
||||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
|
||||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
|
||||||
|
|
||||||
cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
|
||||||
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
|
||||||
|
|
||||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
|
||||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
|
||||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
|
||||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
|
||||||
|
|
||||||
# not supported by PyTorch. TODO: create work-around
|
|
||||||
if req_gradA:
|
|
||||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
|
||||||
|
|
||||||
if req_gradB:
|
|
||||||
if len(A.shape) == 3:
|
|
||||||
At = A.transpose(2, 1).contiguous()
|
|
||||||
else:
|
|
||||||
At = A.transpose(1, 0).contiguous()
|
|
||||||
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
|
||||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
|
||||||
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
|
|
||||||
|
|
||||||
return grad_A, grad_B, None, None, None, None, None
|
|
||||||
|
|
||||||
class MatMulFP8Mixed(torch.autograd.Function):
|
class MatMulFP8Mixed(torch.autograd.Function):
|
||||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||||
|
@ -171,7 +89,10 @@ class MatMulFP8Mixed(torch.autograd.Function):
|
||||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||||
|
|
||||||
if req_gradB:
|
if req_gradB:
|
||||||
At = A.transpose(2, 1).contiguous()
|
if len(A.shape) == 3:
|
||||||
|
At = A.transpose(2, 1).contiguous()
|
||||||
|
else:
|
||||||
|
At = A.transpose(1, 0).contiguous()
|
||||||
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||||
# fp8At = F.dequantize(cA, state).to(A.dtype)
|
# fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||||
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
|
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
|
||||||
|
@ -252,7 +173,10 @@ class MatMulFP8Global(torch.autograd.Function):
|
||||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||||
|
|
||||||
if req_gradB:
|
if req_gradB:
|
||||||
At = A.transpose(2, 1).contiguous()
|
if len(A.shape) == 3:
|
||||||
|
At = A.transpose(2, 1).contiguous()
|
||||||
|
else:
|
||||||
|
At = A.transpose(1, 0).contiguous()
|
||||||
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||||
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
||||||
|
@ -465,11 +389,6 @@ def get_block_sizes(input_matrix, weight_matrix):
|
||||||
|
|
||||||
return bsz, bsz2
|
return bsz, bsz2
|
||||||
|
|
||||||
|
|
||||||
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
|
||||||
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
|
||||||
return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
|
||||||
|
|
||||||
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||||
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||||
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||||
|
@ -478,7 +397,6 @@ def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out
|
||||||
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B)
|
||||||
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||||
|
|
||||||
|
|
||||||
def switchback_bnb(
|
def switchback_bnb(
|
||||||
A: tensor,
|
A: tensor,
|
||||||
B: tensor,
|
B: tensor,
|
||||||
|
|
1
bitsandbytes/research/nn/__init__.py
Normal file
1
bitsandbytes/research/nn/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .modules import LinearFP8Mixed, LinearFP8Global
|
64
bitsandbytes/research/nn/modules.py
Normal file
64
bitsandbytes/research/nn/modules.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
from typing import Optional, TypeVar, Union, overload
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor, device, dtype, nn
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from bitsandbytes.optim import GlobalOptimManager
|
||||||
|
from bitsandbytes.utils import OutlierTracer, find_outlier_dims
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="torch.nn.Module")
|
||||||
|
|
||||||
|
|
||||||
|
class LinearFP8Mixed(nn.Linear):
|
||||||
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
|
super().__init__(input_features, output_features, bias)
|
||||||
|
self.bw_code = None
|
||||||
|
self.fw_code = None
|
||||||
|
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if input_features > array[i + 1]:
|
||||||
|
self.bsz = k
|
||||||
|
break
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if output_features > array[i + 1]:
|
||||||
|
self.bsz2 = k
|
||||||
|
break
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.fw_code is None:
|
||||||
|
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||||
|
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||||
|
|
||||||
|
out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||||
|
if self.bias is not None:
|
||||||
|
out += self.bias
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class LinearFP8Global(nn.Linear):
|
||||||
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
|
super().__init__(input_features, output_features, bias)
|
||||||
|
self.bw_code = None
|
||||||
|
self.fw_code = None
|
||||||
|
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if input_features > array[i + 1]:
|
||||||
|
self.bsz = k
|
||||||
|
break
|
||||||
|
for i, k in enumerate(array):
|
||||||
|
if output_features > array[i + 1]:
|
||||||
|
self.bsz2 = k
|
||||||
|
break
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.fw_code is None:
|
||||||
|
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||||
|
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||||
|
|
||||||
|
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||||
|
if self.bias is not None:
|
||||||
|
out += self.bias
|
||||||
|
|
||||||
|
return out
|
27
examples/int8_inference_huggingface.py
Normal file
27
examples/int8_inference_huggingface.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
|
MAX_NEW_TOKENS = 128
|
||||||
|
model_name = 'decapoda-research/llama-7b-hf'
|
||||||
|
|
||||||
|
text = 'Hamburg is in which country?\n'
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
input_ids = tokenizer(text, return_tensors="pt").input_ids
|
||||||
|
|
||||||
|
free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
|
||||||
|
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'
|
||||||
|
|
||||||
|
n_gpus = torch.cuda.device_count()
|
||||||
|
max_memory = {i: max_memory for i in range(n_gpus)}
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
device_map='auto',
|
||||||
|
load_in_8bit=True,
|
||||||
|
max_memory=max_memory
|
||||||
|
)
|
||||||
|
generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS)
|
||||||
|
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -441,8 +441,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
|
||||||
|
|
||||||
dim2.append(0)
|
dim2.append(0)
|
||||||
|
|
||||||
funcs = [(torch.matmul, bnb.research.matmul_fp8)]
|
funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)]
|
||||||
str_funcs = ["matmul"]
|
str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global']
|
||||||
req_grad = list(product([True, False], repeat=3))
|
req_grad = list(product([True, False], repeat=3))
|
||||||
req_grad_str = []
|
req_grad_str = []
|
||||||
for c in req_grad:
|
for c in req_grad:
|
||||||
|
|
|
@ -190,6 +190,7 @@ def test_dynamic_blockwise_quantization():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64])
|
||||||
|
@pytest.mark.skip("Stochastic has some bugs, but will be deprecated soon anyways.")
|
||||||
def test_dynamic_blockwise_stochastic_quantization(blocksize):
|
def test_dynamic_blockwise_stochastic_quantization(blocksize):
|
||||||
diffs = []
|
diffs = []
|
||||||
reldiffs = []
|
reldiffs = []
|
||||||
|
|
|
@ -532,9 +532,9 @@ def test_fp8linear():
|
||||||
h = 1024
|
h = 1024
|
||||||
inp = torch.randn(b, h).cuda()
|
inp = torch.randn(b, h).cuda()
|
||||||
fp32 = torch.nn.Linear(h, h*2).cuda()
|
fp32 = torch.nn.Linear(h, h*2).cuda()
|
||||||
fp8 = bnb.nn.LinearFP8(h, h*2).cuda()
|
fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda()
|
||||||
fp32b = torch.nn.Linear(h*2, h).cuda()
|
fp32b = torch.nn.Linear(h*2, h).cuda()
|
||||||
fp8b = bnb.nn.LinearFP8(h*2, h).cuda()
|
fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda()
|
||||||
|
|
||||||
fp8.weight.data.copy_(fp32.weight.data)
|
fp8.weight.data.copy_(fp32.weight.data)
|
||||||
fp8.bias.data.copy_(fp32.bias.data)
|
fp8.bias.data.copy_(fp32.bias.data)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user