pre-triton update
This commit is contained in:
parent
75377d125e
commit
51f8bb7133
|
@ -11,7 +11,10 @@ from .autograd._functions import (
|
|||
matmul_cublas,
|
||||
mm_cublas,
|
||||
matmul_fp8,
|
||||
matmul_mixed
|
||||
matmul_mixed,
|
||||
matmul_fp8_global,
|
||||
matmul_fp4,
|
||||
matmul_fp8_mixed,
|
||||
)
|
||||
from .cextension import COMPILED_WITH_CUDA
|
||||
from .nn import modules
|
||||
|
|
|
@ -395,7 +395,7 @@ class MatMulFP8(torch.autograd.Function):
|
|||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024):
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
|
@ -425,6 +425,7 @@ class MatMulFP8(torch.autograd.Function):
|
|||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
|
@ -440,14 +441,13 @@ class MatMulFP8(torch.autograd.Function):
|
|||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _ = ctx.needs_input_grad
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz).to(grad_output.dtype)
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||
|
||||
cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
@ -467,7 +467,249 @@ class MatMulFP8(torch.autograd.Function):
|
|||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
class MatMulFP8Mixed(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||
|
||||
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
# fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
class MatMulFP4(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz)
|
||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2)
|
||||
fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype)
|
||||
|
||||
cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
cA, state = F.quantize(At.float(), code=ctx.bw_code)
|
||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(fp8At.to(fp8out_2.dtype), fp8out_2).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
|
||||
class MatMulFP8Global(torch.autograd.Function):
|
||||
# forward is the same, but we added the fallback for pre-turing GPUs
|
||||
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
|
||||
# default of pytorch behavior if inputs are empty
|
||||
ctx.is_empty = False
|
||||
if prod(A.shape) == 0:
|
||||
ctx.is_empty = True
|
||||
ctx.A = A
|
||||
ctx.B = B
|
||||
|
||||
B_shape = B.shape
|
||||
if A.shape[-1] == B_shape[0]:
|
||||
return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device)
|
||||
else:
|
||||
return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
|
||||
|
||||
# 1. Dequantize
|
||||
# 2. MatmulnN
|
||||
cA, state = F.quantize(A.float(), code=fw_code)
|
||||
fp8A = F.dequantize(cA, state).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
# output is half
|
||||
|
||||
# 3. Save state
|
||||
ctx.fw_code = fw_code
|
||||
ctx.bw_code = bw_code
|
||||
ctx.bsz = bsz
|
||||
ctx.bsz2 = bsz2
|
||||
ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype
|
||||
|
||||
if any(ctx.needs_input_grad[:2]):
|
||||
# NOTE: we send back A, and re-quant.
|
||||
ctx.tensors = (A, fp8B)
|
||||
else:
|
||||
ctx.tensors = (None, None)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.is_empty:
|
||||
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None
|
||||
|
||||
req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad
|
||||
A, B = ctx.tensors
|
||||
|
||||
grad_A, grad_B = None, None
|
||||
|
||||
# TODO: Fix blocksize to be output_dim
|
||||
cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype)
|
||||
|
||||
# cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code)
|
||||
# fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype)
|
||||
|
||||
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
|
||||
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
|
||||
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
if req_gradA:
|
||||
grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype)
|
||||
|
||||
if req_gradB:
|
||||
At = A.transpose(2, 1).contiguous()
|
||||
cA, state = F.quantize(At.float(), code=ctx.fw_code)
|
||||
fp8At = F.dequantize(cA, state).to(A.dtype)
|
||||
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None, None, None
|
||||
|
||||
|
||||
class MatMul8bitMixed(torch.autograd.Function):
|
||||
|
@ -520,12 +762,14 @@ class MatMul8bitMixed(torch.autograd.Function):
|
|||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
#print('A shape', A.shape)
|
||||
if not state.has_fp16_weights and state.CxB is None:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
subA = None
|
||||
|
||||
# 2. Quantize B
|
||||
if state.has_fp16_weights:
|
||||
#print('B shape', B.shape)
|
||||
has_grad = True if (getattr(B, "grad", None) is not None) else False
|
||||
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
|
||||
if is_transposed:
|
||||
|
@ -633,6 +877,8 @@ class MatMul8bitMixed(torch.autograd.Function):
|
|||
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
|
||||
|
||||
if req_gradB:
|
||||
# print('back A shape', A.shape)
|
||||
# print('grad output t shape', grad_output.t().shape)
|
||||
grad_B = torch.matmul(grad_output.t(), A)
|
||||
|
||||
if req_gradA:
|
||||
|
@ -642,6 +888,8 @@ class MatMul8bitMixed(torch.autograd.Function):
|
|||
state.CxBt, state.SBt = F.transform(
|
||||
state.CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
# print('back B shape', state.CxBt.shape)
|
||||
# print('back grad shape', C32grad.shape)
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
||||
|
@ -668,8 +916,18 @@ def matmul(
|
|||
return MatMul8bitLt.apply(A, B, out, bias, state)
|
||||
|
||||
|
||||
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1):
|
||||
return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz)
|
||||
def matmul_fp8(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
return MatMulFP8.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
|
||||
def matmul_fp4(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1):
|
||||
return MatMulFP4.apply(A, B, out, fw_code, bw_code, bsz, bsz2)
|
||||
|
||||
|
||||
def matmul_mixed(
|
||||
|
|
|
@ -2,4 +2,4 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast, Linear8bitLt2, Linear8bitLtMixed, LinearFP8Global, LinearFP4, LinearFP8Mixed
|
||||
|
|
|
@ -498,14 +498,69 @@ class LinearFP8(nn.Linear):
|
|||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
print('block size is', self.bsz)
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz)
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
|
||||
class LinearFP8Mixed(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
|
||||
class LinearFP8Global(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
|
@ -520,12 +575,16 @@ class LinearInt8(nn.Linear):
|
|||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.code is None:
|
||||
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz)
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
|
@ -553,3 +612,30 @@ class LinearInt8Cast(nn.Linear):
|
|||
|
||||
return out
|
||||
|
||||
|
||||
class LinearFP4(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.bw_code = None
|
||||
self.fw_code = None
|
||||
array = [4096, 2048, 1024, 512, 256, 128, 64, 0]
|
||||
for i, k in enumerate(array):
|
||||
if input_features > array[i + 1]:
|
||||
self.bsz = k
|
||||
break
|
||||
for i, k in enumerate(array):
|
||||
if output_features > array[i + 1]:
|
||||
self.bsz2 = k
|
||||
break
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.fw_code is None:
|
||||
#self.bw_code = bnb.functional.create_fp8_map(True, 3, 0, 4).to(x.device)
|
||||
self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device)
|
||||
self.fw_code = bnb.functional.create_fp8_map(True, 3, 0, 4).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp4(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
Loading…
Reference in New Issue
Block a user