This commit is contained in:
Tim Dettmers 2022-09-20 06:36:32 +03:00 committed by justheuristic
parent cff3a71599
commit 9b7d307b8c

View File

@ -381,7 +381,7 @@ class MatMul8bitLt(torch.autograd.Function):
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).div(127.0))
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception('State must contain either CBt or CB matrix for backward')