Added pre/post call to all lib calls. Fixes #120

This commit is contained in:
Tim Dettmers 2023-04-11 09:36:56 -07:00
parent 29ab3a6b14
commit 2bb5c00ba9

View File

@ -770,6 +770,8 @@ def optimizer_update_32bit(
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
)
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec])
if g.dtype == torch.float32 and state1.dtype == torch.float32:
str2optimizer32bit[optimizer_name][0](
get_ptr(g),
@ -812,6 +814,7 @@ def optimizer_update_32bit(
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
post_call(prev_device)
def optimizer_update_8bit(
@ -890,6 +893,8 @@ def optimizer_update_8bit(
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit[optimizer_name][0](
get_ptr(p),
@ -942,6 +947,7 @@ def optimizer_update_8bit(
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
post_call(prev_device)
def optimizer_update_8bit_blockwise(
@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise(
skip_zeros=False,
) -> None:
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
str2optimizer8bit_blockwise[optimizer_name][0](
get_ptr(p),
@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise(
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
)
post_call(prev_device)
def percentile_clipping(
@ -1023,6 +1032,7 @@ def percentile_clipping(
The current optimiation steps (number of past gradient norms).
"""
prev_device = pre_call(grad.device)
is_on_gpu([grad, gnorm_vec])
if grad.dtype == torch.float32:
lib.cpercentile_clipping_g32(
@ -1040,6 +1050,7 @@ def percentile_clipping(
)
else:
raise ValueError(f"Gradient type {grad.dtype} not supported!")
post_call(prev_device)
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
vals, idx = torch.sort(gnorm_vec)
@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
)
nnz = cooA.nnz
prev_device = pre_call(B.device)
assert cooA.rowidx.numel() == nnz
assert cooA.colidx.numel() == nnz
assert cooA.values.numel() == nnz
@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB,
)
# else: assertion error
post_call(prev_device)
return out