forked from mrq/bitsandbytes-rocm
refactoring
This commit is contained in:
parent
ee325f0215
commit
d358999e9e
|
@ -185,11 +185,10 @@ class MatmulLtState:
|
|||
idx = None
|
||||
is_training = True
|
||||
has_fp16_weights = True
|
||||
memory_efficient_backward = False
|
||||
use_pool = False
|
||||
formatB = F.get_special_format_str()
|
||||
|
||||
memory_efficient_backward = False
|
||||
|
||||
def reset_grads(self):
|
||||
self.CB = None
|
||||
self.CxB = None
|
||||
|
@ -198,6 +197,7 @@ class MatmulLtState:
|
|||
|
||||
self.CxBt = None
|
||||
self.SBt = None
|
||||
self.CBt = None
|
||||
|
||||
|
||||
class MatMul8bitLt(torch.autograd.Function):
|
||||
|
@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
A_dtype = A.dtype
|
||||
A = A.to(torch.float16)
|
||||
|
||||
assert (
|
||||
A.dtype == torch.float16
|
||||
), f"The input data type needs to be fp16 but {A.dtype} was found!"
|
||||
|
||||
# 1. Quantize A
|
||||
if len(A.shape) == 3:
|
||||
A = A.view(-1, A.shape[-1]).contiguous()
|
||||
|
@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
return grad_A, grad_B, None, grad_bias, None
|
||||
|
||||
|
||||
matmul = MatMul8bitLt.apply
|
||||
|
||||
|
||||
def matmul(
|
||||
A: tensor,
|
||||
B: tensor,
|
||||
|
|
Loading…
Reference in New Issue
Block a user