refactoring

This commit is contained in:
dbaranchuk 2022-08-23 23:51:00 +03:00
parent 8ae9bb23ad
commit 1753aa0418

View File

@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function):
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
elif state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None and state.has_fp16_weights:
CBt = state.CBt
elif state.CxBt is None:
assert state.CBt is None
CB = state.CB.half()
SCB = state.SCB.unsquezee(1).half()
SCBt = state.SCBt.unsquezee(1).half()
Bt = (CB * SCB).t().contiguous()
CBt = (Bt / SCBt).t().to(torch.int8)
if state.CxBt is None:
if state.has_fp16_weights:
CBt = state.CBt
else:
# Restore CBt from CB
assert state.CBt is None, "CBt should not be stored in state"
CB = state.CB.half()
SCB = state.SCB.unsquezee(1).half()
SCBt = state.SCBt.unsquezee(1).half()
Bt = (CB * SCB).t().contiguous()
CBt = (Bt / SCBt).t().to(torch.int8)
CxBt, SBt = F.transform(
CBt, to_order=formatB, transpose=True
)
# intentionally, do not store CxBt into state
CxBt, SBt = F.transform(
CBt, to_order=formatB, transpose=True
)
else:
CxBt = state.CxBt
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)