adding half() cast

This commit is contained in:
Mitchell Wortsman 2023-02-21 03:53:44 +00:00
parent 2489d819c5
commit 7b764d3569
3 changed files with 66 additions and 9 deletions

View File

@ -415,8 +415,8 @@ class MatMulFP8(torch.autograd.Function):
cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=1024)
fp8A = F.dequantize_blockwise(cA, state, blocksize=1024).to(A.dtype)
cB, state = F.quantize_blockwise(B, code=fw_code, blocksize=1024)
fp8B = F.dequantize_blockwise(cB, state, blocksize=1024).to(B.dtype)
cB, state = F.quantize(B.float(), code=fw_code)
fp8B = F.dequantize(cB, state).to(B.dtype)
output = torch.matmul(fp8A, fp8B)
@ -450,9 +450,13 @@ class MatMulFP8(torch.autograd.Function):
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: grad_A = torch.matmul(fp8out, B.t())
if req_gradB: grad_B = torch.matmul(fp8A.t(), fp8out)
if req_gradA: grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(fp8A.dtype)
if req_gradB:
if fp8A.ndim == 3:
fp8At = fp8A.transpose(2, 1)
elif fp8A.ndim == 2:
fp8At = fp8A.t()
grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype)
return grad_A, grad_B, None, None, None

View File

@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8
from .modules import Int8Params, Linear8bitLt, StableEmbedding, OutlierAwareLinear, Fake4bitLinear, LinearFP8, LinearInt8, Linear8bitLtThresh, LinearInt8Cast

View File

@ -326,10 +326,11 @@ class Linear8bitLt(nn.Linear):
self.init_8bit_state()
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != torch.float16:
self.bias.data = self.bias.data.half()
# if self.bias is not None and self.bias.dtype != torch.float16:
# self.bias.data = self.bias.data.half()
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
#out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
out = bnb.matmul(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
if not self.state.has_fp16_weights:
if not self.state.memory_efficient_backward and self.state.CB is not None:
@ -344,6 +345,28 @@ class Linear8bitLt(nn.Linear):
return out
class Linear8bitLtThresh(Linear8bitLt):
def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=6.0,
index=None,
):
super().__init__(
input_features,
output_features,
bias=bias,
has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward,
threshold=threshold,
index=index
)
class LinearFP8(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
@ -361,3 +384,33 @@ class LinearFP8(nn.Linear):
return out
class LinearInt8(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.code = None
def forward(self, x: torch.Tensor):
if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
out = bnb.matmul_fp8(x, self.weight.t(), fw_code=self.code, bw_code=self.code)
if self.bias is not None:
out += self.bias
return out
class LinearInt8Cast(nn.Linear):
def __init__(self, input_features, output_features, bias=True):
super().__init__(input_features, output_features, bias)
self.code = None
def forward(self, x: torch.Tensor):
if self.code is None:
self.code = bnb.functional.create_linear_map(True, 8).to(x.device)
out = bnb.matmul_fp8(x.half(), self.weight.half().t(), fw_code=self.code, bw_code=self.code)
if self.bias is not None:
out += self.bias
return out