review
This commit is contained in:
parent
cff3a71599
commit
9b7d307b8c
|
@ -381,7 +381,7 @@ class MatMul8bitLt(torch.autograd.Function):
|
|||
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
|
||||
elif state.CB is not None:
|
||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).div(127.0))
|
||||
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0))
|
||||
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
|
||||
else:
|
||||
raise Exception('State must contain either CBt or CB matrix for backward')
|
||||
|
|
Loading…
Reference in New Issue
Block a user