debug
This commit is contained in:
parent
4da2227fcb
commit
5d65817101
|
@ -370,8 +370,6 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
if state.threshold > 0.0 and subA is not None:
|
||||
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
|
||||
|
||||
raise NotImplementedError("!!")
|
||||
|
||||
if req_gradA:
|
||||
if state.CBt is not None:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
|
|
|
@ -237,7 +237,9 @@ class Linear8bitLt(nn.Linear):
|
|||
if threshold > 0.0 and not has_fp16_weights:
|
||||
self.state.use_pool = True
|
||||
|
||||
self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights)
|
||||
self.weight = Int8Params(
|
||||
self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights
|
||||
)
|
||||
|
||||
def init_8bit_state(self):
|
||||
self.state.CB = self.weight.CB
|
||||
|
|
Loading…
Reference in New Issue
Block a user