refactoring

This commit is contained in:
dbaranchuk 2022-08-23 23:51:00 +03:00
parent 8ae9bb23ad
commit 1753aa0418

View File

@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function):
subA = A[:, idx] subA = A[:, idx]
state.subB = B[:, idx].t().contiguous() state.subB = B[:, idx].t().contiguous()
state.idx = idx state.idx = idx
elif state.CxB is None: else:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions if state.CxB is None:
# we also need to convert it to the turing/ampere format # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) # we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
if not state.has_fp16_weights and state.CxB is None: if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function):
if req_gradA: if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, "col32") C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None and state.has_fp16_weights: if state.CxBt is None:
CBt = state.CBt if state.has_fp16_weights:
elif state.CxBt is None: CBt = state.CBt
assert state.CBt is None else:
CB = state.CB.half() # Restore CBt from CB
SCB = state.SCB.unsquezee(1).half() assert state.CBt is None, "CBt should not be stored in state"
SCBt = state.SCBt.unsquezee(1).half() CB = state.CB.half()
Bt = (CB * SCB).t().contiguous() SCB = state.SCB.unsquezee(1).half()
CBt = (Bt / SCBt).t().to(torch.int8) SCBt = state.SCBt.unsquezee(1).half()
Bt = (CB * SCB).t().contiguous()
CBt = (Bt / SCBt).t().to(torch.int8)
CxBt, SBt = F.transform( # intentionally, do not store CxBt into state
CBt, to_order=formatB, transpose=True CxBt, SBt = F.transform(
) CBt, to_order=formatB, transpose=True
)
else:
CxBt = state.CxBt
gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt) gradA32, SgradA32 = F.igemmlt(C32grad, CxBt, Sgrad, SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)