forked from mrq/bitsandbytes-rocm
refactoring
This commit is contained in:
parent
8ae9bb23ad
commit
1753aa0418
|
@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
subA = A[:, idx]
|
||||
state.subB = B[:, idx].t().contiguous()
|
||||
state.idx = idx
|
||||
elif state.CxB is None:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
if state.CxB is None:
|
||||
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
|
||||
# we also need to convert it to the turing/ampere format
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
else:
|
||||
if not state.has_fp16_weights and state.CxB is None:
|
||||
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
|
||||
|
@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
|
||||
if req_gradA:
|
||||
C32grad, Sgrad = F.transform(Cgrad, "col32")
|
||||
if state.CxBt is None and state.has_fp16_weights:
|
||||
CBt = state.CBt
|
||||
elif state.CxBt is None:
|
||||
assert state.CBt is None
|
||||
CB = state.CB.half()
|
||||
SCB = state.SCB.unsquezee(1).half()
|
||||
SCBt = state.SCBt.unsquezee(1).half()
|
||||
Bt = (CB * SCB).t().contiguous()
|
||||
CBt = (Bt / SCBt).t().to(torch.int8)
|
||||
if state.CxBt is None:
|
||||
if state.has_fp16_weights:
|
||||
CBt = state.CBt
|
||||
else:
|
||||
# Restore CBt from CB
|
||||
assert state.CBt is None, "CBt should not be stored in state"
|
||||
CB = state.CB.half()
|
||||
SCB = state.SCB.unsquezee(1).half()
|
||||
SCBt = state.SCBt.unsquezee(1).half()
|
||||
Bt = (CB * SCB).t().contiguous()
|
||||
CBt = (Bt / SCBt).t().to(torch.int8)
|
||||
|
||||
CxBt, SBt = F.transform(
|
||||
CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
# intentionally, do not store CxBt into state
|
||||
CxBt, SBt = F.transform(
|
||||
CBt, to_order=formatB, transpose=True
|
||||
)
|
||||
else:
|
||||
CxBt = state.CxBt
|
||||
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
|
||||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user