diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index aef6971..6278db9 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -458,6 +458,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra """ + prev_device = pre_call(A.device) if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -479,6 +480,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra is_on_gpu([code, A, absmax, out, rand]) cblocksize = ct.c_int32(blocksize) if rand is not None: + is_on_gpu([code, A, out, absmax, rand]) assert blocksize==4096 assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) @@ -489,6 +491,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: + is_on_gpu([code, A, out, absmax]) if A.dtype == torch.float32: lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: @@ -499,6 +502,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra # cpu assert rand is None lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + post_call(A.device) return out, (absmax, code) @@ -537,6 +541,7 @@ def dequantize_blockwise( Dequantized tensor (default: float32) """ assert quant_state is not None or absmax is not None + device = pre_call(A.device) if code is None and quant_state is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -561,6 +566,7 @@ def dequantize_blockwise( raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + post_call(A.device) return out