adding half() cast
This commit is contained in:
parent
2489d819c5
commit
7b764d3569
|
@ -415,8 +415,8 @@ class MatMulFP8(torch.autograd.Function):
|
|||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
|
||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
|
||||
|
||||
cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024)
|
||||
fp8B = F.dequantize_blockwise(cB, state, blocksize=1024).to(B.dtype)
|
||||
cB, state = F.quantize(B.float(), code=fw_code)
|
||||
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||
|
||||
output = torch.matmul(fp8A, fp8B)
|
||||
|
||||
|
@ -450,9 +450,13 @@ class MatMulFP8(torch.autograd.Function):
|
|||
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||
|
||||
# not supported by PyTorch. TODO: create work-around
|
||||
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
||||
if req_gradA: grad_A = torch.matmul(fp8out, B.t())
|
||||
if req_gradB: grad_B = torch.matmul(fp8A.t(), fp8out)
|
||||
if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype)
|
||||
if req_gradB:
|
||||
if fp8A.ndim == 3:
|
||||
fp8At = fp8A.transpose(2, 1)
|
||||
elif fp8A.ndim == 2:
|
||||
fp8At = fp8A.t()
|
||||
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
||||
|
||||
return grad_A, grad_B, None, None, None
|
||||
|
||||
|
|
|
@ -2,4 +2,4 @@
|
|||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8
|
||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast
|
||||
|
|
|
@ -326,10 +326,11 @@ class Linear8bitLt(nn.Linear):
|
|||
self.init_8bit_state()
|
||||
|
||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||
if self.bias is not None and self.bias.dtype != torch.float16:
|
||||
self.bias.data = self.bias.data.half()
|
||||
# if self.bias is not None and self.bias.dtype != torch.float16:
|
||||
# self.bias.data = self.bias.data.half()
|
||||
|
||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
||||
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||
out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||
|
||||
if not self.state.has_fp16_weights:
|
||||
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
||||
|
@ -344,6 +345,28 @@ class Linear8bitLt(nn.Linear):
|
|||
|
||||
return out
|
||||
|
||||
|
||||
class Linear8bitLtThresh(Linear8bitLt):
|
||||
def __init__(
|
||||
self,
|
||||
input_features,
|
||||
output_features,
|
||||
bias=True,
|
||||
has_fp16_weights=True,
|
||||
memory_efficient_backward=False,
|
||||
threshold=6.0,
|
||||
index=None,
|
||||
):
|
||||
super().__init__(
|
||||
input_features,
|
||||
output_features,
|
||||
bias=bias,
|
||||
has_fp16_weights=has_fp16_weights,
|
||||
memory_efficient_backward=memory_efficient_backward,
|
||||
threshold=threshold,
|
||||
index=index
|
||||
)
|
||||
|
||||
class LinearFP8(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
|
@ -361,3 +384,33 @@ class LinearFP8(nn.Linear):
|
|||
|
||||
return out
|
||||
|
||||
class LinearInt8(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.code = None
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.code is None:
|
||||
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
|
||||
class LinearInt8Cast(nn.Linear):
|
||||
def __init__(self, input_features, output_features, bias=True):
|
||||
super().__init__(input_features, output_features, bias)
|
||||
self.code = None
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.code is None:
|
||||
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
||||
|
||||
out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code)
|
||||
if self.bias is not None:
|
||||
out += self.bias
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user