adding half() cast
This commit is contained in:
parent
2489d819c5
commit
7b764d3569
|
@ -415,8 +415,8 @@ class MatMulFP8(torch.autograd.Function):
|
||||||
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
|
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
|
||||||
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
|
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
|
||||||
|
|
||||||
cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024)
|
cB, state = F.quantize(B.float(), code=fw_code)
|
||||||
fp8B = F.dequantize_blockwise(cB, state, blocksize=1024).to(B.dtype)
|
fp8B = F.dequantize(cB, state).to(B.dtype)
|
||||||
|
|
||||||
output = torch.matmul(fp8A, fp8B)
|
output = torch.matmul(fp8A, fp8B)
|
||||||
|
|
||||||
|
@ -450,9 +450,13 @@ class MatMulFP8(torch.autograd.Function):
|
||||||
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
|
||||||
|
|
||||||
# not supported by PyTorch. TODO: create work-around
|
# not supported by PyTorch. TODO: create work-around
|
||||||
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
|
if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype)
|
||||||
if req_gradA: grad_A = torch.matmul(fp8out, B.t())
|
if req_gradB:
|
||||||
if req_gradB: grad_B = torch.matmul(fp8A.t(), fp8out)
|
if fp8A.ndim == 3:
|
||||||
|
fp8At = fp8A.transpose(2, 1)
|
||||||
|
elif fp8A.ndim == 2:
|
||||||
|
fp8At = fp8A.t()
|
||||||
|
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
|
||||||
|
|
||||||
return grad_A, grad_B, None, None, None
|
return grad_A, grad_B, None, None, None
|
||||||
|
|
||||||
|
|
|
@ -2,4 +2,4 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the MIT license found in the
|
# This source code is licensed under the MIT license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8
|
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast
|
||||||
|
|
|
@ -326,10 +326,11 @@ class Linear8bitLt(nn.Linear):
|
||||||
self.init_8bit_state()
|
self.init_8bit_state()
|
||||||
|
|
||||||
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
# weights are cast automatically as Int8Params, but the bias has to be cast manually
|
||||||
if self.bias is not None and self.bias.dtype != torch.float16:
|
# if self.bias is not None and self.bias.dtype != torch.float16:
|
||||||
self.bias.data = self.bias.data.half()
|
# self.bias.data = self.bias.data.half()
|
||||||
|
|
||||||
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
|
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||||
|
out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
|
||||||
|
|
||||||
if not self.state.has_fp16_weights:
|
if not self.state.has_fp16_weights:
|
||||||
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
if not self.state.memory_efficient_backward and self.state.CB is not None:
|
||||||
|
@ -344,6 +345,28 @@ class Linear8bitLt(nn.Linear):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Linear8bitLtThresh(Linear8bitLt):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_features,
|
||||||
|
output_features,
|
||||||
|
bias=True,
|
||||||
|
has_fp16_weights=True,
|
||||||
|
memory_efficient_backward=False,
|
||||||
|
threshold=6.0,
|
||||||
|
index=None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
input_features,
|
||||||
|
output_features,
|
||||||
|
bias=bias,
|
||||||
|
has_fp16_weights=has_fp16_weights,
|
||||||
|
memory_efficient_backward=memory_efficient_backward,
|
||||||
|
threshold=threshold,
|
||||||
|
index=index
|
||||||
|
)
|
||||||
|
|
||||||
class LinearFP8(nn.Linear):
|
class LinearFP8(nn.Linear):
|
||||||
def __init__(self, input_features, output_features, bias=True):
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
super().__init__(input_features, output_features, bias)
|
super().__init__(input_features, output_features, bias)
|
||||||
|
@ -361,3 +384,33 @@ class LinearFP8(nn.Linear):
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class LinearInt8(nn.Linear):
|
||||||
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
|
super().__init__(input_features, output_features, bias)
|
||||||
|
self.code = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.code is None:
|
||||||
|
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
||||||
|
|
||||||
|
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code)
|
||||||
|
if self.bias is not None:
|
||||||
|
out += self.bias
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class LinearInt8Cast(nn.Linear):
|
||||||
|
def __init__(self, input_features, output_features, bias=True):
|
||||||
|
super().__init__(input_features, output_features, bias)
|
||||||
|
self.code = None
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
if self.code is None:
|
||||||
|
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
|
||||||
|
|
||||||
|
out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code)
|
||||||
|
if self.bias is not None:
|
||||||
|
out += self.bias
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user