forked from mrq/bitsandbytes-rocm
refactoring
This commit is contained in:
parent
ee325f0215
commit
d358999e9e
|
@ -185,11 +185,10 @@ class MatmulLtState:
|
||||||
idx = None
|
idx = None
|
||||||
is_training = True
|
is_training = True
|
||||||
has_fp16_weights = True
|
has_fp16_weights = True
|
||||||
|
memory_efficient_backward = False
|
||||||
use_pool = False
|
use_pool = False
|
||||||
formatB = F.get_special_format_str()
|
formatB = F.get_special_format_str()
|
||||||
|
|
||||||
memory_efficient_backward = False
|
|
||||||
|
|
||||||
def reset_grads(self):
|
def reset_grads(self):
|
||||||
self.CB = None
|
self.CB = None
|
||||||
self.CxB = None
|
self.CxB = None
|
||||||
|
@ -198,6 +197,7 @@ class MatmulLtState:
|
||||||
|
|
||||||
self.CxBt = None
|
self.CxBt = None
|
||||||
self.SBt = None
|
self.SBt = None
|
||||||
|
self.CBt = None
|
||||||
|
|
||||||
|
|
||||||
class MatMul8bitLt(torch.autograd.Function):
|
class MatMul8bitLt(torch.autograd.Function):
|
||||||
|
@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
A_dtype = A.dtype
|
A_dtype = A.dtype
|
||||||
A = A.to(torch.float16)
|
A = A.to(torch.float16)
|
||||||
|
|
||||||
assert (
|
|
||||||
A.dtype == torch.float16
|
|
||||||
), f"The input data type needs to be fp16 but {A.dtype} was found!"
|
|
||||||
|
|
||||||
# 1. Quantize A
|
# 1. Quantize A
|
||||||
if len(A.shape) == 3:
|
if len(A.shape) == 3:
|
||||||
A = A.view(-1, A.shape[-1]).contiguous()
|
A = A.view(-1, A.shape[-1]).contiguous()
|
||||||
|
@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
||||||
return grad_A, grad_B, None, grad_bias, None
|
return grad_A, grad_B, None, grad_bias, None
|
||||||
|
|
||||||
|
|
||||||
matmul = MatMul8bitLt.apply
|
|
||||||
|
|
||||||
|
|
||||||
def matmul(
|
def matmul(
|
||||||
A: tensor,
|
A: tensor,
|
||||||
B: tensor,
|
B: tensor,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user