refactoring

This commit is contained in:
dbaranchuk 2022-09-11 06:26:15 +03:00
parent ee325f0215
commit d358999e9e

View File

@ -185,11 +185,10 @@ class MatmulLtState:
idx = None
is_training = True
has_fp16_weights = True
memory_efficient_backward = False
use_pool = False
formatB = F.get_special_format_str()
memory_efficient_backward = False
def reset_grads(self):
self.CB = None
self.CxB = None
@ -198,6 +197,7 @@ class MatmulLtState:
self.CxBt = None
self.SBt = None
self.CBt = None
class MatMul8bitLt(torch.autograd.Function):
@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function):
A_dtype = A.dtype
A = A.to(torch.float16)
assert (
A.dtype == torch.float16
), f"The input data type needs to be fp16 but {A.dtype} was found!"
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function):
return grad_A, grad_B, None, grad_bias, None
matmul = MatMul8bitLt.apply
def matmul(
A: tensor,
B: tensor,