Merge pull request #213 from tonylins/dev/fix_no_absmax
Gix a bug in (de)quantize_no_absmax with multiple GPUs
This commit is contained in:
commit
c7875533ce
|
@ -656,9 +656,11 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
|||
torch.Tensor:
|
||||
Quantized 8-bit tensor.
|
||||
'''
|
||||
prev_device = pre_call(A.device)
|
||||
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
|
||||
is_on_gpu([A, out])
|
||||
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
||||
post_call(prev_device)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -683,9 +685,11 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
|||
torch.Tensor:
|
||||
32-bit output tensor.
|
||||
'''
|
||||
prev_device = pre_call(A.device)
|
||||
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
|
||||
is_on_gpu([code, A, out])
|
||||
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
||||
post_call(prev_device)
|
||||
return out
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user