Added pre/post call to all lib calls. Fixes #120
This commit is contained in:
parent
29ab3a6b14
commit
2bb5c00ba9
|
@ -770,6 +770,8 @@ def optimizer_update_32bit(
|
|||
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
|
||||
)
|
||||
|
||||
prev_device = pre_call(g.device)
|
||||
is_on_gpu([g, p, state1, state2, unorm_vec])
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.float32:
|
||||
str2optimizer32bit[optimizer_name][0](
|
||||
get_ptr(g),
|
||||
|
@ -812,6 +814,7 @@ def optimizer_update_32bit(
|
|||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
post_call(prev_device)
|
||||
|
||||
|
||||
def optimizer_update_8bit(
|
||||
|
@ -890,6 +893,8 @@ def optimizer_update_8bit(
|
|||
if max_unorm > 0.0:
|
||||
param_norm = torch.norm(p.data.float())
|
||||
|
||||
prev_device = pre_call(g.device)
|
||||
is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2])
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit[optimizer_name][0](
|
||||
get_ptr(p),
|
||||
|
@ -942,6 +947,7 @@ def optimizer_update_8bit(
|
|||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
post_call(prev_device)
|
||||
|
||||
|
||||
def optimizer_update_8bit_blockwise(
|
||||
|
@ -964,6 +970,8 @@ def optimizer_update_8bit_blockwise(
|
|||
skip_zeros=False,
|
||||
) -> None:
|
||||
|
||||
prev_device = pre_call(g.device)
|
||||
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
|
||||
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
||||
str2optimizer8bit_blockwise[optimizer_name][0](
|
||||
get_ptr(p),
|
||||
|
@ -1008,6 +1016,7 @@ def optimizer_update_8bit_blockwise(
|
|||
raise ValueError(
|
||||
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
||||
)
|
||||
post_call(prev_device)
|
||||
|
||||
|
||||
def percentile_clipping(
|
||||
|
@ -1023,6 +1032,7 @@ def percentile_clipping(
|
|||
The current optimiation steps (number of past gradient norms).
|
||||
|
||||
"""
|
||||
prev_device = pre_call(grad.device)
|
||||
is_on_gpu([grad, gnorm_vec])
|
||||
if grad.dtype == torch.float32:
|
||||
lib.cpercentile_clipping_g32(
|
||||
|
@ -1040,6 +1050,7 @@ def percentile_clipping(
|
|||
)
|
||||
else:
|
||||
raise ValueError(f"Gradient type {grad.dtype} not supported!")
|
||||
post_call(prev_device)
|
||||
|
||||
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
|
||||
vals, idx = torch.sort(gnorm_vec)
|
||||
|
@ -1796,6 +1807,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
|
|||
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
|
||||
)
|
||||
nnz = cooA.nnz
|
||||
prev_device = pre_call(B.device)
|
||||
assert cooA.rowidx.numel() == nnz
|
||||
assert cooA.colidx.numel() == nnz
|
||||
assert cooA.values.numel() == nnz
|
||||
|
@ -1872,6 +1884,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
|
|||
ccolsB,
|
||||
)
|
||||
# else: assertion error
|
||||
post_call(prev_device)
|
||||
|
||||
return out
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user