2022-08-01 10:31:48 +00:00
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
|
#
|
|
|
|
# This source code is licensed under the MIT license found in the
|
2021-10-06 02:16:20 +00:00
|
|
|
# LICENSE file in the root directory of this source tree.
|
2022-06-30 15:14:20 +00:00
|
|
|
import ctypes as ct
|
2022-08-08 16:13:22 +00:00
|
|
|
import operator
|
2021-10-06 02:16:20 +00:00
|
|
|
import random
|
|
|
|
import torch
|
2022-11-04 02:49:50 +00:00
|
|
|
import itertools
|
2022-08-03 18:54:01 +00:00
|
|
|
|
|
|
|
from typing import Tuple
|
2021-10-06 02:16:20 +00:00
|
|
|
from torch import Tensor
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
from .cextension import COMPILED_WITH_CUDA, lib
|
2022-08-08 16:13:22 +00:00
|
|
|
from functools import reduce # Required in Python 3
|
|
|
|
|
|
|
|
# math.prod not compatible with python < 3.8
|
|
|
|
def prod(iterable):
|
|
|
|
return reduce(operator.mul, iterable, 1)
|
2022-07-01 14:16:10 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
name2qmap = {}
|
|
|
|
|
2022-07-01 14:16:10 +00:00
|
|
|
if COMPILED_WITH_CUDA:
|
2022-08-01 10:31:48 +00:00
|
|
|
"""C FUNCTIONS FOR OPTIMIZERS"""
|
2022-07-01 14:16:10 +00:00
|
|
|
str2optimizer32bit = {}
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer32bit["adam"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
2022-08-01 16:32:47 +00:00
|
|
|
str2optimizer32bit["momentum"] = (
|
|
|
|
lib.cmomentum32bit_g32,
|
|
|
|
lib.cmomentum32bit_g16,
|
|
|
|
)
|
|
|
|
str2optimizer32bit["rmsprop"] = (
|
|
|
|
lib.crmsprop32bit_g32,
|
|
|
|
lib.crmsprop32bit_g16,
|
|
|
|
)
|
|
|
|
str2optimizer32bit["adagrad"] = (
|
|
|
|
lib.cadagrad32bit_g32,
|
|
|
|
lib.cadagrad32bit_g16,
|
|
|
|
)
|
|
|
|
str2optimizer32bit["lars"] = (
|
|
|
|
lib.cmomentum32bit_g32,
|
|
|
|
lib.cmomentum32bit_g16,
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer32bit["lamb"] = (lib.cadam32bit_g32, lib.cadam32bit_g16)
|
2022-07-01 14:16:10 +00:00
|
|
|
|
|
|
|
str2optimizer8bit = {}
|
2022-08-01 16:32:47 +00:00
|
|
|
str2optimizer8bit["adam"] = (
|
|
|
|
lib.cadam_static_8bit_g32,
|
|
|
|
lib.cadam_static_8bit_g16,
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer8bit["momentum"] = (
|
|
|
|
lib.cmomentum_static_8bit_g32,
|
|
|
|
lib.cmomentum_static_8bit_g16,
|
|
|
|
)
|
|
|
|
str2optimizer8bit["rmsprop"] = (
|
|
|
|
lib.crmsprop_static_8bit_g32,
|
|
|
|
lib.crmsprop_static_8bit_g16,
|
|
|
|
)
|
2022-08-01 16:32:47 +00:00
|
|
|
str2optimizer8bit["lamb"] = (
|
|
|
|
lib.cadam_static_8bit_g32,
|
|
|
|
lib.cadam_static_8bit_g16,
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer8bit["lars"] = (
|
|
|
|
lib.cmomentum_static_8bit_g32,
|
|
|
|
lib.cmomentum_static_8bit_g16,
|
|
|
|
)
|
2022-07-01 14:16:10 +00:00
|
|
|
|
|
|
|
str2optimizer8bit_blockwise = {}
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer8bit_blockwise["adam"] = (
|
|
|
|
lib.cadam_8bit_blockwise_fp32,
|
|
|
|
lib.cadam_8bit_blockwise_fp16,
|
|
|
|
)
|
|
|
|
str2optimizer8bit_blockwise["momentum"] = (
|
|
|
|
lib.cmomentum_8bit_blockwise_fp32,
|
|
|
|
lib.cmomentum_8bit_blockwise_fp16,
|
|
|
|
)
|
|
|
|
str2optimizer8bit_blockwise["rmsprop"] = (
|
|
|
|
lib.crmsprop_8bit_blockwise_fp32,
|
|
|
|
lib.crmsprop_8bit_blockwise_fp16,
|
|
|
|
)
|
|
|
|
str2optimizer8bit_blockwise["adagrad"] = (
|
|
|
|
lib.cadagrad_8bit_blockwise_fp32,
|
|
|
|
lib.cadagrad_8bit_blockwise_fp16,
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
class CUBLAS_Context(object):
|
|
|
|
_instance = None
|
|
|
|
|
|
|
|
def __init__(self):
|
2022-08-01 10:31:48 +00:00
|
|
|
raise RuntimeError("Call get_instance() instead")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def initialize(self):
|
|
|
|
self.context = {}
|
2022-08-01 10:31:48 +00:00
|
|
|
# prev_device = torch.cuda.current_device()
|
|
|
|
# for i in range(torch.cuda.device_count()):
|
2022-07-22 21:41:05 +00:00
|
|
|
# torch.cuda.set_device(torch.device('cuda', i))
|
|
|
|
# self.context.append(ct.c_void_p(lib.get_context()))
|
2022-08-01 10:31:48 +00:00
|
|
|
# torch.cuda.set_device(prev_device)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_instance(cls):
|
|
|
|
if cls._instance is None:
|
|
|
|
cls._instance = cls.__new__(cls)
|
|
|
|
cls._instance.initialize()
|
|
|
|
return cls._instance
|
|
|
|
|
|
|
|
def get_context(self, device):
|
|
|
|
if device.index not in self.context:
|
|
|
|
prev_device = torch.cuda.current_device()
|
|
|
|
torch.cuda.set_device(device)
|
|
|
|
self.context[device.index] = ct.c_void_p(lib.get_context())
|
|
|
|
torch.cuda.set_device(prev_device)
|
|
|
|
return self.context[device.index]
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
class Cusparse_Context(object):
|
|
|
|
_instance = None
|
|
|
|
|
|
|
|
def __init__(self):
|
2022-08-01 10:31:48 +00:00
|
|
|
raise RuntimeError("Call get_instance() instead")
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def initialize(self):
|
|
|
|
self.context = ct.c_void_p(lib.get_cusparse())
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def get_instance(cls):
|
|
|
|
if cls._instance is None:
|
|
|
|
cls._instance = cls.__new__(cls)
|
|
|
|
cls._instance.initialize()
|
|
|
|
return cls._instance
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def create_linear_map(signed=True):
|
|
|
|
if signed:
|
|
|
|
return torch.linspace(-1.0, 1.0, 256)
|
|
|
|
else:
|
|
|
|
return torch.linspace(0.0, 1.0, 256)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-11-04 02:49:50 +00:00
|
|
|
def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2):
|
|
|
|
e = exponent_bits
|
|
|
|
p = precision_bits
|
|
|
|
assert e+p == 7
|
|
|
|
# the exponent is biased to 2^(e-1) -1 == 0
|
|
|
|
evalues = []
|
|
|
|
pvalues = []
|
|
|
|
for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)):
|
|
|
|
evalues.append(2**val)
|
|
|
|
|
|
|
|
|
|
|
|
lst = list(itertools.product([0, 1], repeat=precision_bits))
|
|
|
|
for bit_pattern in lst:
|
|
|
|
value = 1
|
|
|
|
for i, pval in enumerate(list(bit_pattern)):
|
|
|
|
value += pval*(2**-(i+1))
|
|
|
|
pvalues.append(value)
|
|
|
|
|
|
|
|
assert len(evalues)*len(pvalues) == 128
|
|
|
|
values = []
|
|
|
|
for ev in evalues:
|
|
|
|
for pv in pvalues:
|
|
|
|
values.append(-ev*pv)
|
|
|
|
values.append(ev*pv)
|
|
|
|
values.sort()
|
|
|
|
code = torch.Tensor(values)
|
|
|
|
code /= code.max()
|
|
|
|
code[127] = 0
|
|
|
|
|
|
|
|
return code
|
|
|
|
|
|
|
|
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def create_dynamic_map(signed=True, n=7):
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
Creates the dynamic quantiztion map.
|
|
|
|
|
|
|
|
The dynamic data type is made up of a dynamic exponent and
|
|
|
|
fraction. As the exponent increase from 0 to -7 the number
|
|
|
|
of bits available for the fraction shrinks.
|
|
|
|
|
|
|
|
This is a generalization of the dynamic type where a certain
|
|
|
|
number of the bits and be reserved for the linear quantization
|
|
|
|
region (the fraction). n determines the maximum number of
|
|
|
|
exponent bits.
|
|
|
|
|
|
|
|
For more details see
|
|
|
|
(8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561]
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
data = []
|
|
|
|
# these are additional items that come from the case
|
|
|
|
# where all the exponent bits are zero and no
|
|
|
|
# indicator bit is present
|
2022-08-01 10:31:48 +00:00
|
|
|
additional_items = 2 ** (7 - n) - 1
|
|
|
|
if not signed:
|
|
|
|
additional_items = 2 * additional_items
|
2021-10-06 02:16:20 +00:00
|
|
|
for i in range(n):
|
2022-08-01 16:32:47 +00:00
|
|
|
fraction_items = (
|
|
|
|
2 ** (i + 7 - n) + 1 if signed else 2 ** (i + 7 - n + 1) + 1
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
boundaries = torch.linspace(0.1, 1, fraction_items)
|
2022-08-01 10:31:48 +00:00
|
|
|
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
|
|
|
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
2021-10-06 02:16:20 +00:00
|
|
|
if signed:
|
2022-08-01 10:31:48 +00:00
|
|
|
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
if additional_items > 0:
|
2022-08-01 10:31:48 +00:00
|
|
|
boundaries = torch.linspace(0.1, 1, additional_items + 1)
|
|
|
|
means = (boundaries[:-1] + boundaries[1:]) / 2.0
|
|
|
|
data += ((10 ** (-(n - 1) + i)) * means).tolist()
|
2021-10-06 02:16:20 +00:00
|
|
|
if signed:
|
2022-08-01 10:31:48 +00:00
|
|
|
data += (-(10 ** (-(n - 1) + i)) * means).tolist()
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
data.append(0)
|
|
|
|
data.append(1.0)
|
|
|
|
data.sort()
|
|
|
|
return Tensor(data)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def get_special_format_str():
|
2022-08-23 23:00:26 +00:00
|
|
|
if not torch.cuda.is_available(): return 'col_turing'
|
2022-07-22 21:41:05 +00:00
|
|
|
major, minor = torch.cuda.get_device_capability()
|
2022-08-23 23:00:26 +00:00
|
|
|
if major <= 7:
|
2022-08-01 10:31:48 +00:00
|
|
|
return "col_turing"
|
|
|
|
elif major == 8:
|
|
|
|
return "col_ampere"
|
|
|
|
else:
|
|
|
|
return "col_turing"
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-03 16:05:37 +00:00
|
|
|
|
|
|
|
def is_on_gpu(tensors):
|
|
|
|
on_gpu = True
|
|
|
|
for t in tensors:
|
|
|
|
if t is None: continue # NULL pointers are fine
|
|
|
|
on_gpu &= t.device.type == 'cuda'
|
|
|
|
return on_gpu
|
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
def get_ptr(A: Tensor) -> ct.c_void_p:
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
Get the ctypes pointer from a PyTorch Tensor.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
A : torch.tensor
|
|
|
|
The PyTorch tensor.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
ctypes.c_void_p
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
|
|
|
if A is None:
|
|
|
|
return None
|
|
|
|
else:
|
2022-08-16 17:56:17 +00:00
|
|
|
return ct.c_void_p(A.data.data_ptr())
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def pre_call(device):
|
|
|
|
prev_device = torch.cuda.current_device()
|
|
|
|
torch.cuda.set_device(device)
|
|
|
|
return prev_device
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def post_call(prev_device):
|
|
|
|
torch.cuda.set_device(prev_device)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def get_transform_func(dtype, orderA, orderOut, transpose=False):
|
|
|
|
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
|
|
|
|
if not hasattr(lib, name):
|
|
|
|
print(name)
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
|
|
|
return getattr(lib, name)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def get_transform_buffer(
|
|
|
|
shape, dtype, device, to_order, from_order="row", transpose=False
|
|
|
|
):
|
|
|
|
# init_func = torch.empty
|
2022-07-22 21:41:05 +00:00
|
|
|
init_func = torch.zeros
|
|
|
|
dims = len(shape)
|
|
|
|
|
|
|
|
if dims == 2:
|
|
|
|
rows = shape[0]
|
|
|
|
elif dims == 3:
|
2022-08-01 10:31:48 +00:00
|
|
|
rows = shape[0] * shape[1]
|
2022-07-22 21:41:05 +00:00
|
|
|
cols = shape[-1]
|
|
|
|
|
|
|
|
state = (shape, to_order)
|
|
|
|
if transpose:
|
|
|
|
# swap dims
|
|
|
|
tmp = rows
|
|
|
|
rows = cols
|
|
|
|
cols = tmp
|
|
|
|
state = (shape[::-1], to_order)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if to_order == "row" or to_order == "col":
|
2022-07-22 21:41:05 +00:00
|
|
|
return init_func(shape, dtype=dtype, device=device), state
|
2022-08-01 10:31:48 +00:00
|
|
|
elif to_order == "col32":
|
2022-07-22 21:41:05 +00:00
|
|
|
# blocks of 32 columns (padded)
|
2022-08-01 10:31:48 +00:00
|
|
|
cols = 32 * ((cols + 31) // 32)
|
2022-07-22 21:41:05 +00:00
|
|
|
return init_func((rows, cols), dtype=dtype, device=device), state
|
2022-08-01 10:31:48 +00:00
|
|
|
elif to_order == "col_turing":
|
2022-07-22 21:41:05 +00:00
|
|
|
# blocks of 32 columns and 8 rows
|
2022-08-01 10:31:48 +00:00
|
|
|
cols = 32 * ((cols + 31) // 32)
|
|
|
|
rows = 8 * ((rows + 7) // 8)
|
2022-07-22 21:41:05 +00:00
|
|
|
return init_func((rows, cols), dtype=dtype, device=device), state
|
2022-08-01 10:31:48 +00:00
|
|
|
elif to_order == "col_ampere":
|
2022-07-22 21:41:05 +00:00
|
|
|
# blocks of 32 columns and 32 rows
|
2022-08-01 10:31:48 +00:00
|
|
|
cols = 32 * ((cols + 31) // 32)
|
|
|
|
rows = 32 * ((rows + 31) // 32)
|
2022-07-22 21:41:05 +00:00
|
|
|
return init_func((rows, cols), dtype=dtype, device=device), state
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise NotImplementedError(f"To_order not supported: {to_order}")
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def nvidia_transform(
|
2022-08-01 16:32:47 +00:00
|
|
|
A,
|
|
|
|
to_order,
|
|
|
|
from_order="row",
|
|
|
|
out=None,
|
|
|
|
transpose=False,
|
|
|
|
state=None,
|
|
|
|
ld=None,
|
2022-08-01 10:31:48 +00:00
|
|
|
):
|
|
|
|
if state is None:
|
|
|
|
state = (A.shape, from_order)
|
|
|
|
else:
|
|
|
|
from_order = state[1]
|
|
|
|
if out is None:
|
|
|
|
out, new_state = get_transform_buffer(
|
|
|
|
state[0], A.dtype, A.device, to_order, state[1]
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
new_state = (state[1], to_order)
|
2022-07-22 21:41:05 +00:00
|
|
|
func = get_transform_func(A.dtype, from_order, to_order, transpose)
|
|
|
|
|
|
|
|
shape = state[0]
|
|
|
|
if len(shape) == 2:
|
|
|
|
dim1 = ct.c_int32(shape[0])
|
|
|
|
dim2 = ct.c_int32(shape[1])
|
|
|
|
elif ld is not None:
|
2022-08-08 16:13:22 +00:00
|
|
|
n = prod(shape)
|
|
|
|
dim1 = prod([shape[i] for i in ld])
|
2022-08-01 10:31:48 +00:00
|
|
|
dim2 = ct.c_int32(n // dim1)
|
2022-07-22 21:41:05 +00:00
|
|
|
dim1 = ct.c_int32(dim1)
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = ct.c_int32(shape[0] * shape[1])
|
2022-07-22 21:41:05 +00:00
|
|
|
dim2 = ct.c_int32(shape[2])
|
|
|
|
|
|
|
|
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
|
|
|
ptrA = get_ptr(A)
|
|
|
|
ptrOut = get_ptr(out)
|
|
|
|
func(ptr, get_ptr(A), get_ptr(out), dim1, dim2)
|
|
|
|
|
|
|
|
return out, new_state
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def estimate_quantiles(
|
|
|
|
A: Tensor, out: Tensor = None, offset: float = 1 / 512
|
|
|
|
) -> Tensor:
|
2021-10-06 02:16:20 +00:00
|
|
|
'''
|
|
|
|
Estimates 256 equidistant quantiles on the input tensor eCDF.
|
|
|
|
|
|
|
|
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
|
|
|
|
via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
|
|
|
|
and the extreme quantiles close to 0 and 1 have high variance / large estimation
|
|
|
|
errors. These large errors can be avoided by using the offset variable which trims
|
|
|
|
the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
|
|
|
|
trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
|
|
|
|
usually has a much lower error but is not a minimum entropy encoding. Given an offset
|
|
|
|
of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
A : torch.Tensor
|
|
|
|
The input tensor. Any shape.
|
|
|
|
out : torch.Tensor
|
|
|
|
Tensor with the 256 estimated quantiles.
|
|
|
|
offset : float
|
|
|
|
The offset for the first and last quantile from 0 and 1. Default: 1/512
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
torch.Tensor:
|
|
|
|
The 256 quantiles in float32 datatype.
|
|
|
|
'''
|
|
|
|
if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device)
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([A, out])
|
2021-10-06 02:16:20 +00:00
|
|
|
if A.dtype == torch.float32:
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cestimate_quantiles_fp32(
|
|
|
|
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
elif A.dtype == torch.float16:
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cestimate_quantiles_fp16(
|
|
|
|
get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise NotImplementedError(f"Not supported data type {A.dtype}")
|
2021-10-06 02:16:20 +00:00
|
|
|
return out
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-09-11 18:55:09 +00:00
|
|
|
def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, rand=None, out: Tensor = None, blocksize=4096) -> Tensor:
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
Quantize tensor A in blocks of size 4096 values.
|
|
|
|
|
|
|
|
Quantizes tensor A by dividing it into blocks of 4096 values.
|
|
|
|
Then the absolute maximum value within these blocks is calculated
|
|
|
|
for the non-linear quantization.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
A : torch.Tensor
|
|
|
|
The input tensor.
|
|
|
|
code : torch.Tensor
|
|
|
|
The quantization map.
|
|
|
|
absmax : torch.Tensor
|
|
|
|
The absmax values.
|
|
|
|
rand : torch.Tensor
|
|
|
|
The tensor for stochastic rounding.
|
|
|
|
out : torch.Tensor
|
|
|
|
The output tensor (8-bit).
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
torch.Tensor:
|
|
|
|
The 8-bit tensor.
|
|
|
|
tuple(torch.Tensor, torch.Tensor):
|
|
|
|
The quantization state to undo the quantization.
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
if code is None:
|
2022-08-01 10:31:48 +00:00
|
|
|
if "dynamic" not in name2qmap:
|
|
|
|
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
|
|
|
code = name2qmap["dynamic"]
|
2021-10-06 02:16:20 +00:00
|
|
|
code = code.to(A.device)
|
|
|
|
|
|
|
|
if absmax is None:
|
|
|
|
n = A.numel()
|
2022-09-11 18:55:09 +00:00
|
|
|
blocksize = (blocksize if A.device.type == 'cpu' else 4096)
|
|
|
|
blocks = n // blocksize
|
|
|
|
blocks += 1 if n % blocksize > 0 else 0
|
2021-10-06 02:16:20 +00:00
|
|
|
absmax = torch.zeros((blocks,), device=A.device)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if out is None:
|
|
|
|
out = torch.zeros_like(A, dtype=torch.uint8)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
if A.device.type != 'cpu':
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([code, A, absmax, out, rand])
|
2021-10-06 02:16:20 +00:00
|
|
|
if rand is not None:
|
|
|
|
assert rand.numel() >= 1024
|
|
|
|
rand_offset = random.randint(0, 1023)
|
|
|
|
if A.dtype == torch.float32:
|
2022-09-11 18:55:09 +00:00
|
|
|
lib.cquantize_blockwise_stochastic_fp32(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
2021-10-06 02:16:20 +00:00
|
|
|
elif A.dtype == torch.float16:
|
2022-09-11 18:55:09 +00:00
|
|
|
lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel()))
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
|
|
|
if A.dtype == torch.float32:
|
2022-09-11 18:55:09 +00:00
|
|
|
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
2021-10-06 02:16:20 +00:00
|
|
|
elif A.dtype == torch.float16:
|
2022-09-11 18:55:09 +00:00
|
|
|
lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel()))
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
|
|
|
# cpu
|
|
|
|
assert rand is None
|
2022-09-11 18:55:09 +00:00
|
|
|
lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
return out, (absmax, code)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def dequantize_blockwise(
|
|
|
|
A: Tensor,
|
|
|
|
quant_state: Tuple[Tensor, Tensor] = None,
|
|
|
|
absmax: Tensor = None,
|
|
|
|
code: Tensor = None,
|
|
|
|
out: Tensor = None,
|
|
|
|
blocksize: int = 4096,
|
|
|
|
) -> Tensor:
|
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
Dequantizes blockwise quantized values.
|
|
|
|
|
|
|
|
Dequantizes the tensor A with maximum absolute values absmax in
|
|
|
|
blocks of size 4096.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
A : torch.Tensor
|
|
|
|
The input 8-bit tensor.
|
|
|
|
quant_state : tuple(torch.Tensor, torch.Tensor)
|
2022-08-01 10:31:48 +00:00
|
|
|
Tuple of code and absmax values.
|
2021-10-06 02:16:20 +00:00
|
|
|
absmax : torch.Tensor
|
|
|
|
The absmax values.
|
|
|
|
code : torch.Tensor
|
|
|
|
The quantization map.
|
|
|
|
out : torch.Tensor
|
|
|
|
Dequantized output tensor (default: float32)
|
|
|
|
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
torch.Tensor:
|
|
|
|
Dequantized tensor (default: float32)
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
assert quant_state is not None or absmax is not None
|
|
|
|
if code is None and quant_state is None:
|
2022-08-01 10:31:48 +00:00
|
|
|
if "dynamic" not in name2qmap:
|
|
|
|
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
|
|
|
code = name2qmap["dynamic"]
|
2021-10-06 02:16:20 +00:00
|
|
|
code = code.to(A.device)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if out is None:
|
|
|
|
out = torch.zeros_like(A, dtype=torch.float32)
|
|
|
|
if quant_state is None:
|
|
|
|
quant_state = (absmax, code)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
|
|
|
if A.device.type != 'cpu':
|
2022-09-11 18:55:09 +00:00
|
|
|
if blocksize not in [2048, 4096]:
|
|
|
|
raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048 4096]")
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([A, out])
|
2021-10-06 02:16:20 +00:00
|
|
|
if out.dtype == torch.float32:
|
2022-09-11 18:55:09 +00:00
|
|
|
lib.cdequantize_blockwise_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
2021-10-06 02:16:20 +00:00
|
|
|
elif out.dtype == torch.float16:
|
2022-09-11 18:55:09 +00:00
|
|
|
lib.cdequantize_blockwise_fp16(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel()))
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}"
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-09-11 18:55:09 +00:00
|
|
|
lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel()))
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor:
|
2021-10-06 02:16:20 +00:00
|
|
|
if code is None:
|
2022-08-01 10:31:48 +00:00
|
|
|
if "dynamic" not in name2qmap:
|
|
|
|
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
|
|
|
code = name2qmap["dynamic"]
|
2021-10-06 02:16:20 +00:00
|
|
|
code = code.to(A.device)
|
|
|
|
|
|
|
|
absmax = torch.abs(A).max()
|
2022-08-01 10:31:48 +00:00
|
|
|
inp = A / absmax
|
2021-10-06 02:16:20 +00:00
|
|
|
out = quantize_no_absmax(inp, code, out)
|
|
|
|
return out, (absmax, code)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def dequantize(
|
|
|
|
A: Tensor,
|
|
|
|
quant_state: Tuple[Tensor, Tensor] = None,
|
|
|
|
absmax: Tensor = None,
|
|
|
|
code: Tensor = None,
|
|
|
|
out: Tensor = None,
|
|
|
|
) -> Tensor:
|
2021-10-06 02:16:20 +00:00
|
|
|
assert quant_state is not None or absmax is not None
|
|
|
|
if code is None and quant_state is None:
|
2022-08-01 10:31:48 +00:00
|
|
|
if "dynamic" not in name2qmap:
|
|
|
|
name2qmap["dynamic"] = create_dynamic_map().to(A.device)
|
|
|
|
code = name2qmap["dynamic"]
|
2021-10-06 02:16:20 +00:00
|
|
|
code = code.to(A.device)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if quant_state is None:
|
|
|
|
quant_state = (absmax, code)
|
2021-10-06 02:16:20 +00:00
|
|
|
out = dequantize_no_absmax(A, quant_state[1], out)
|
2022-08-01 10:31:48 +00:00
|
|
|
return out * quant_state[0]
|
2021-10-06 02:16:20 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
2021-10-06 02:16:20 +00:00
|
|
|
'''
|
|
|
|
Quantizes input tensor to 8-bit.
|
|
|
|
|
|
|
|
Quantizes the 32-bit input tensor `A` to the 8-bit output tensor
|
|
|
|
`out` using the quantization map `code`.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
A : torch.Tensor
|
|
|
|
The input tensor.
|
|
|
|
code : torch.Tensor
|
|
|
|
The quantization map.
|
|
|
|
out : torch.Tensor, optional
|
|
|
|
The output tensor. Needs to be of type byte.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
torch.Tensor:
|
|
|
|
Quantized 8-bit tensor.
|
|
|
|
'''
|
|
|
|
if out is None: out = torch.zeros_like(A, dtype=torch.uint8)
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([A, out])
|
2021-10-06 02:16:20 +00:00
|
|
|
lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
|
|
|
return out
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor:
|
2021-10-06 02:16:20 +00:00
|
|
|
'''
|
|
|
|
Dequantizes the 8-bit tensor to 32-bit.
|
|
|
|
|
|
|
|
Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via
|
|
|
|
the quantization map `code`.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
A : torch.Tensor
|
|
|
|
The 8-bit input tensor.
|
|
|
|
code : torch.Tensor
|
|
|
|
The quantization map.
|
|
|
|
out : torch.Tensor
|
|
|
|
The 32-bit output tensor.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
torch.Tensor:
|
|
|
|
32-bit output tensor.
|
|
|
|
'''
|
|
|
|
if out is None: out = torch.zeros_like(A, dtype=torch.float32)
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([code, A, out])
|
2021-10-06 02:16:20 +00:00
|
|
|
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()))
|
|
|
|
return out
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def optimizer_update_32bit(
|
|
|
|
optimizer_name: str,
|
|
|
|
g: Tensor,
|
|
|
|
p: Tensor,
|
|
|
|
state1: Tensor,
|
|
|
|
beta1: float,
|
|
|
|
eps: float,
|
|
|
|
step: int,
|
|
|
|
lr: float,
|
|
|
|
state2: Tensor = None,
|
|
|
|
beta2: float = 0.0,
|
|
|
|
weight_decay: float = 0.0,
|
|
|
|
gnorm_scale: float = 1.0,
|
|
|
|
unorm_vec: Tensor = None,
|
|
|
|
max_unorm: float = 0.0,
|
|
|
|
skip_zeros=False,
|
|
|
|
) -> None:
|
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
Performs an inplace optimizer update with one or two optimizer states.
|
|
|
|
|
|
|
|
Universal optimizer update for 32-bit state and 32/16-bit gradients/weights.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
optimizer_name : str
|
|
|
|
The name of the optimizer: {adam}.
|
|
|
|
g : torch.Tensor
|
|
|
|
Gradient tensor.
|
|
|
|
p : torch.Tensor
|
|
|
|
Parameter tensor.
|
|
|
|
state1 : torch.Tensor
|
|
|
|
Optimizer state 1.
|
|
|
|
beta1 : float
|
|
|
|
Optimizer beta1.
|
|
|
|
eps : float
|
|
|
|
Optimizer epsilon.
|
|
|
|
weight_decay : float
|
|
|
|
Weight decay.
|
|
|
|
step : int
|
|
|
|
Current optimizer step.
|
|
|
|
lr : float
|
|
|
|
The learning rate.
|
|
|
|
state2 : torch.Tensor
|
|
|
|
Optimizer state 2.
|
|
|
|
beta2 : float
|
|
|
|
Optimizer beta2.
|
|
|
|
gnorm_scale : float
|
|
|
|
The factor to rescale the gradient to the max clip value.
|
2021-10-21 01:37:44 +00:00
|
|
|
unorm_vec : torch.Tensor
|
|
|
|
The tensor for the update norm.
|
|
|
|
max_unorm : float
|
|
|
|
The maximum update norm relative to the weight norm.
|
|
|
|
skip_zeros : bool
|
|
|
|
Whether to skip zero-valued gradients or not (default: False).
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
param_norm = 0.0
|
|
|
|
if max_unorm > 0.0:
|
|
|
|
param_norm = torch.norm(p.data.float())
|
|
|
|
|
|
|
|
if optimizer_name not in str2optimizer32bit:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise NotImplementedError(
|
|
|
|
f'Optimizer not implemented: {optimizer_name}. Choices: {",".join(str2optimizer32bit.keys())}'
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
if g.dtype == torch.float32 and state1.dtype == torch.float32:
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer32bit[optimizer_name][0](
|
|
|
|
get_ptr(g),
|
|
|
|
get_ptr(p),
|
|
|
|
get_ptr(state1),
|
|
|
|
get_ptr(state2),
|
|
|
|
get_ptr(unorm_vec),
|
|
|
|
ct.c_float(max_unorm),
|
|
|
|
ct.c_float(param_norm),
|
|
|
|
ct.c_float(beta1),
|
|
|
|
ct.c_float(beta2),
|
|
|
|
ct.c_float(eps),
|
|
|
|
ct.c_float(weight_decay),
|
|
|
|
ct.c_int32(step),
|
|
|
|
ct.c_float(lr),
|
|
|
|
ct.c_float(gnorm_scale),
|
|
|
|
ct.c_bool(skip_zeros),
|
|
|
|
ct.c_int32(g.numel()),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
elif g.dtype == torch.float16 and state1.dtype == torch.float32:
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer32bit[optimizer_name][1](
|
|
|
|
get_ptr(g),
|
|
|
|
get_ptr(p),
|
|
|
|
get_ptr(state1),
|
|
|
|
get_ptr(state2),
|
|
|
|
get_ptr(unorm_vec),
|
|
|
|
ct.c_float(max_unorm),
|
|
|
|
ct.c_float(param_norm),
|
|
|
|
ct.c_float(beta1),
|
|
|
|
ct.c_float(beta2),
|
|
|
|
ct.c_float(eps),
|
|
|
|
ct.c_float(weight_decay),
|
|
|
|
ct.c_int32(step),
|
|
|
|
ct.c_float(lr),
|
|
|
|
ct.c_float(gnorm_scale),
|
|
|
|
ct.c_bool(skip_zeros),
|
|
|
|
ct.c_int32(g.numel()),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def optimizer_update_8bit(
|
|
|
|
optimizer_name: str,
|
|
|
|
g: Tensor,
|
|
|
|
p: Tensor,
|
|
|
|
state1: Tensor,
|
|
|
|
state2: Tensor,
|
|
|
|
beta1: float,
|
|
|
|
beta2: float,
|
|
|
|
eps: float,
|
|
|
|
step: int,
|
|
|
|
lr: float,
|
|
|
|
qmap1: Tensor,
|
|
|
|
qmap2: Tensor,
|
|
|
|
max1: Tensor,
|
|
|
|
max2: Tensor,
|
|
|
|
new_max1: Tensor,
|
|
|
|
new_max2: Tensor,
|
|
|
|
weight_decay: float = 0.0,
|
|
|
|
gnorm_scale: float = 1.0,
|
|
|
|
unorm_vec: Tensor = None,
|
|
|
|
max_unorm: float = 0.0,
|
|
|
|
) -> None:
|
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
Performs an inplace Adam update.
|
|
|
|
|
|
|
|
Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights.
|
|
|
|
Uses AdamW formulation if weight decay > 0.0.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
optimizer_name : str
|
|
|
|
The name of the optimizer. Choices {adam, momentum}
|
|
|
|
g : torch.Tensor
|
|
|
|
Gradient tensor.
|
|
|
|
p : torch.Tensor
|
|
|
|
Parameter tensor.
|
|
|
|
state1 : torch.Tensor
|
|
|
|
Adam state 1.
|
|
|
|
state2 : torch.Tensor
|
|
|
|
Adam state 2.
|
|
|
|
beta1 : float
|
|
|
|
Adam beta1.
|
|
|
|
beta2 : float
|
|
|
|
Adam beta2.
|
|
|
|
eps : float
|
|
|
|
Adam epsilon.
|
|
|
|
weight_decay : float
|
|
|
|
Weight decay.
|
|
|
|
step : int
|
|
|
|
Current optimizer step.
|
|
|
|
lr : float
|
|
|
|
The learning rate.
|
|
|
|
qmap1 : torch.Tensor
|
|
|
|
Quantization map for first Adam state.
|
|
|
|
qmap2 : torch.Tensor
|
|
|
|
Quantization map for second Adam state.
|
|
|
|
max1 : torch.Tensor
|
|
|
|
Max value for first Adam state update.
|
|
|
|
max2 : torch.Tensor
|
|
|
|
Max value for second Adam state update.
|
|
|
|
new_max1 : torch.Tensor
|
|
|
|
Max value for the next Adam update of the first state.
|
|
|
|
new_max2 : torch.Tensor
|
|
|
|
Max value for the next Adam update of the second state.
|
|
|
|
gnorm_scale : float
|
|
|
|
The factor to rescale the gradient to the max clip value.
|
2021-10-21 01:37:44 +00:00
|
|
|
unorm_vec : torch.Tensor
|
|
|
|
The tensor for the update norm.
|
|
|
|
max_unorm : float
|
|
|
|
The maximum update norm relative to the weight norm.
|
2022-08-01 10:31:48 +00:00
|
|
|
"""
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
param_norm = 0.0
|
|
|
|
if max_unorm > 0.0:
|
|
|
|
param_norm = torch.norm(p.data.float())
|
|
|
|
|
|
|
|
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer8bit[optimizer_name][0](
|
|
|
|
get_ptr(p),
|
|
|
|
get_ptr(g),
|
|
|
|
get_ptr(state1),
|
|
|
|
get_ptr(state2),
|
|
|
|
get_ptr(unorm_vec),
|
|
|
|
ct.c_float(max_unorm),
|
|
|
|
ct.c_float(param_norm),
|
|
|
|
ct.c_float(beta1),
|
|
|
|
ct.c_float(beta2),
|
|
|
|
ct.c_float(eps),
|
|
|
|
ct.c_int32(step),
|
|
|
|
ct.c_float(lr),
|
|
|
|
get_ptr(qmap1),
|
|
|
|
get_ptr(qmap2),
|
|
|
|
get_ptr(max1),
|
|
|
|
get_ptr(max2),
|
|
|
|
get_ptr(new_max1),
|
|
|
|
get_ptr(new_max2),
|
|
|
|
ct.c_float(weight_decay),
|
|
|
|
ct.c_float(gnorm_scale),
|
|
|
|
ct.c_int32(g.numel()),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer8bit[optimizer_name][1](
|
|
|
|
get_ptr(p),
|
|
|
|
get_ptr(g),
|
|
|
|
get_ptr(state1),
|
|
|
|
get_ptr(state2),
|
|
|
|
get_ptr(unorm_vec),
|
|
|
|
ct.c_float(max_unorm),
|
|
|
|
ct.c_float(param_norm),
|
|
|
|
ct.c_float(beta1),
|
|
|
|
ct.c_float(beta2),
|
|
|
|
ct.c_float(eps),
|
|
|
|
ct.c_int32(step),
|
|
|
|
ct.c_float(lr),
|
|
|
|
get_ptr(qmap1),
|
|
|
|
get_ptr(qmap2),
|
|
|
|
get_ptr(max1),
|
|
|
|
get_ptr(max2),
|
|
|
|
get_ptr(new_max1),
|
|
|
|
get_ptr(new_max2),
|
|
|
|
ct.c_float(weight_decay),
|
|
|
|
ct.c_float(gnorm_scale),
|
|
|
|
ct.c_int32(g.numel()),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def optimizer_update_8bit_blockwise(
|
|
|
|
optimizer_name: str,
|
|
|
|
g: Tensor,
|
|
|
|
p: Tensor,
|
|
|
|
state1: Tensor,
|
|
|
|
state2: Tensor,
|
|
|
|
beta1: float,
|
|
|
|
beta2: float,
|
|
|
|
eps: float,
|
|
|
|
step: int,
|
|
|
|
lr: float,
|
|
|
|
qmap1: Tensor,
|
|
|
|
qmap2: Tensor,
|
|
|
|
absmax1: Tensor,
|
|
|
|
absmax2: Tensor,
|
|
|
|
weight_decay: float = 0.0,
|
|
|
|
gnorm_scale: float = 1.0,
|
|
|
|
skip_zeros=False,
|
|
|
|
) -> None:
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
if g.dtype == torch.float32 and state1.dtype == torch.uint8:
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer8bit_blockwise[optimizer_name][0](
|
|
|
|
get_ptr(p),
|
|
|
|
get_ptr(g),
|
|
|
|
get_ptr(state1),
|
|
|
|
get_ptr(state2),
|
|
|
|
ct.c_float(beta1),
|
|
|
|
ct.c_float(beta2),
|
|
|
|
ct.c_float(eps),
|
|
|
|
ct.c_int32(step),
|
|
|
|
ct.c_float(lr),
|
|
|
|
get_ptr(qmap1),
|
|
|
|
get_ptr(qmap2),
|
|
|
|
get_ptr(absmax1),
|
|
|
|
get_ptr(absmax2),
|
|
|
|
ct.c_float(weight_decay),
|
|
|
|
ct.c_float(gnorm_scale),
|
|
|
|
ct.c_bool(skip_zeros),
|
|
|
|
ct.c_int32(g.numel()),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
|
2022-08-01 10:31:48 +00:00
|
|
|
str2optimizer8bit_blockwise[optimizer_name][1](
|
|
|
|
get_ptr(p),
|
|
|
|
get_ptr(g),
|
|
|
|
get_ptr(state1),
|
|
|
|
get_ptr(state2),
|
|
|
|
ct.c_float(beta1),
|
|
|
|
ct.c_float(beta2),
|
|
|
|
ct.c_float(eps),
|
|
|
|
ct.c_int32(step),
|
|
|
|
ct.c_float(lr),
|
|
|
|
get_ptr(qmap1),
|
|
|
|
get_ptr(qmap2),
|
|
|
|
get_ptr(absmax1),
|
|
|
|
get_ptr(absmax2),
|
|
|
|
ct.c_float(weight_decay),
|
|
|
|
ct.c_float(gnorm_scale),
|
|
|
|
ct.c_bool(skip_zeros),
|
|
|
|
ct.c_int32(g.numel()),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}"
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def percentile_clipping(
|
|
|
|
grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5
|
|
|
|
):
|
2021-10-06 02:16:20 +00:00
|
|
|
"""Applies percentile clipping
|
|
|
|
|
|
|
|
grad: torch.Tensor
|
|
|
|
The gradient tensor.
|
|
|
|
gnorm_vec: torch.Tensor
|
|
|
|
Vector of gradient norms. 100 elements expected.
|
|
|
|
step: int
|
|
|
|
The current optimiation steps (number of past gradient norms).
|
|
|
|
|
|
|
|
"""
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([grad, gnorm_vec])
|
2021-10-06 02:16:20 +00:00
|
|
|
if grad.dtype == torch.float32:
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cpercentile_clipping_g32(
|
|
|
|
get_ptr(grad),
|
|
|
|
get_ptr(gnorm_vec),
|
|
|
|
ct.c_int32(step),
|
|
|
|
ct.c_int32(grad.numel()),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
elif grad.dtype == torch.float16:
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cpercentile_clipping_g16(
|
|
|
|
get_ptr(grad),
|
|
|
|
get_ptr(gnorm_vec),
|
|
|
|
ct.c_int32(step),
|
|
|
|
ct.c_int32(grad.numel()),
|
|
|
|
)
|
2021-10-06 02:16:20 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(f"Gradient type {grad.dtype} not supported!")
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
current_gnorm = torch.sqrt(gnorm_vec[step % 100])
|
|
|
|
vals, idx = torch.sort(gnorm_vec)
|
|
|
|
clip_value = torch.sqrt(vals[percentile])
|
|
|
|
gnorm_scale = 1.0
|
|
|
|
|
|
|
|
if current_gnorm > clip_value:
|
2022-08-01 10:31:48 +00:00
|
|
|
gnorm_scale = clip_value / current_gnorm
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
return current_gnorm, clip_value, gnorm_scale
|
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def histogram_scatter_add_2d(
|
|
|
|
histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor
|
|
|
|
):
|
2021-10-06 02:16:20 +00:00
|
|
|
assert len(histogram.shape) == 2
|
|
|
|
assert histogram.dtype == torch.float32
|
|
|
|
assert source.dtype == torch.float32
|
|
|
|
assert index1.dtype == torch.int32
|
|
|
|
assert index2.dtype == torch.int32
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
assert histogram.device.type == "cuda"
|
|
|
|
assert index1.device.type == "cuda"
|
|
|
|
assert index2.device.type == "cuda"
|
|
|
|
assert source.device.type == "cuda"
|
2021-10-06 02:16:20 +00:00
|
|
|
|
|
|
|
maxdim1 = ct.c_int32(histogram.shape[0])
|
|
|
|
n = ct.c_int32(index1.numel())
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([histogram, index1, index2d, source])
|
2021-10-06 02:16:20 +00:00
|
|
|
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
|
|
|
|
if not torch.cuda.is_initialized(): torch.cuda.init()
|
|
|
|
if A.dtype != expected_type or B.dtype != expected_type:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise TypeError(
|
|
|
|
f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
sA = A.shape
|
|
|
|
sB = B.shape
|
|
|
|
tA = transposed_A
|
|
|
|
tB = transposed_B
|
|
|
|
|
|
|
|
correct = True
|
|
|
|
|
|
|
|
if len(sA) == 2 and len(sB) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
if not tA and not tB and A.shape[1] != B.shape[0]:
|
|
|
|
correct = False
|
|
|
|
elif tA and not tB and A.shape[0] != B.shape[0]:
|
|
|
|
correct = False
|
|
|
|
elif tA and tB and A.shape[0] != B.shape[1]:
|
|
|
|
correct = False
|
|
|
|
elif not tA and tB and A.shape[1] != B.shape[1]:
|
|
|
|
correct = False
|
2022-07-22 21:41:05 +00:00
|
|
|
elif len(sA) == 3 and len(sB) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
if not tA and not tB and A.shape[2] != B.shape[0]:
|
|
|
|
correct = False
|
|
|
|
elif tA and not tB and A.shape[1] != B.shape[0]:
|
|
|
|
correct = False
|
|
|
|
elif tA and tB and A.shape[1] != B.shape[1]:
|
|
|
|
correct = False
|
|
|
|
elif not tA and tB and A.shape[2] != B.shape[1]:
|
|
|
|
correct = False
|
2022-07-22 21:41:05 +00:00
|
|
|
elif len(sA) == 3 and len(sB) == 3:
|
2022-08-01 10:31:48 +00:00
|
|
|
if not tA and not tB and A.shape[2] != B.shape[1]:
|
|
|
|
correct = False
|
|
|
|
elif tA and not tB and A.shape[1] != B.shape[1]:
|
|
|
|
correct = False
|
|
|
|
elif tA and tB and A.shape[1] != B.shape[2]:
|
|
|
|
correct = False
|
|
|
|
elif not tA and tB and A.shape[2] != B.shape[2]:
|
|
|
|
correct = False
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if out is not None:
|
|
|
|
sout = out.shape
|
|
|
|
# special case common in backprop
|
|
|
|
if not correct and len(sA) == 3 and len(sB) == 3:
|
2022-08-01 10:31:48 +00:00
|
|
|
if (
|
|
|
|
sout[0] == sA[2]
|
|
|
|
and sout[1] == sB[2]
|
|
|
|
and sA[0] == sB[0]
|
|
|
|
and sA[1] == sB[1]
|
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
correct = True
|
|
|
|
else:
|
|
|
|
if len(sA) == 2 and len(sB) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
if not tA and not tB:
|
|
|
|
sout = (sA[0], sB[1])
|
|
|
|
elif tA and tB:
|
|
|
|
sout = (sA[1], sB[0])
|
|
|
|
elif tA and not tB:
|
|
|
|
sout = (sA[1], sB[1])
|
|
|
|
elif not tA and tB:
|
|
|
|
sout = (sA[0], sB[0])
|
2022-07-22 21:41:05 +00:00
|
|
|
elif len(sA) == 3 and len(sB) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
if not tA and not tB:
|
|
|
|
sout = (sA[0], sA[1], sB[1])
|
|
|
|
elif tA and tB:
|
|
|
|
sout = (sA[0], sA[2], sB[0])
|
|
|
|
elif tA and not tB:
|
|
|
|
sout = (sA[0], sA[2], sB[1])
|
|
|
|
elif not tA and tB:
|
|
|
|
sout = (sA[0], sA[1], sB[0])
|
2022-07-22 21:41:05 +00:00
|
|
|
elif len(sA) == 3 and len(sB) == 3:
|
2022-08-01 10:31:48 +00:00
|
|
|
if not tA and not tB:
|
|
|
|
sout = (sA[0], sA[1], sB[2])
|
|
|
|
elif tA and tB:
|
|
|
|
sout = (sA[0], sA[2], sB[1])
|
|
|
|
elif tA and not tB:
|
|
|
|
sout = (sA[0], sA[2], sB[2])
|
|
|
|
elif not tA and tB:
|
|
|
|
sout = (sA[0], sA[1], sB[1])
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if not correct:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}."
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
return sout
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def igemm(
|
2022-08-01 16:32:47 +00:00
|
|
|
A: Tensor,
|
|
|
|
B: Tensor,
|
|
|
|
out: Tensor = None,
|
|
|
|
transposed_A=False,
|
|
|
|
transposed_B=False,
|
2022-08-01 10:31:48 +00:00
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
sout = check_matmul(A, B, out, transposed_A, transposed_B)
|
2022-08-01 10:31:48 +00:00
|
|
|
if out is None:
|
|
|
|
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
|
2022-07-22 21:41:05 +00:00
|
|
|
if len(A.shape) == 3 and len(B.shape) == 3:
|
|
|
|
if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]:
|
|
|
|
return batched_igemm(A, B, out)
|
|
|
|
|
|
|
|
sA = A.shape
|
|
|
|
sB = B.shape
|
2022-08-01 10:31:48 +00:00
|
|
|
if transposed_A and len(sA) == 2:
|
|
|
|
sA = (sA[1], sA[0])
|
|
|
|
elif transposed_A and len(sA) == 3:
|
|
|
|
sA = (sA[0], sA[2], sA[0])
|
|
|
|
if transposed_B and len(sB) == 2:
|
|
|
|
sB = (sB[1], sB[0])
|
|
|
|
elif transposed_B and len(sB) == 3:
|
|
|
|
sB = (sB[0], sB[2], sB[0])
|
2022-07-22 21:41:05 +00:00
|
|
|
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
|
|
|
|
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
|
|
|
|
# (transpose of row major is column major)
|
|
|
|
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
|
|
|
|
|
|
|
|
# matrices in the input arguments for cuBLAS
|
|
|
|
# column major: A @ B = C: [m, k] @ [k, n] = [m, n]
|
|
|
|
# row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n]
|
|
|
|
# column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m]
|
|
|
|
if len(sB) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
if B.stride()[0] == B.shape[1]:
|
|
|
|
transposed_B = False
|
|
|
|
elif B.stride()[1] == B.shape[0]:
|
|
|
|
transposed_B = True
|
2022-07-22 21:41:05 +00:00
|
|
|
if len(A.shape) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
if A.stride()[0] == A.shape[1]:
|
|
|
|
transposed_A = False
|
|
|
|
elif A.stride()[1] == A.shape[0]:
|
|
|
|
transposed_A = True
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
if A.stride()[1] == A.shape[2]:
|
|
|
|
transposed_A = False
|
|
|
|
elif A.stride()[2] == A.shape[1]:
|
|
|
|
transposed_A = True
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if len(sA) == 2:
|
|
|
|
n = sA[0]
|
|
|
|
ldb = A.stride()[1 if transposed_A else 0]
|
|
|
|
elif len(sA) == 3 and len(sB) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
n = sA[0] * sA[1]
|
2022-07-22 21:41:05 +00:00
|
|
|
ldb = sA[2]
|
|
|
|
|
|
|
|
m = sB[1]
|
|
|
|
k = sB[0]
|
|
|
|
lda = B.stride()[(1 if transposed_B else 0)]
|
|
|
|
ldc = sB[1]
|
|
|
|
elif len(sB) == 3:
|
|
|
|
# special case
|
|
|
|
assert len(sA) == 3
|
|
|
|
if not (sA[0] == sB[0] and sA[1] == sB[1]):
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
transposed_A = True
|
|
|
|
transposed_B = False
|
|
|
|
|
|
|
|
m = sB[2]
|
|
|
|
n = sA[2]
|
2022-08-01 10:31:48 +00:00
|
|
|
k = sB[0] * sB[1]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
lda = m
|
|
|
|
ldb = sA[2]
|
|
|
|
ldc = m
|
|
|
|
|
|
|
|
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
|
|
|
|
|
|
|
# B^T @ A^T = C^T
|
|
|
|
# [km, nk -> mn]
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([B, A, out])
|
2022-07-22 21:41:05 +00:00
|
|
|
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
|
|
|
|
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def batched_igemm(
|
2022-08-01 16:32:47 +00:00
|
|
|
A: Tensor,
|
|
|
|
B: Tensor,
|
|
|
|
out: Tensor = None,
|
|
|
|
transposed_A=False,
|
|
|
|
transposed_B=False,
|
2022-08-01 10:31:48 +00:00
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
if not len(A.shape) == 3 or not len(B.shape) == 3:
|
2022-08-01 10:31:48 +00:00
|
|
|
raise ValueError(
|
|
|
|
f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
sout = check_matmul(A, B, out, transposed_A, transposed_B)
|
2022-08-01 10:31:48 +00:00
|
|
|
if out is None:
|
|
|
|
out = torch.zeros(size=sout, dtype=torch.int32, device=A.device)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if B.is_contiguous():
|
|
|
|
lda = B.stride()[1]
|
|
|
|
transposed_A = False
|
|
|
|
else:
|
|
|
|
s = B.stride()
|
|
|
|
if s[0] != B.shape[0]:
|
|
|
|
B = B.contiguous()
|
|
|
|
lda = B.stride()[1]
|
|
|
|
elif s[2] == B.shape[1]:
|
|
|
|
transposed_A = True
|
|
|
|
lda = B.stride()[2]
|
|
|
|
else:
|
|
|
|
if s[2] == 1:
|
|
|
|
B = B.contiguous()
|
|
|
|
lda = B.stride()[1]
|
|
|
|
elif s[1] == 1:
|
|
|
|
B = B.contiguous()
|
|
|
|
lda = B.stride()[1]
|
|
|
|
else:
|
|
|
|
B = B.contiguous()
|
|
|
|
lda = B.stride()[1]
|
|
|
|
|
|
|
|
if A.is_contiguous():
|
|
|
|
ldb = A.stride()[1]
|
|
|
|
transposed_B = False
|
|
|
|
else:
|
|
|
|
s = A.stride()
|
|
|
|
if s[0] != A.shape[0]:
|
|
|
|
A = A.contiguous()
|
|
|
|
ldb = A.stride()[1]
|
|
|
|
transposed_B = False
|
|
|
|
elif s[2] == A.shape[1]:
|
|
|
|
ldb = A.stride()[2]
|
|
|
|
transposed_B = True
|
|
|
|
else:
|
|
|
|
A = A.contiguous()
|
|
|
|
ldb = A.stride()[1]
|
|
|
|
transposed_B = False
|
|
|
|
|
|
|
|
# this is a mess: cuBLAS expect column major, but PyTorch is row major.
|
|
|
|
# So to perform the matrix multiplication, we have to treat A, B, and C matrices
|
|
|
|
# (transpose of row major is column major)
|
|
|
|
# This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these
|
|
|
|
# matrices in the input arguments for cuBLAS
|
|
|
|
|
|
|
|
# column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n]
|
|
|
|
# row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n]
|
|
|
|
# column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m]
|
|
|
|
num_batch = A.shape[0]
|
|
|
|
n = A.shape[1]
|
|
|
|
m = B.shape[2]
|
|
|
|
k = B.shape[1]
|
|
|
|
|
|
|
|
ldc = m
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
strideA = B.shape[1] * B.shape[2]
|
|
|
|
strideB = A.shape[1] * A.shape[2]
|
|
|
|
strideC = A.shape[1] * B.shape[2]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
|
|
|
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([B, A, out])
|
2022-07-22 21:41:05 +00:00
|
|
|
lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
|
|
|
|
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc),
|
|
|
|
ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch))
|
|
|
|
return out
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-26 00:27:57 +00:00
|
|
|
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
|
2022-07-22 21:41:05 +00:00
|
|
|
shapeA = SA[0]
|
|
|
|
shapeB = SB[0]
|
|
|
|
dimsA = len(shapeA)
|
|
|
|
dimsB = len(shapeB)
|
2022-08-03 18:54:01 +00:00
|
|
|
assert dimsB == 2, 'Only two dimensional matrices are supported for argument B'
|
2022-07-22 21:41:05 +00:00
|
|
|
if dimsA == 2:
|
|
|
|
m = shapeA[0]
|
|
|
|
elif dimsA == 3:
|
2022-08-01 10:31:48 +00:00
|
|
|
m = shapeA[0] * shapeA[1]
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-03 18:54:01 +00:00
|
|
|
rows = n = shapeB[0]
|
2022-08-08 16:13:22 +00:00
|
|
|
assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}'
|
2022-08-03 18:54:01 +00:00
|
|
|
|
|
|
|
# if the tensor is empty, return a transformed empty tensor with the right dimensions
|
|
|
|
if shapeA[0] == 0 and dimsA == 2:
|
|
|
|
return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16)
|
|
|
|
elif shapeA[1] == 0 and dimsA == 3:
|
|
|
|
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if dimsA == 2 and out is None:
|
2022-08-01 10:31:48 +00:00
|
|
|
out, Sout = get_transform_buffer(
|
|
|
|
(shapeA[0], shapeB[0]), dtype, A.device, "col32", "row"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
elif dimsA == 3 and out is None:
|
2022-08-01 10:31:48 +00:00
|
|
|
out, Sout = get_transform_buffer(
|
|
|
|
(shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row"
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
assert dimsB != 3, "len(B.shape)==3 not supported"
|
|
|
|
assert A.device.type == "cuda"
|
|
|
|
assert B.device.type == "cuda"
|
2022-07-22 21:41:05 +00:00
|
|
|
assert A.dtype == torch.int8
|
|
|
|
assert B.dtype == torch.int8
|
|
|
|
assert out.dtype == dtype
|
2022-08-01 10:31:48 +00:00
|
|
|
assert SA[1] == "col32"
|
|
|
|
assert SB[1] in ["col_turing", "col_ampere"]
|
|
|
|
assert Sout[1] == "col32"
|
|
|
|
assert (
|
|
|
|
shapeA[-1] == shapeB[-1]
|
|
|
|
), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}"
|
2022-07-22 21:41:05 +00:00
|
|
|
formatB = SB[1]
|
|
|
|
prev_device = A.device
|
|
|
|
torch.cuda.set_device(A.device)
|
|
|
|
|
|
|
|
ptr = CUBLAS_Context.get_instance().get_context(A.device)
|
|
|
|
ptrA = get_ptr(A)
|
|
|
|
ptrB = get_ptr(B)
|
|
|
|
ptrC = get_ptr(out)
|
|
|
|
|
|
|
|
k = shapeA[-1]
|
2022-08-01 10:31:48 +00:00
|
|
|
lda = ct.c_int32(m * 32)
|
|
|
|
if formatB == "col_turing":
|
2022-07-22 21:41:05 +00:00
|
|
|
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
|
|
|
|
# n = rows
|
2022-08-01 10:31:48 +00:00
|
|
|
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
|
|
|
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
|
|
|
|
# n = rows
|
2022-08-01 10:31:48 +00:00
|
|
|
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
ldc = ct.c_int32(m * 32)
|
2022-07-22 21:41:05 +00:00
|
|
|
m = ct.c_int32(m)
|
|
|
|
n = ct.c_int32(n)
|
|
|
|
k = ct.c_int32(k)
|
|
|
|
|
|
|
|
has_error = 0
|
2022-07-26 00:27:57 +00:00
|
|
|
ptrRowScale = get_ptr(None)
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([A, B, out])
|
2022-07-22 21:41:05 +00:00
|
|
|
if formatB == 'col_turing':
|
|
|
|
if dtype == torch.int32:
|
2022-08-01 10:31:48 +00:00
|
|
|
has_error = lib.cigemmlt_turing_32(
|
|
|
|
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
has_error = lib.cigemmlt_turing_8(
|
|
|
|
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
|
|
|
|
)
|
|
|
|
elif formatB == "col_ampere":
|
2022-07-22 21:41:05 +00:00
|
|
|
if dtype == torch.int32:
|
2022-08-01 10:31:48 +00:00
|
|
|
has_error = lib.cigemmlt_ampere_32(
|
|
|
|
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
has_error = lib.cigemmlt_ampere_8(
|
|
|
|
ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
if has_error == 1:
|
2022-08-03 18:54:01 +00:00
|
|
|
print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}')
|
2022-07-22 21:41:05 +00:00
|
|
|
raise Exception('cublasLt ran into an error!')
|
|
|
|
|
|
|
|
torch.cuda.set_device(prev_device)
|
|
|
|
|
|
|
|
return out, Sout
|
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def mm_dequant(
|
|
|
|
A,
|
|
|
|
quant_state,
|
|
|
|
row_stats,
|
|
|
|
col_stats,
|
|
|
|
out=None,
|
|
|
|
new_row_stats=None,
|
|
|
|
new_col_stats=None,
|
2022-08-16 17:56:17 +00:00
|
|
|
bias=None
|
2022-08-01 10:31:48 +00:00
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
assert A.dtype == torch.int32
|
2022-08-16 17:56:17 +00:00
|
|
|
if bias is not None: assert bias.dtype == torch.float16
|
2022-07-22 21:41:05 +00:00
|
|
|
out_shape = quant_state[0]
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(out_shape) == 3:
|
|
|
|
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
|
|
|
|
|
|
|
|
if out is None:
|
|
|
|
out = torch.empty(out_shape, dtype=torch.float16, device=A.device)
|
|
|
|
if new_row_stats is None:
|
2022-08-01 16:32:47 +00:00
|
|
|
new_row_stats = torch.empty(
|
|
|
|
out_shape[0], dtype=torch.float32, device=A.device
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
if new_col_stats is None:
|
2022-08-01 16:32:47 +00:00
|
|
|
new_col_stats = torch.empty(
|
|
|
|
out_shape[1], dtype=torch.float32, device=A.device
|
|
|
|
)
|
2022-08-01 10:31:48 +00:00
|
|
|
assert (
|
|
|
|
new_row_stats.shape[0] == row_stats.shape[0]
|
|
|
|
), f"{new_row_stats.shape} vs {row_stats.shape}"
|
|
|
|
assert (
|
|
|
|
new_col_stats.shape[0] == col_stats.shape[0]
|
|
|
|
), f"{new_col_stats.shape} vs {col_stats.shape}"
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-16 17:56:17 +00:00
|
|
|
prev_device = pre_call(A.device)
|
2022-07-22 21:41:05 +00:00
|
|
|
ptrA = get_ptr(A)
|
|
|
|
ptrOut = get_ptr(out)
|
|
|
|
ptrRowStats = get_ptr(row_stats)
|
|
|
|
ptrColStats = get_ptr(col_stats)
|
|
|
|
ptrNewRowStats = get_ptr(new_row_stats)
|
|
|
|
ptrNewColStats = get_ptr(new_col_stats)
|
2022-08-16 17:56:17 +00:00
|
|
|
ptrBias = get_ptr(bias)
|
2022-07-22 21:41:05 +00:00
|
|
|
numRows = ct.c_int32(out_shape[0])
|
|
|
|
numCols = ct.c_int32(out_shape[1])
|
|
|
|
|
2022-08-16 17:56:17 +00:00
|
|
|
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias])
|
|
|
|
lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols)
|
|
|
|
post_call(prev_device)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def get_colrow_absmax(
|
|
|
|
A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0
|
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
assert A.dtype == torch.float16
|
|
|
|
device = A.device
|
|
|
|
|
|
|
|
cols = A.shape[-1]
|
|
|
|
if len(A.shape) == 3:
|
2022-08-01 10:31:48 +00:00
|
|
|
rows = A.shape[0] * A.shape[1]
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
|
|
|
rows = A.shape[0]
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
col_tiles = (cols + 255) // 256
|
|
|
|
tiled_rows = ((rows + 15) // 16) * 16
|
|
|
|
if row_stats is None:
|
2022-08-01 16:32:47 +00:00
|
|
|
row_stats = torch.empty(
|
|
|
|
(rows,), dtype=torch.float32, device=device
|
|
|
|
).fill_(-50000.0)
|
2022-08-01 10:31:48 +00:00
|
|
|
if col_stats is None:
|
2022-08-01 16:32:47 +00:00
|
|
|
col_stats = torch.empty(
|
|
|
|
(cols,), dtype=torch.float32, device=device
|
|
|
|
).fill_(-50000.0)
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
if nnz_block_ptr is None and threshold > 0.0:
|
|
|
|
nnz_block_ptr = torch.zeros(
|
|
|
|
((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
ptrA = get_ptr(A)
|
|
|
|
ptrRowStats = get_ptr(row_stats)
|
|
|
|
ptrColStats = get_ptr(col_stats)
|
|
|
|
ptrNnzrows = get_ptr(nnz_block_ptr)
|
|
|
|
rows = ct.c_int32(rows)
|
|
|
|
cols = ct.c_int32(cols)
|
|
|
|
|
|
|
|
prev_device = pre_call(A.device)
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([A, row_stats, col_stats, nnz_block_ptr])
|
2022-07-22 21:41:05 +00:00
|
|
|
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
|
|
|
|
post_call(prev_device)
|
|
|
|
|
|
|
|
if threshold > 0.0:
|
|
|
|
nnz_block_ptr.cumsum_(0)
|
|
|
|
|
|
|
|
return row_stats, col_stats, nnz_block_ptr
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
class COOSparseTensor(object):
|
|
|
|
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
|
|
|
|
assert rowidx.dtype == torch.int32
|
|
|
|
assert colidx.dtype == torch.int32
|
|
|
|
assert values.dtype == torch.float16
|
|
|
|
assert values.numel() == nnz
|
|
|
|
assert rowidx.numel() == nnz
|
|
|
|
assert colidx.numel() == nnz
|
|
|
|
|
|
|
|
self.rows = rows
|
|
|
|
self.cols = cols
|
|
|
|
self.nnz = nnz
|
|
|
|
self.rowidx = rowidx
|
|
|
|
self.colidx = colidx
|
|
|
|
self.values = values
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
class CSRSparseTensor(object):
|
|
|
|
def __init__(self, rows, cols, nnz, rowptr, colidx, values):
|
|
|
|
assert rowptr.dtype == torch.int32
|
|
|
|
assert colidx.dtype == torch.int32
|
|
|
|
assert values.dtype == torch.float16
|
|
|
|
assert values.numel() == nnz
|
|
|
|
assert colidx.numel() == nnz
|
2022-08-01 10:31:48 +00:00
|
|
|
assert rowptr.numel() == rows + 1
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
self.rows = rows
|
|
|
|
self.cols = cols
|
|
|
|
self.nnz = nnz
|
|
|
|
self.rowptr = rowptr
|
|
|
|
self.colidx = colidx
|
|
|
|
self.values = values
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
class CSCSparseTensor(object):
|
|
|
|
def __init__(self, rows, cols, nnz, colptr, rowidx, values):
|
|
|
|
assert colptr.dtype == torch.int32
|
|
|
|
assert rowidx.dtype == torch.int32
|
|
|
|
assert values.dtype == torch.float16
|
|
|
|
assert values.numel() == nnz
|
|
|
|
assert rowidx.numel() == nnz
|
2022-08-01 10:31:48 +00:00
|
|
|
assert colptr.numel() == cols + 1
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
self.rows = rows
|
|
|
|
self.cols = cols
|
|
|
|
self.nnz = nnz
|
|
|
|
self.colptr = colptr
|
|
|
|
self.rowidx = rowidx
|
|
|
|
self.values = values
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def coo2csr(cooA):
|
|
|
|
values, counts = torch.unique(cooA.rowidx, return_counts=True)
|
|
|
|
values.add_(1)
|
2022-08-01 16:32:47 +00:00
|
|
|
rowptr = torch.zeros(
|
|
|
|
(cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
rowptr.scatter_(index=values.long(), src=counts.int(), dim=0)
|
|
|
|
rowptr.cumsum_(0)
|
2022-08-01 10:31:48 +00:00
|
|
|
return CSRSparseTensor(
|
|
|
|
cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values
|
|
|
|
)
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
def coo2csc(cooA):
|
|
|
|
val, col2rowidx = torch.sort(cooA.colidx)
|
|
|
|
rowidx = cooA.rowidx[col2rowidx]
|
|
|
|
values = cooA.values[col2rowidx]
|
|
|
|
colvalues, counts = torch.unique(val, return_counts=True)
|
|
|
|
colvalues.add_(1)
|
2022-08-01 16:32:47 +00:00
|
|
|
colptr = torch.zeros(
|
|
|
|
(cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0)
|
|
|
|
colptr.cumsum_(0)
|
2022-08-01 16:32:47 +00:00
|
|
|
return CSCSparseTensor(
|
|
|
|
cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
|
|
|
|
rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
|
|
|
|
colidx = torch.zeros((nnz,), dtype=torch.int32, device=device)
|
|
|
|
values = torch.zeros((nnz,), dtype=dtype, device=device)
|
|
|
|
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
|
|
|
|
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def double_quant(
|
|
|
|
A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0
|
|
|
|
):
|
2022-07-22 21:41:05 +00:00
|
|
|
device = A.device
|
|
|
|
assert A.dtype == torch.half
|
2022-08-01 10:31:48 +00:00
|
|
|
assert device.type == "cuda"
|
2022-07-22 21:41:05 +00:00
|
|
|
prev_device = pre_call(A.device)
|
|
|
|
|
|
|
|
cols = A.shape[-1]
|
|
|
|
if len(A.shape) == 3:
|
2022-08-01 10:31:48 +00:00
|
|
|
rows = A.shape[0] * A.shape[1]
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
|
|
|
rows = A.shape[0]
|
|
|
|
|
|
|
|
if row_stats is None or col_stats is None:
|
2022-08-01 16:32:47 +00:00
|
|
|
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(
|
|
|
|
A, threshold=threshold
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
if out_col is None:
|
|
|
|
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
|
|
|
|
if out_row is None:
|
|
|
|
out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
coo_tensor = None
|
|
|
|
ptrA = get_ptr(A)
|
|
|
|
ptrColStats = get_ptr(col_stats)
|
|
|
|
ptrRowStats = get_ptr(row_stats)
|
|
|
|
ptrOutCol = get_ptr(out_col)
|
|
|
|
ptrOutRow = get_ptr(out_row)
|
|
|
|
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([A, col_stats, row_stats, out_col, out_row])
|
2022-07-22 21:41:05 +00:00
|
|
|
if threshold > 0.0:
|
|
|
|
nnz = nnz_row_ptr[-1].item()
|
|
|
|
if nnz > 0:
|
2022-08-01 10:31:48 +00:00
|
|
|
coo_tensor = coo_zeros(
|
|
|
|
A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
ptrRowIdx = get_ptr(coo_tensor.rowidx)
|
|
|
|
ptrColIdx = get_ptr(coo_tensor.colidx)
|
|
|
|
ptrVal = get_ptr(coo_tensor.values)
|
|
|
|
ptrRowPtr = get_ptr(nnz_row_ptr)
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cdouble_rowcol_quant(
|
|
|
|
ptrA,
|
|
|
|
ptrRowStats,
|
|
|
|
ptrColStats,
|
|
|
|
ptrOutCol,
|
|
|
|
ptrOutRow,
|
|
|
|
ptrRowIdx,
|
|
|
|
ptrColIdx,
|
|
|
|
ptrVal,
|
|
|
|
ptrRowPtr,
|
|
|
|
ct.c_float(threshold),
|
|
|
|
ct.c_int32(rows),
|
|
|
|
ct.c_int32(cols),
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
val, idx = torch.sort(coo_tensor.rowidx)
|
|
|
|
coo_tensor.rowidx = val
|
|
|
|
coo_tensor.colidx = coo_tensor.colidx[idx]
|
|
|
|
coo_tensor.values = coo_tensor.values[idx]
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cdouble_rowcol_quant(
|
|
|
|
ptrA,
|
|
|
|
ptrRowStats,
|
|
|
|
ptrColStats,
|
|
|
|
ptrOutCol,
|
|
|
|
ptrOutRow,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
ct.c_float(0.0),
|
|
|
|
ct.c_int32(rows),
|
|
|
|
ct.c_int32(cols),
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cdouble_rowcol_quant(
|
|
|
|
ptrA,
|
|
|
|
ptrRowStats,
|
|
|
|
ptrColStats,
|
|
|
|
ptrOutCol,
|
|
|
|
ptrOutRow,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
None,
|
|
|
|
ct.c_float(threshold),
|
|
|
|
ct.c_int32(rows),
|
|
|
|
ct.c_int32(cols),
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
post_call(prev_device)
|
|
|
|
|
|
|
|
return out_row, out_col, row_stats, col_stats, coo_tensor
|
|
|
|
|
|
|
|
|
|
|
|
def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None):
|
2022-08-04 14:47:22 +00:00
|
|
|
prev_device = pre_call(A.device)
|
2022-07-22 21:41:05 +00:00
|
|
|
if state is None: state = (A.shape, from_order)
|
|
|
|
else: from_order = state[1]
|
|
|
|
if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose)
|
|
|
|
else: new_state = (state[0], to_order) # (shape, order)
|
|
|
|
|
|
|
|
shape = state[0]
|
|
|
|
if len(shape) == 2:
|
|
|
|
dim1 = ct.c_int32(shape[0])
|
|
|
|
dim2 = ct.c_int32(shape[1])
|
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
dim1 = ct.c_int32(shape[0] * shape[1])
|
2022-07-22 21:41:05 +00:00
|
|
|
dim2 = ct.c_int32(shape[2])
|
|
|
|
|
|
|
|
ptrA = get_ptr(A)
|
|
|
|
ptrOut = get_ptr(out)
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([A, out])
|
2022-07-22 21:41:05 +00:00
|
|
|
if to_order == 'col32':
|
|
|
|
if transpose:
|
|
|
|
lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2)
|
|
|
|
else:
|
|
|
|
lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif to_order == "col_turing":
|
2022-07-22 21:41:05 +00:00
|
|
|
if transpose:
|
|
|
|
lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2)
|
|
|
|
else:
|
|
|
|
lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif to_order == "col_ampere":
|
2022-07-22 21:41:05 +00:00
|
|
|
if transpose:
|
|
|
|
lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2)
|
|
|
|
else:
|
|
|
|
lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif to_order == "row":
|
|
|
|
if from_order == "col_turing":
|
2022-07-22 21:41:05 +00:00
|
|
|
lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif from_order == "col_ampere":
|
2022-07-22 21:41:05 +00:00
|
|
|
lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}')
|
|
|
|
|
2022-08-04 14:47:22 +00:00
|
|
|
post_call(prev_device)
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
return out, new_state
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def spmm_coo(cooA, B, out=None):
|
2022-08-01 10:31:48 +00:00
|
|
|
if out is None:
|
2022-08-01 16:32:47 +00:00
|
|
|
out = torch.empty(
|
|
|
|
(cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
nnz = cooA.nnz
|
|
|
|
assert cooA.rowidx.numel() == nnz
|
|
|
|
assert cooA.colidx.numel() == nnz
|
|
|
|
assert cooA.values.numel() == nnz
|
|
|
|
assert cooA.cols == B.shape[0]
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
transposed_B = False if B.is_contiguous() else True
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
ldb = B.stride()[(1 if transposed_B else 0)]
|
|
|
|
ldc = B.shape[1]
|
|
|
|
|
|
|
|
ptr = Cusparse_Context.get_instance().context
|
|
|
|
|
|
|
|
ptrRowidx = get_ptr(cooA.rowidx)
|
|
|
|
ptrColidx = get_ptr(cooA.colidx)
|
|
|
|
ptrValues = get_ptr(cooA.values)
|
|
|
|
ptrB = get_ptr(B)
|
|
|
|
ptrC = get_ptr(out)
|
|
|
|
cnnz = ct.c_int32(cooA.nnz)
|
|
|
|
crowsA = ct.c_int32(cooA.rows)
|
|
|
|
ccolsA = ct.c_int32(cooA.cols)
|
|
|
|
ccolsB = ct.c_int32(B.shape[1])
|
|
|
|
cldb = ct.c_int32(ldb)
|
|
|
|
cldc = ct.c_int32(ldc)
|
|
|
|
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out])
|
2022-07-22 21:41:05 +00:00
|
|
|
lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B))
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
|
2022-08-01 10:31:48 +00:00
|
|
|
if out is None:
|
|
|
|
out = torch.zeros(
|
|
|
|
(cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
nnz = cooA.nnz
|
|
|
|
assert cooA.rowidx.numel() == nnz
|
|
|
|
assert cooA.colidx.numel() == nnz
|
|
|
|
assert cooA.values.numel() == nnz
|
2022-08-01 10:31:48 +00:00
|
|
|
assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
transposed_B = False if B.is_contiguous() else True
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
ldb = B.stride()[(1 if transposed_B else 0)]
|
|
|
|
ldc = B.shape[1]
|
|
|
|
|
|
|
|
values, counts = torch.unique(cooA.rowidx, return_counts=True)
|
|
|
|
offset = counts.cumsum(0).int()
|
|
|
|
max_count, max_idx = torch.sort(counts, descending=True)
|
|
|
|
max_idx = max_idx.int()
|
|
|
|
max_count = max_count.int()
|
2022-08-01 10:31:48 +00:00
|
|
|
assert (
|
|
|
|
max_count[0] <= 32
|
|
|
|
), f"Current max count per row is 8 but found {max_count[0]}."
|
2022-07-22 21:41:05 +00:00
|
|
|
assert B.dtype in [torch.float16, torch.int8]
|
|
|
|
ptrOffset = get_ptr(offset)
|
|
|
|
ptrMaxCount = get_ptr(max_count)
|
|
|
|
ptrMaxIdx = get_ptr(max_idx)
|
|
|
|
|
|
|
|
ptrRowidx = get_ptr(cooA.rowidx)
|
|
|
|
ptrColidx = get_ptr(cooA.colidx)
|
|
|
|
ptrValues = get_ptr(cooA.values)
|
|
|
|
ptrB = get_ptr(B)
|
|
|
|
ptrC = get_ptr(out)
|
|
|
|
ptrDequantStats = get_ptr(dequant_stats)
|
|
|
|
cnnz_rows = ct.c_int32(counts.numel())
|
|
|
|
cnnz = ct.c_int32(cooA.nnz)
|
|
|
|
crowsA = ct.c_int32(cooA.rows)
|
|
|
|
ccolsA = ct.c_int32(cooA.cols)
|
|
|
|
crowsB = ct.c_int32(B.shape[1])
|
|
|
|
ccolsB = ct.c_int32(B.shape[1])
|
|
|
|
cldb = ct.c_int32(ldb)
|
|
|
|
cldc = ct.c_int32(ldc)
|
2022-08-01 10:31:48 +00:00
|
|
|
# print(cooA.rowidx[:64])
|
|
|
|
# print(cooA.colidx[:64].sort()[0])
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-03 16:05:37 +00:00
|
|
|
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
|
2022-07-22 21:41:05 +00:00
|
|
|
if B.dtype == torch.float16:
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cspmm_coo_very_sparse_naive_fp16(
|
|
|
|
ptrMaxCount,
|
|
|
|
ptrMaxIdx,
|
|
|
|
ptrOffset,
|
|
|
|
ptrRowidx,
|
|
|
|
ptrColidx,
|
|
|
|
ptrValues,
|
|
|
|
ptrB,
|
|
|
|
ptrC,
|
|
|
|
ptrDequantStats,
|
|
|
|
cnnz_rows,
|
|
|
|
cnnz,
|
|
|
|
crowsA,
|
|
|
|
crowsB,
|
|
|
|
ccolsB,
|
|
|
|
)
|
2022-07-22 21:41:05 +00:00
|
|
|
elif B.dtype == torch.int8:
|
2022-08-01 10:31:48 +00:00
|
|
|
lib.cspmm_coo_very_sparse_naive_int8(
|
|
|
|
ptrMaxCount,
|
|
|
|
ptrMaxIdx,
|
|
|
|
ptrOffset,
|
|
|
|
ptrRowidx,
|
|
|
|
ptrColidx,
|
|
|
|
ptrValues,
|
|
|
|
ptrB,
|
|
|
|
ptrC,
|
|
|
|
ptrDequantStats,
|
|
|
|
cnnz_rows,
|
|
|
|
cnnz,
|
|
|
|
crowsA,
|
|
|
|
crowsB,
|
|
|
|
ccolsB,
|
|
|
|
)
|
|
|
|
# else: assertion error
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
C = 127.0
|
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def vectorwise_quant(x, dim=1, quant_type="vector"):
|
|
|
|
if quant_type == "linear":
|
2022-07-22 21:41:05 +00:00
|
|
|
max1 = torch.abs(x).max().float()
|
2022-08-01 10:31:48 +00:00
|
|
|
xq = torch.round(x / max1 * 127).to(torch.int8)
|
2022-07-22 21:41:05 +00:00
|
|
|
return xq, max1
|
2022-08-01 10:31:48 +00:00
|
|
|
elif quant_type in ["vector", "row"]:
|
2022-07-22 21:41:05 +00:00
|
|
|
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
|
2022-08-01 10:31:48 +00:00
|
|
|
xq = torch.round(x * (C / max1)).to(torch.int8)
|
2022-07-22 21:41:05 +00:00
|
|
|
return xq, max1
|
2022-08-01 10:31:48 +00:00
|
|
|
elif quant_type == "zeropoint":
|
2022-07-22 21:41:05 +00:00
|
|
|
dtype = x.dtype
|
|
|
|
x = x.float()
|
|
|
|
dyna = x.max() - x.min()
|
2022-08-01 10:31:48 +00:00
|
|
|
if dyna == 0:
|
|
|
|
dyna = 1
|
|
|
|
qx = 255.0 / dyna
|
2022-07-22 21:41:05 +00:00
|
|
|
minx = x.min()
|
2022-08-01 10:31:48 +00:00
|
|
|
zpx = torch.round(minx * qx)
|
|
|
|
x = torch.round(qx * x - zpx) + zpx
|
2022-07-22 21:41:05 +00:00
|
|
|
return x, qx
|
2022-08-01 10:31:48 +00:00
|
|
|
elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
|
2022-07-22 21:41:05 +00:00
|
|
|
dtype = x.dtype
|
|
|
|
x = x.float()
|
2022-08-01 10:31:48 +00:00
|
|
|
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(
|
|
|
|
x, dim=dim, keepdim=True
|
|
|
|
)
|
|
|
|
dyna[dyna == 0] = 1
|
|
|
|
qx = 255.0 / dyna
|
2022-07-22 21:41:05 +00:00
|
|
|
minx = torch.amin(x, dim=dim, keepdim=True)
|
2022-08-01 10:31:48 +00:00
|
|
|
zpx = torch.round(minx * qx)
|
|
|
|
x = torch.round(qx * x - zpx) + zpx
|
2022-07-22 21:41:05 +00:00
|
|
|
return x, qx
|
2022-08-01 10:31:48 +00:00
|
|
|
elif quant_type == "truncated-vector":
|
2022-07-22 21:41:05 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
absx = torch.abs(x)
|
|
|
|
max1 = torch.amax(absx, dim=dim, keepdim=True)
|
2022-08-01 10:31:48 +00:00
|
|
|
max1 = max1 * 0.7
|
|
|
|
idx = absx > max1.expand_as(absx)
|
2022-07-22 21:41:05 +00:00
|
|
|
sign = torch.sign(x[idx])
|
2022-08-01 10:31:48 +00:00
|
|
|
x[idx] = max1.expand_as(absx)[idx] * sign
|
|
|
|
xq = torch.round(x / max1 * C).to(torch.int8)
|
2022-07-22 21:41:05 +00:00
|
|
|
return xq, max1
|
2022-08-01 10:31:48 +00:00
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
def vectorwise_dequant(xq, max1, quant_type="vector"):
|
|
|
|
if quant_type == "vector":
|
|
|
|
x = (xq / C * max1).to(torch.float32)
|
2022-07-22 21:41:05 +00:00
|
|
|
return x
|
2022-08-01 10:31:48 +00:00
|
|
|
else:
|
|
|
|
return None
|
2022-07-22 21:41:05 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
|
|
|
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
|
|
|
|
if quant_type == "linear":
|
|
|
|
norm = S1 * S2 / (C * C)
|
2022-07-22 21:41:05 +00:00
|
|
|
# double cast needed to prevent overflows
|
2022-08-01 10:31:48 +00:00
|
|
|
return (xq.float() * norm).to(dtype)
|
|
|
|
elif quant_type == "zeropoint":
|
|
|
|
norm = 1.0 / (S1 * S2)
|
|
|
|
return (xq.float() * norm).to(dtype)
|
|
|
|
elif quant_type == "row-zeropoint":
|
|
|
|
norm = 1.0 / (S1 * S2)
|
2022-07-22 21:41:05 +00:00
|
|
|
x = xq.float()
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(S1.shape) == 3 and len(x.shape) == 2:
|
|
|
|
S1 = S1.squeeze(0)
|
|
|
|
if len(S2.shape) == 3 and len(x.shape) == 2:
|
|
|
|
S2 = S2.squeeze(0)
|
2022-07-22 21:41:05 +00:00
|
|
|
if len(S1.shape) == 2:
|
|
|
|
x *= norm
|
|
|
|
else:
|
|
|
|
x *= norm
|
|
|
|
return x.to(dtype)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif quant_type == "vector-zeropoint":
|
2022-07-22 21:41:05 +00:00
|
|
|
x = xq.float()
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(S1.shape) == 3 and len(x.shape) == 2:
|
|
|
|
S1 = S1.squeeze(0)
|
|
|
|
if len(S2.shape) == 3 and len(x.shape) == 2:
|
|
|
|
S2 = S2.squeeze(0)
|
2022-07-22 21:41:05 +00:00
|
|
|
if len(S1.shape) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
x *= 1.0 / S1
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
x *= 1.0 / S1
|
|
|
|
x *= 1.0 / S2.t()
|
2022-07-22 21:41:05 +00:00
|
|
|
return x.to(dtype)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif quant_type == "row":
|
2022-07-22 21:41:05 +00:00
|
|
|
x = xq.float()
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(S1.shape) == 3 and len(x.shape) == 2:
|
|
|
|
S1 = S1.squeeze(0)
|
|
|
|
if len(S2.shape) == 3 and len(x.shape) == 2:
|
|
|
|
S2 = S2.squeeze(0)
|
2022-07-22 21:41:05 +00:00
|
|
|
if len(S1.shape) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
x *= S1 * S2 / (C * C)
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
x *= S1 * S2 / (C * C)
|
2022-07-22 21:41:05 +00:00
|
|
|
return x.to(dtype)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif quant_type in ["truncated-vector", "vector"]:
|
2022-07-22 21:41:05 +00:00
|
|
|
x = xq.float()
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(S1.shape) == 3 and len(x.shape) == 2:
|
|
|
|
S1 = S1.squeeze(0)
|
|
|
|
if len(S2.shape) == 3 and len(x.shape) == 2:
|
|
|
|
S2 = S2.squeeze(0)
|
2022-07-22 21:41:05 +00:00
|
|
|
if len(S1.shape) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
x *= S1 / C
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
x *= S1 / C
|
|
|
|
x *= S2 / C
|
2022-07-22 21:41:05 +00:00
|
|
|
return x.to(dtype)
|
2022-08-01 10:31:48 +00:00
|
|
|
else:
|
|
|
|
return None
|
2022-07-22 21:41:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
|
2022-08-01 10:31:48 +00:00
|
|
|
offset = B.float().t().sum(0) * (SA[0] + SA[1])
|
2022-07-22 21:41:05 +00:00
|
|
|
x = xq.float()
|
2022-08-01 10:31:48 +00:00
|
|
|
if len(xq.shape) == 2 and len(SB.shape) == 3:
|
|
|
|
SB = SB.squeeze(0)
|
2022-07-22 21:41:05 +00:00
|
|
|
if len(SB.shape) == 2:
|
2022-08-01 10:31:48 +00:00
|
|
|
x *= SB.t() / 127
|
2022-07-22 21:41:05 +00:00
|
|
|
else:
|
2022-08-01 10:31:48 +00:00
|
|
|
x *= SB / 127
|
|
|
|
x *= SA[1] / 127
|
|
|
|
x += offset
|
2022-07-22 21:41:05 +00:00
|
|
|
return x.to(dtype)
|
2022-07-26 19:12:38 +00:00
|
|
|
|
2022-08-01 10:31:48 +00:00
|
|
|
|
2022-07-26 19:12:38 +00:00
|
|
|
def extract_outliers(A, SA, idx):
|
|
|
|
shapeA = SA[0]
|
|
|
|
formatA = SA[1]
|
2022-08-01 10:31:48 +00:00
|
|
|
assert formatA in ["col_turing", "col_ampere"]
|
|
|
|
assert A.device.type == "cuda"
|
2022-07-26 19:12:38 +00:00
|
|
|
|
2022-08-01 16:32:47 +00:00
|
|
|
out = torch.zeros(
|
|
|
|
(shapeA[0], idx.numel()), dtype=torch.int8, device=A.device
|
|
|
|
)
|
2022-07-26 19:12:38 +00:00
|
|
|
|
|
|
|
idx_size = ct.c_int32(idx.numel())
|
|
|
|
rows = ct.c_int32(shapeA[0])
|
|
|
|
cols = ct.c_int32(shapeA[1])
|
|
|
|
ptrA = get_ptr(A)
|
|
|
|
ptrIdx = get_ptr(idx)
|
|
|
|
ptrOut = get_ptr(out)
|
|
|
|
|
2022-08-04 14:47:22 +00:00
|
|
|
prev_device = pre_call(A.device)
|
2022-07-26 19:12:38 +00:00
|
|
|
if formatA == 'col_turing':
|
|
|
|
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
2022-08-01 10:31:48 +00:00
|
|
|
elif formatA == "col_ampere":
|
2022-07-26 19:12:38 +00:00
|
|
|
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
|
2022-08-04 14:47:22 +00:00
|
|
|
post_call(prev_device)
|
2022-07-26 19:12:38 +00:00
|
|
|
|
|
|
|
return out
|