This commit is contained in:
dbaranchuk 2022-09-11 21:41:46 +03:00
parent 4dd475ced4
commit e2a75769f2

View File

@ -388,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function):
grad_bias = grad_output.sum(0)
# Cast grad_A back to grad_output_dtype
grad_output.to(grad_output_dtype)
grad_output = grad_output.to(grad_output_dtype)
return grad_A, grad_B, None, grad_bias, None