ran black and isort for coherent code formatting

This commit is contained in:
Titus von Koeller 2022-08-01 03:31:48 -07:00
parent 597a8521b2
commit bfa0e33294
25 changed files with 3855 additions and 1987 deletions

View File

@ -1,16 +1,18 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .nn import modules from .autograd._functions import (MatmulLtState, bmm_cublas, matmul,
from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState matmul_cublas, mm_cublas)
from .cextension import COMPILED_WITH_CUDA from .cextension import COMPILED_WITH_CUDA
from .nn import modules
if COMPILED_WITH_CUDA: if COMPILED_WITH_CUDA:
from .optim import adam from .optim import adam
__pdoc__ = {'libbitsandbytes': False, __pdoc__ = {
'optim.optimizer.Optimizer8bit': False, "libbitsandbytes": False,
'optim.optimizer.MockArgs': False "optim.optimizer.Optimizer8bit": False,
} "optim.optimizer.MockArgs": False,
}

View File

@ -1,21 +1,24 @@
from dataclasses import dataclass
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
from dataclasses import dataclass
tensor = torch.Tensor tensor = torch.Tensor
''' """
This class pools outlier dimensions across layers. This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
are less systematic and occur with low frequency. are less systematic and occur with low frequency.
''' """
class GlobalOutlierPooler(object): class GlobalOutlierPooler(object):
_instance = None _instance = None
def __init__(self): def __init__(self):
raise RuntimeError('Call get_instance() instead') raise RuntimeError("Call get_instance() instead")
def initialize(self): def initialize(self):
self.outliers = set() self.outliers = set()
@ -29,25 +32,29 @@ class GlobalOutlierPooler(object):
return cls._instance return cls._instance
def add_outliers(self, outlier_idx, feature_dim): def add_outliers(self, outlier_idx, feature_dim):
if self.model_dim is None: self.model_dim = feature_dim if self.model_dim is None:
if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer self.model_dim = feature_dim
if feature_dim != self.model_dim:
return # we do not encode outliers for the 2nd FFN layer
self.outliers.update(outlier_idx.tolist()) self.outliers.update(outlier_idx.tolist())
def get_current_outlier_idx(self): def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64) return torch.Tensor(list(self.outliers)).to(torch.int64)
class MatMul8bit(torch.autograd.Function):
class MatMul8bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]): def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):
if precision[0] != 8: if precision[0] != 8:
with torch.no_grad(): with torch.no_grad():
output = torch.matmul(A, B) output = torch.matmul(A, B)
else: else:
if len(B.shape) == 2: dim = 0 if len(B.shape) == 2:
else: dim = 1 dim = 0
else:
dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type) qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type) qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB) iout = F.igemm(qA, qB)
@ -84,21 +91,41 @@ class MatMul8bit(torch.autograd.Function):
else: else:
if len(B.shape) == 2 and len(A.shape) == 3: if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
if not grad_output.is_contiguous(): grad_output.contiguous() if not grad_output.is_contiguous():
qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type) grad_output.contiguous()
if not A.is_contiguous(): A = A.contiguous() qgrad_output, S1 = F.vectorwise_quant(
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type) grad_output.view(-1, grad_output.shape[2]),
dim=0,
quant_type=quant_type,
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
)
igrad_B = F.igemm(qA.t(), qgrad_output) igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type) grad_B = F.vectorwise_mm_dequant(
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
)
else: else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type) qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type) grad_B = F.vectorwise_mm_dequant(
igrad_B,
S2.permute(permute_dim),
S1,
grad_output.dtype,
quant_type,
)
if A.requires_grad: if A.requires_grad:
if len(grad_output.shape) == 3: dims = [2] if len(grad_output.shape) == 3:
else: dims = [1] dims = [2]
else:
dims = [1]
if len(B.shape) == 3: if len(B.shape) == 3:
# bio -> boi # bio -> boi
@ -113,10 +140,14 @@ class MatMul8bit(torch.autograd.Function):
with torch.no_grad(): with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim)) grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else: else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type) qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type) grad_A = F.vectorwise_mm_dequant(
igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
)
return grad_A, grad_B, None, None, None return grad_A, grad_B, None, None, None
@ -125,6 +156,7 @@ mm_cublas = MatMul8bit.apply
bmm_cublas = MatMul8bit.apply bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
CB = None CB = None
@ -159,7 +191,6 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()): def forward(ctx, A, B, out=None, state=MatmulLtState()):
# 1. Quantize A # 1. Quantize A
@ -171,11 +202,15 @@ class MatMul8bitLt(torch.autograd.Function):
requires_gradB = B.requires_grad requires_gradB = B.requires_grad
formatB = state.formatB formatB = state.formatB
input_shape = A.shape input_shape = A.shape
if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() if state.outlier_pool is None:
assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!' state.outlier_pool = GlobalOutlierPooler.get_instance()
assert (
A.dtype == torch.float16
), f"The input data type needs to be fp16 but {A.dtype} was found!"
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold) CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and coo_tensorA is not None:
@ -191,8 +226,8 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format # we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half() # state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None: # if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB # # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long() # outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
@ -203,24 +238,24 @@ class MatMul8bitLt(torch.autograd.Function):
# state.idx = outlier_idx # state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half() # state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()
#if state.idx is not None: # if state.idx is not None:
# # extract outliers # # extract outliers
# CA[:, state.idx] = 0 # CA[:, state.idx] = 0
# CAt[:, state.idx] = 0 # CAt[:, state.idx] = 0
# subA = A[:, state.idx] # subA = A[:, state.idx]
#else: # else:
# subA = None # subA = None
else: else:
if not state.has_fp16_weights and state.CxB is None: if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None subA = None
# 2. Quantize B # 2. Quantize B
if state.has_fp16_weights: if state.has_fp16_weights:
has_grad = (True if (getattr(B, 'grad', None) is not None) else False) has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed: B = B.contiguous() if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None: if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads() state.reset_grads()
@ -234,14 +269,16 @@ class MatMul8bitLt(torch.autograd.Function):
outlier_idx = torch.unique(coo_tensorA.colidx) outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx state.idx = outlier_idx
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer # # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
#else: # else:
# state.idx = outlier_idx # state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half() state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
)
CA[:, state.idx.long()] = 0 CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0 CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()] subA = A[:, state.idx.long()]
@ -254,7 +291,7 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0]) output_shape = (input_shape[0], shapeB[0])
# 3. Matmul # 3. Matmul
C32A, SA = F.transform(CA, 'col32') C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB) output = F.mm_dequant(out32, Sout32, SCA, state.SCB)
@ -277,7 +314,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x # clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone clone_func = torch.clone
return clone_func(output.view(output_shape)) return clone_func(output.view(output_shape))
@ -288,7 +325,7 @@ class MatMul8bitLt(torch.autograd.Function):
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB formatB = ctx.formatB
state = ctx.state state = ctx.state
assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.' assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous() grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
@ -298,18 +335,22 @@ class MatMul8bitLt(torch.autograd.Function):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output) Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB: if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True) CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None: if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA) grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA: if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, 'col32') C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None: if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape) grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
ctx.grad_shape
)
return grad_A, grad_B, None, None, None, None, None return grad_A, grad_B, None, None, None, None, None
@ -317,9 +358,10 @@ class MatMul8bitLt(torch.autograd.Function):
matmul = MatMul8bitLt.apply matmul = MatMul8bitLt.apply
def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0): def matmul(
A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
):
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
state.threshold = threshold state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, state) return MatMul8bitLt.apply(A, B, out, state)

View File

@ -1,6 +1,7 @@
import ctypes as ct import ctypes as ct
import os import os
from warnings import warn from warnings import warn
from bitsandbytes.cuda_setup import evaluate_cuda_setup from bitsandbytes.cuda_setup import evaluate_cuda_setup
@ -8,17 +9,21 @@ class CUDALibrary_Singleton(object):
_instance = None _instance = None
def __init__(self): def __init__(self):
raise RuntimeError('Call get_instance() instead') raise RuntimeError("Call get_instance() instead")
def initialize(self): def initialize(self):
self.context = {} self.context = {}
binary_name = evaluate_cuda_setup() binary_name = evaluate_cuda_setup()
if not os.path.exists(os.path.dirname(__file__) + f'/{binary_name}'): if not os.path.exists(os.path.dirname(__file__) + f"/{binary_name}"):
print(f'TODO: compile library for specific version: {binary_name}') print(f"TODO: compile library for specific version: {binary_name}")
print('defaulting to libbitsandbytes.so') print("defaulting to libbitsandbytes.so")
self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so') self.lib = ct.cdll.LoadLibrary(
os.path.dirname(__file__) + "/libbitsandbytes.so"
)
else: else:
self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + f'/{binary_name}') self.lib = ct.cdll.LoadLibrary(
os.path.dirname(__file__) + f"/{binary_name}"
)
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
@ -35,6 +40,8 @@ try:
lib.get_cusparse.restype = ct.c_void_p lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True COMPILED_WITH_CUDA = True
except AttributeError: except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. " warn(
"8-bit optimizers and GPU quantization are unavailable.") "The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA = False COMPILED_WITH_CUDA = False

View File

@ -18,31 +18,36 @@ evaluation:
- based on that set the default path - based on that set the default path
""" """
from os import environ as env
from pathlib import Path
from typing import Set, Union
from .utils import warn_of_missing_prerequisite, print_err
import ctypes import ctypes
import shlex import shlex
import subprocess import subprocess
from os import environ as env
from pathlib import Path
from typing import Set, Union
from .utils import print_err, warn_of_missing_prerequisite
def execute_and_return(strCMD): def execute_and_return(strCMD):
proc = subprocess.Popen(shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE) proc = subprocess.Popen(
shlex.split(strCMD), stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
out, err = proc.communicate() out, err = proc.communicate()
out, err = out.decode("UTF-8").strip(), err.decode("UTF-8").strip() out, err = out.decode("UTF-8").strip(), err.decode("UTF-8").strip()
return out, err return out, err
def check_cuda_result(cuda, result_val): def check_cuda_result(cuda, result_val):
if result_val != 0: if result_val != 0:
cuda.cuGetErrorString(result_val, ctypes.byref(error_str)) cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
print(f"Count not initialize CUDA - failure!") print(f"Count not initialize CUDA - failure!")
raise Exception('CUDA exception!') raise Exception("CUDA exception!")
return result_val return result_val
# taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 # taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549
def get_compute_capability(): def get_compute_capability():
libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll")
for libname in libnames: for libname in libnames:
try: try:
cuda = ctypes.CDLL(libname) cuda = ctypes.CDLL(libname)
@ -51,8 +56,7 @@ def get_compute_capability():
else: else:
break break
else: else:
raise OSError("could not load any of: " + ' '.join(libnames)) raise OSError("could not load any of: " + " ".join(libnames))
nGpus = ctypes.c_int() nGpus = ctypes.c_int()
cc_major = ctypes.c_int() cc_major = ctypes.c_int()
@ -69,39 +73,43 @@ def get_compute_capability():
ccs = [] ccs = []
for i in range(nGpus.value): for i in range(nGpus.value):
result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i)) result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
result = check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device)) result = check_cuda_result(
ccs.append(f'{cc_major.value}.{cc_minor.value}') cuda,
cuda.cuDeviceComputeCapability(
ctypes.byref(cc_major), ctypes.byref(cc_minor), device
),
)
ccs.append(f"{cc_major.value}.{cc_minor.value}")
#TODO: handle different compute capabilities; for now, take the max # TODO: handle different compute capabilities; for now, take the max
ccs.sort() ccs.sort()
return ccs[-1] # return ccs[-1]
return ccs
CUDA_RUNTIME_LIB: str = "libcudart.so" CUDA_RUNTIME_LIB: str = "libcudart.so"
def tokenize_paths(paths: str) -> Set[Path]: def tokenize_paths(paths: str) -> Set[Path]:
return { return {Path(ld_path) for ld_path in paths.split(":") if ld_path}
Path(ld_path) for ld_path in paths.split(':')
if ld_path
}
def get_cuda_runtime_lib_path( def get_cuda_runtime_lib_path(
# TODO: replace this with logic for all paths in env vars # TODO: replace this with logic for all paths in env vars
LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH") LD_LIBRARY_PATH: Union[str, None] = env.get("LD_LIBRARY_PATH")
) -> Union[Path, None]: ) -> Union[Path, None]:
""" # TODO: add doc-string """# TODO: add doc-string"""
"""
if not LD_LIBRARY_PATH: if not LD_LIBRARY_PATH:
warn_of_missing_prerequisite( warn_of_missing_prerequisite(
'LD_LIBRARY_PATH is completely missing from environment!' "LD_LIBRARY_PATH is completely missing from environment!"
) )
return None return None
ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH) ld_library_paths: Set[Path] = tokenize_paths(LD_LIBRARY_PATH)
non_existent_directories: Set[Path] = { non_existent_directories: Set[Path] = {
path for path in ld_library_paths path for path in ld_library_paths if not path.exists()
if not path.exists()
} }
if non_existent_directories: if non_existent_directories:
@ -111,7 +119,8 @@ def get_cuda_runtime_lib_path(
) )
cuda_runtime_libs: Set[Path] = { cuda_runtime_libs: Set[Path] = {
path / CUDA_RUNTIME_LIB for path in ld_library_paths path / CUDA_RUNTIME_LIB
for path in ld_library_paths
if (path / CUDA_RUNTIME_LIB).is_file() if (path / CUDA_RUNTIME_LIB).is_file()
} - non_existent_directories } - non_existent_directories
@ -126,26 +135,31 @@ def get_cuda_runtime_lib_path(
single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs)) single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
return single_cuda_runtime_lib_dir return single_cuda_runtime_lib_dir
def evaluate_cuda_setup(): def evaluate_cuda_setup():
cuda_path = get_cuda_runtime_lib_path() cuda_path = get_cuda_runtime_lib_path()
cc = get_compute_capability() cc = get_compute_capability()
binary_name = 'libbitsandbytes_cpu.so' binary_name = "libbitsandbytes_cpu.so"
if not (has_gpu := bool(cc)): if not (has_gpu := bool(cc)):
print('WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library...') print(
"WARNING: No GPU detected! Check our CUDA paths. Processing to load CPU-only library..."
)
return binary_name return binary_name
has_cublaslt = cc in ['7.5', '8.0', '8.6'] has_cublaslt = cc in ["7.5", "8.0", "8.6"]
# TODO: # TODO:
# (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) # (1) Model missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible)
# (2) Multiple CUDA versions installed # (2) Multiple CUDA versions installed
cuda_home = str(Path(cuda_path).parent.parent) cuda_home = str(Path(cuda_path).parent.parent)
ls_output, err = execute_and_return(f'{cuda_home}/bin/nvcc --version') ls_output, err = execute_and_return(f"{cuda_home}/bin/nvcc --version")
cuda_version = ls_output.split('\n')[3].split(',')[-1].strip().lower().replace('v', '') cuda_version = (
major, minor, revision = cuda_version.split('.') ls_output.split("\n")[3].split(",")[-1].strip().lower().replace("v", "")
cuda_version_string = f'{major}{minor}' )
major, minor, revision = cuda_version.split(".")
cuda_version_string = f"{major}{minor}"
binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so' binary_name = f'libbitsandbytes_cuda{cuda_version_string}_{("cublaslt" if has_cublaslt else "")}.so'

View File

@ -1,6 +1,5 @@
import typer import typer
cli = typer.Typer() cli = typer.Typer()

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from .modules import StableEmbedding, Linear8bit, Linear8bitLt, Int8Params from .modules import Int8Params, Linear8bit, Linear8bitLt, StableEmbedding

View File

@ -1,39 +1,59 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import (Any, Callable, Dict, Iterator, Mapping, Optional, Set,
Tuple, TypeVar, Union, overload)
import torch import torch
import bitsandbytes as bnb
from typing import Union, Tuple, Any, Callable, Iterator, Set, Optional, overload, TypeVar, Mapping, Dict
from torch import Tensor, device, dtype
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor, device, dtype, nn
from torch.nn.parameter import Parameter
import bitsandbytes as bnb
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
T = TypeVar('T', bound='torch.nn.Module') T = TypeVar("T", bound="torch.nn.Module")
class StableEmbedding(torch.nn.Embedding): class StableEmbedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, def __init__(
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, self,
sparse: bool = False, _weight: Optional[Tensor] = None) -> None: num_embeddings: int,
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
) -> None:
super(StableEmbedding, self).__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
)
self.norm = torch.nn.LayerNorm(embedding_dim) self.norm = torch.nn.LayerNorm(embedding_dim)
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight) torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9. to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases. PyTorch releases.
''' """
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None: if self.padding_idx is not None:
with torch.no_grad(): with torch.no_grad():
@ -41,29 +61,55 @@ class StableEmbedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
emb = F.embedding( emb = F.embedding(
input, self.weight, self.padding_idx, self.max_norm, input,
self.norm_type, self.scale_grad_by_freq, self.sparse) self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return self.norm(emb) return self.norm(emb)
class Embedding(torch.nn.Embedding): class Embedding(torch.nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, def __init__(
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, self,
sparse: bool = False, _weight: Optional[Tensor] = None) -> None: num_embeddings: int,
super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight) embedding_dim: int,
GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32}) padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
) -> None:
super(Embedding, self).__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
_weight,
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
)
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
torch.nn.init.xavier_uniform_(self.weight) torch.nn.init.xavier_uniform_(self.weight)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()
''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
to make the Layer compatible with Pytorch < 1.9. to make the Layer compatible with Pytorch < 1.9.
This means that if this changes in future PyTorch releases this need to change too This means that if this changes in future PyTorch releases this need to change too
which is cumbersome. However, with this we can ensure compatibility with previous which is cumbersome. However, with this we can ensure compatibility with previous
PyTorch releases. PyTorch releases.
''' """
def _fill_padding_idx_with_zero(self) -> None: def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None: if self.padding_idx is not None:
with torch.no_grad(): with torch.no_grad():
@ -71,13 +117,22 @@ class Embedding(torch.nn.Embedding):
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
emb = F.embedding( emb = F.embedding(
input, self.weight, self.padding_idx, self.max_norm, input,
self.norm_type, self.scale_grad_by_freq, self.sparse) self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return emb return emb
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
def __new__(cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None): def __new__(
cls, data=None, requires_grad=True, has_fp16_weights=False, CB=None, SCB=None
):
cls.has_fp16_weights = has_fp16_weights cls.has_fp16_weights = has_fp16_weights
cls.CB = None cls.CB = None
cls.SCB = None cls.SCB = None
@ -96,14 +151,18 @@ class Int8Params(torch.nn.Parameter):
del CBt del CBt
del SCBt del SCBt
self.data = CB self.data = CB
setattr(self, 'CB', CB) setattr(self, "CB", CB)
setattr(self, 'SCB', SCB) setattr(self, "SCB", SCB)
return self return self
@overload @overload
def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., def to(
non_blocking: bool = ...) -> T: self: T,
device: Optional[Union[int, device]] = ...,
dtype: Optional[Union[dtype, str]] = ...,
non_blocking: bool = ...,
) -> T:
... ...
@overload @overload
@ -115,23 +174,41 @@ class Int8Params(torch.nn.Parameter):
... ...
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
)
if device is not None and device.type == 'cuda' and self.data.device.type == 'cpu': return self.cuda(device) if (
device is not None
and device.type == "cuda"
and self.data.device.type == "cpu"
):
return self.cuda(device)
else: else:
new_param = Int8Params(super().to(device=device, dtype=dtype, non_blocking=non_blocking), requires_grad=self.requires_grad, has_fp16_weights=self.has_fp16_weights) new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
requires_grad=self.requires_grad,
has_fp16_weights=self.has_fp16_weights,
)
new_param.CB = self.CB new_param.CB = self.CB
new_param.SCB = self.SCB new_param.SCB = self.SCB
return new_param return new_param
class Linear8bitLt(nn.Linear): class Linear8bitLt(nn.Linear):
def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, threshold=0.0, index=None): def __init__(
self,
input_features,
output_features,
bias=True,
has_fp16_weights=True,
threshold=0.0,
index=None,
):
super(Linear8bitLt, self).__init__(input_features, output_features, bias) super(Linear8bitLt, self).__init__(input_features, output_features, bias)
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index=index self.index = index
self.state.threshold = threshold self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights self.state.has_fp16_weights = has_fp16_weights
@ -149,9 +226,10 @@ class Linear8bitLt(nn.Linear):
def forward(self, x): def forward(self, x):
self.state.is_training = self.training self.state.is_training = self.training
if self.weight.CB is not None: self.init_8bit_state() if self.weight.CB is not None:
#assert not self.state.has_fp16_weights self.init_8bit_state()
#if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None # assert not self.state.has_fp16_weights
# if not self.state.has_fp16_weights: assert self.state.CB is not None or self.state.CxB is not None
out = bnb.matmul(x, self.weight, state=self.state) out = bnb.matmul(x, self.weight, state=self.state)
@ -166,8 +244,18 @@ class Linear8bitLt(nn.Linear):
return out return out
class Linear8bit(nn.Linear): class Linear8bit(nn.Linear):
def __init__(self, input_features, output_features, bias=True, quant_type='vector', index=None, args=None, sparse_decomp=False): def __init__(
self,
input_features,
output_features,
bias=True,
quant_type="vector",
index=None,
args=None,
sparse_decomp=False,
):
super(Linear8bit, self).__init__(input_features, output_features, bias) super(Linear8bit, self).__init__(input_features, output_features, bias)
self.quant_type = quant_type self.quant_type = quant_type
self.index = index self.index = index
@ -178,15 +266,24 @@ class Linear8bit(nn.Linear):
self.iter += 1 self.iter += 1
if self.iter % self.args.clip_freq == 0: if self.iter % self.args.clip_freq == 0:
with torch.no_grad(): with torch.no_grad():
maxval, maxidx = torch.topk(torch.abs(self.weight.flatten()), k=self.args.clip_idx) maxval, maxidx = torch.topk(
torch.abs(self.weight.flatten()), k=self.args.clip_idx
)
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
print('clip', maxval[-1].item()) print("clip", maxval[-1].item())
self.weight.clip_(-maxval[-1], maxval[-1]) self.weight.clip_(-maxval[-1], maxval[-1])
if self.args is not None: if self.args is not None:
out = bnb.nn.functional.sparse_decomposed_linear8bit(x, self.weight, self.bias, qval=self.args.sparse_decomp_val, quant_type=self.args.quant_type) out = bnb.nn.functional.sparse_decomposed_linear8bit(
x,
self.weight,
self.bias,
qval=self.args.sparse_decomp_val,
quant_type=self.args.quant_type,
)
else: else:
out = bnb.nn.functional.linear8bit(x, self.weight, self.bias, quant_type=self.args.quant_type) out = bnb.nn.functional.linear8bit(
x, self.weight, self.bias, quant_type=self.args.quant_type
)
return out return out

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.cextension import COMPILED_WITH_CUDA from bitsandbytes.cextension import COMPILED_WITH_CUDA

View File

@ -1,12 +1,25 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class Adagrad(Optimizer1State): class Adagrad(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, def __init__(
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
eps=1e-10,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
@ -14,15 +27,39 @@ class Adagrad(Optimizer1State):
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError('Initial accumulator value != 0.0 not supported!') raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError('Lr Decay != 0.0 not supported!') raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, super(Adagrad, self).__init__(
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) "adagrad",
params,
lr,
(0.0, 0.0),
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adagrad8bit(Optimizer1State): class Adagrad8bit(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, def __init__(
optim_bits=8, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
eps=1e-10,
optim_bits=8,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
@ -30,16 +67,40 @@ class Adagrad8bit(Optimizer1State):
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError('Initial accumulator value != 0.0 not supported!') raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError('Lr Decay != 0.0 not supported!') raise ValueError("Lr Decay != 0.0 not supported!")
assert block_wise assert block_wise
super(Adagrad8bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, super(Adagrad8bit, self).__init__(
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) "adagrad",
params,
lr,
(0.0, 0.0),
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adagrad32bit(Optimizer1State): class Adagrad32bit(Optimizer1State):
def __init__(self, params, lr=1e-2, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10, def __init__(
optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-2,
lr_decay=0,
weight_decay=0,
initial_accumulator_value=0,
eps=1e-10,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
@ -47,8 +108,19 @@ class Adagrad32bit(Optimizer1State):
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if initial_accumulator_value != 0.0: if initial_accumulator_value != 0.0:
raise ValueError('Initial accumulator value != 0.0 not supported!') raise ValueError("Initial accumulator value != 0.0 not supported!")
if lr_decay != 0.0: if lr_decay != 0.0:
raise ValueError('Lr Decay != 0.0 not supported!') raise ValueError("Lr Decay != 0.0 not supported!")
super(Adagrad32bit, self).__init__('adagrad', params, lr, (0.0, 0.0), eps, super(Adagrad32bit, self).__init__(
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) "adagrad",
params,
lr,
(0.0, 0.0),
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)

View File

@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
@ -8,29 +8,97 @@ import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from bitsandbytes.optim.optimizer import Optimizer2State
import bitsandbytes.functional as F import bitsandbytes.functional as F
from bitsandbytes.optim.optimizer import Optimizer2State
class Adam(Optimizer2State): class Adam(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0, amsgrad=False, optim_bits=32, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
super(Adam, self).__init__('adam', params, lr, betas, eps, lr=1e-3,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(Adam, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adam8bit(Optimizer2State): class Adam8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0, amsgrad=False, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
super(Adam8bit, self).__init__('adam', params, lr, betas, eps, lr=1e-3,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(Adam8bit, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class Adam32bit(Optimizer2State): class Adam32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0, amsgrad=False, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
super(Adam32bit, self).__init__('adam', params, lr, betas, eps, lr=1e-3,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(Adam32bit, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AnalysisAdam(torch.optim.Optimizer): class AnalysisAdam(torch.optim.Optimizer):
@ -68,8 +136,8 @@ class AnalysisAdam(torch.optim.Optimizer):
eps=1e-8, eps=1e-8,
weight_decay=0, weight_decay=0,
amsgrad=False, amsgrad=False,
bnb_analysis='dynamic-blockwise', bnb_analysis="dynamic-blockwise",
savedir=None savedir=None,
): ):
defaults = dict( defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
@ -124,9 +192,13 @@ class AnalysisAdam(torch.optim.Optimizer):
state["exp_avg"] = torch.zeros_like(p_data_fp32) state["exp_avg"] = torch.zeros_like(p_data_fp32)
# Exponential moving average of squared gradient values # Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
state['abserrors'] = torch.zeros((256, 256), device=p_data_fp32.device) state["abserrors"] = torch.zeros(
state['relerrors'] = torch.zeros((256, 256), device=p_data_fp32.device) (256, 256), device=p_data_fp32.device
state['counts'] = torch.zeros((256, 256), device=p_data_fp32.device) )
state["relerrors"] = torch.zeros(
(256, 256), device=p_data_fp32.device
)
state["counts"] = torch.zeros((256, 256), device=p_data_fp32.device)
if amsgrad: if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values # Maintains max of all exp. moving avg. of sq. grad. values
state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
@ -143,9 +215,9 @@ class AnalysisAdam(torch.optim.Optimizer):
bias_correction1 = 1 - beta1 ** state["step"] bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"] bias_correction2 = 1 - beta2 ** state["step"]
step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
e = state['abserrors'] e = state["abserrors"]
rele = state['relerrors'] rele = state["relerrors"]
counts = state['counts'] counts = state["counts"]
if group["weight_decay"] != 0: if group["weight_decay"] != 0:
p_data_fp32.add_( p_data_fp32.add_(
@ -156,77 +228,84 @@ class AnalysisAdam(torch.optim.Optimizer):
if amsgrad: if amsgrad:
max_exp_avg_sq = state["max_exp_avg_sq"] max_exp_avg_sq = state["max_exp_avg_sq"]
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group["eps"]) denom = exp_avg_sq.sqrt().add_(group["eps"])
update_fp32 = exp_avg/denom update_fp32 = exp_avg / denom
if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000*1000: if p_data_fp32.numel() <= 8192 or p_data_fp32.numel() > 50000 * 1000:
# embedding layer or too small # embedding layer or too small
p_data_fp32 += -step_size*update_fp32 p_data_fp32 += -step_size * update_fp32
else: else:
if self.analysis == 'dynamic-blockwise': if self.analysis == "dynamic-blockwise":
code1 = F.create_dynamic_map(signed=True).to(p.device) code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize_blockwise(exp_avg, code=code1) C1, S1 = F.quantize_blockwise(exp_avg, code=code1)
state1 = F.dequantize_blockwise(C1, S1) state1 = F.dequantize_blockwise(C1, S1)
C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2) C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2)
state2 = F.dequantize_blockwise(C2, S2) state2 = F.dequantize_blockwise(C2, S2)
elif self.analysis == 'dynamic': elif self.analysis == "dynamic":
code1 = F.create_dynamic_map(signed=True).to(p.device) code1 = F.create_dynamic_map(signed=True).to(p.device)
code2 = F.create_dynamic_map(signed=False).to(p.device) code2 = F.create_dynamic_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1) C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1) state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2) C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2) state2 = F.dequantize(C2, S2)
elif self.analysis == 'linear': elif self.analysis == "linear":
code1 = F.create_linear_map(signed=True).to(p.device) code1 = F.create_linear_map(signed=True).to(p.device)
code2 = F.create_linear_map(signed=False).to(p.device) code2 = F.create_linear_map(signed=False).to(p.device)
C1, S1 = F.quantize(exp_avg, code=code1) C1, S1 = F.quantize(exp_avg, code=code1)
state1 = F.dequantize(C1, S1) state1 = F.dequantize(C1, S1)
C2, S2 = F.quantize(exp_avg_sq, code=code2) C2, S2 = F.quantize(exp_avg_sq, code=code2)
state2 = F.dequantize(C2, S2) state2 = F.dequantize(C2, S2)
elif self.analysis == 'quantile': elif self.analysis == "quantile":
code1 = F.estimate_quantiles(exp_avg) code1 = F.estimate_quantiles(exp_avg)
code2 = F.estimate_quantiles(exp_avg_sq) code2 = F.estimate_quantiles(exp_avg_sq)
C1 = F.quantize_no_absmax(exp_avg, code=code1) C1 = F.quantize_no_absmax(exp_avg, code=code1)
state1 = F.dequantize_no_absmax(C1, code1) state1 = F.dequantize_no_absmax(C1, code1)
C2 = F.quantize_no_absmax(exp_avg_sq, code=code2) C2 = F.quantize_no_absmax(exp_avg_sq, code=code2)
state2 = F.dequantize_no_absmax(C2, code2) state2 = F.dequantize_no_absmax(C2, code2)
elif self.analysis == 'my-quantization-routine': elif self.analysis == "my-quantization-routine":
pass pass
# 1. get code # 1. get code
# 2. quantize # 2. quantize
# 3. dequantize # 3. dequantize
# Error will be calculated automatically! # Error will be calculated automatically!
else: else:
raise ValueError(f'Invalid analysis value: {self.analysis}!') raise ValueError(f"Invalid analysis value: {self.analysis}!")
denom = state2.sqrt().add_(group["eps"]) denom = state2.sqrt().add_(group["eps"])
update_8bit = state1/denom update_8bit = state1 / denom
abserr = torch.abs(update_8bit-update_fp32) abserr = torch.abs(update_8bit - update_fp32)
relerr = abserr/torch.abs(update_fp32+1e-6) relerr = abserr / torch.abs(update_fp32 + 1e-6)
C1, C2 = C1.int(), C2.int() C1, C2 = C1.int(), C2.int()
F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr)
F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr)
F.histogram_scatter_add_2d(counts, C1.int(), C2.int(), torch.ones_like(abserr)) F.histogram_scatter_add_2d(
counts, C1.int(), C2.int(), torch.ones_like(abserr)
p_data_fp32 += -step_size*update_fp32 )
p_data_fp32 += -step_size * update_fp32
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
if self.savedir != '' and state['step'] % 100 == 0: if self.savedir != "" and state["step"] % 100 == 0:
if not os.path.exists(self.savedir): os.makedirs(self.savedir) if not os.path.exists(self.savedir):
shapestr = '_'.join([str(dim) for dim in p_data_fp32.shape]) os.makedirs(self.savedir)
pathe = os.path.join(self.savedir, f'{p_id}_{shapestr}_abserr.pkl') shapestr = "_".join([str(dim) for dim in p_data_fp32.shape])
pathrele = os.path.join(self.savedir, f'{p_id}_{shapestr}_relerr.pkl') pathe = os.path.join(
pathcounts = os.path.join(self.savedir, f'{p_id}_{shapestr}_counts.pkl') self.savedir, f"{p_id}_{shapestr}_abserr.pkl"
)
pathrele = os.path.join(
self.savedir, f"{p_id}_{shapestr}_relerr.pkl"
)
pathcounts = os.path.join(
self.savedir, f"{p_id}_{shapestr}_counts.pkl"
)
torch.save(e, pathe) torch.save(e, pathe)
torch.save(rele, pathrele) torch.save(rele, pathrele)
torch.save(counts, pathcounts) torch.save(counts, pathcounts)
@ -234,6 +313,4 @@ class AnalysisAdam(torch.optim.Optimizer):
if p.data.dtype in {torch.float16, torch.bfloat16}: if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32) p.data.copy_(p_data_fp32)
return loss return loss

View File

@ -1,27 +1,93 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State from bitsandbytes.optim.optimizer import Optimizer2State
class AdamW(Optimizer2State): class AdamW(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=1e-2, amsgrad=False, optim_bits=32, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
super(AdamW, self).__init__('adam', params, lr, betas, eps, lr=1e-3,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(AdamW, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW8bit(Optimizer2State): class AdamW8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=1e-2, amsgrad=False, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
super(AdamW8bit, self).__init__('adam', params, lr, betas, eps, lr=1e-3,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(AdamW8bit, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class AdamW32bit(Optimizer2State): class AdamW32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=1e-2, amsgrad=False, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
super(AdamW32bit, self).__init__('adam', params, lr, betas, eps, lr=1e-3,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
super(AdamW32bit, self).__init__(
"adam",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)

View File

@ -1,28 +1,105 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer2State from bitsandbytes.optim.optimizer import Optimizer2State
class LAMB(Optimizer2State): class LAMB(Optimizer2State):
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): params,
super(LAMB, self).__init__('lamb', params, lr, betas, eps, lr=1e-3,
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)
class LAMB8bit(Optimizer2State): class LAMB8bit(Optimizer2State):
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): params,
super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps, lr=1e-3,
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB8bit, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)
class LAMB32bit(Optimizer2State): class LAMB32bit(Optimizer2State):
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0, amsgrad=False, adam_w_mode=True, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0): params,
super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps, lr=1e-3,
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0) bias_correction=True,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
adam_w_mode=True,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=False,
max_unorm=1.0,
):
super(LAMB32bit, self).__init__(
"lamb",
params,
lr,
betas,
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
max_unorm=1.0,
)

View File

@ -1,43 +1,121 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from torch.optim import Optimizer from torch.optim import Optimizer
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class LARS(Optimizer1State): class LARS(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0, def __init__(
weight_decay=0, nesterov=False, optim_bits=32, args=None, self,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f'LARS without momentum is not supported!') raise NotImplementedError(f"LARS without momentum is not supported!")
super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0, super(LARS, self).__init__(
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) "lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class LARS8bit(Optimizer1State): class LARS8bit(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0, def __init__(
weight_decay=0, nesterov=False, args=None, self,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f'LARS without momentum is not supported!') raise NotImplementedError(f"LARS without momentum is not supported!")
super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, super(LARS8bit, self).__init__(
weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) "lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class LARS32bit(Optimizer1State): class LARS32bit(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0, def __init__(
weight_decay=0, nesterov=False, args=None, self,
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02): params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
max_unorm=0.02,
):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f'LARS without momentum is not supported!') raise NotImplementedError(f"LARS without momentum is not supported!")
super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0, super(LARS32bit, self).__init__(
weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False) "lars",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
max_unorm=max_unorm,
block_wise=False,
)
class PytorchLARS(Optimizer): class PytorchLARS(Optimizer):
def __init__(self, params, lr=0.01, momentum=0, dampening=0, def __init__(
weight_decay=0, nesterov=False, max_unorm=0.02): self,
params,
lr=0.01,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
max_unorm=0.02,
):
if lr < 0.0: if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0: if momentum < 0.0:
@ -45,8 +123,14 @@ class PytorchLARS(Optimizer):
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, defaults = dict(
weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm) lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
max_unorm=max_unorm,
)
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(PytorchLARS, self).__init__(params, defaults) super(PytorchLARS, self).__init__(params, defaults)
@ -54,7 +138,7 @@ class PytorchLARS(Optimizer):
def __setstate__(self, state): def __setstate__(self, state):
super(PytorchLARS, self).__setstate__(state) super(PytorchLARS, self).__setstate__(state)
for group in self.param_groups: for group in self.param_groups:
group.setdefault('nesterov', False) group.setdefault("nesterov", False)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
@ -73,15 +157,16 @@ class PytorchLARS(Optimizer):
params_with_grad = [] params_with_grad = []
d_p_list = [] d_p_list = []
momentum_buffer_list = [] momentum_buffer_list = []
weight_decay = group['weight_decay'] weight_decay = group["weight_decay"]
momentum = group['momentum'] momentum = group["momentum"]
dampening = group['dampening'] dampening = group["dampening"]
nesterov = group['nesterov'] nesterov = group["nesterov"]
max_unorm = group['max_unorm'] max_unorm = group["max_unorm"]
lr = group['lr'] lr = group["lr"]
for p in group['params']: for p in group["params"]:
if p.grad is None: continue if p.grad is None:
continue
state = self.state[p] state = self.state[p]
d_p = p.grad d_p = p.grad
@ -89,16 +174,16 @@ class PytorchLARS(Optimizer):
d_p = d_p.add(param, alpha=weight_decay) d_p = d_p.add(param, alpha=weight_decay)
if momentum != 0: if momentum != 0:
buf = state.get('momentum_buffer', None) buf = state.get("momentum_buffer", None)
if buf is None: if buf is None:
buf = torch.clone(d_p).detach() buf = torch.clone(d_p).detach()
state['momentum_buffer']= buf state["momentum_buffer"] = buf
else: else:
buf.mul_(momentum).add_(d_p, alpha=1 - dampening) buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov: if nesterov:
update = d_p + buf*momentum update = d_p + buf * momentum
else: else:
update = buf update = buf
@ -107,9 +192,9 @@ class PytorchLARS(Optimizer):
assert p.dtype == torch.float32 assert p.dtype == torch.float32
pnorm = torch.norm(p.detach()) pnorm = torch.norm(p.detach())
unorm = torch.norm(update) unorm = torch.norm(update)
if unorm > max_unorm*pnorm: if unorm > max_unorm * pnorm:
update_scale = max_unorm*pnorm/unorm update_scale = max_unorm * pnorm / unorm
p.add_(update, alpha=-lr*update_scale) p.add_(update, alpha=-lr * update_scale)
return loss return loss

View File

@ -1,13 +1,16 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch from collections import abc as container_abcs
import bitsandbytes.functional as F from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from collections import defaultdict, abc as container_abcs
import torch
import bitsandbytes.functional as F
class MockArgs(object): class MockArgs(object):
def __init__(self, initial_data): def __init__(self, initial_data):
@ -19,7 +22,7 @@ class GlobalOptimManager(object):
_instance = None _instance = None
def __init__(self): def __init__(self):
raise RuntimeError('Call get_instance() instead') raise RuntimeError("Call get_instance() instead")
def initialize(self): def initialize(self):
self.pid2config = {} self.pid2config = {}
@ -38,15 +41,15 @@ class GlobalOptimManager(object):
def register_parameters(self, params): def register_parameters(self, params):
param_groups = list(params) param_groups = list(params)
if not isinstance(param_groups[0], dict): if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}] param_groups = [{"params": param_groups}]
for group_index, group in enumerate(param_groups): for group_index, group in enumerate(param_groups):
for p_index, p in enumerate(group['params']): for p_index, p in enumerate(group["params"]):
if id(p) in self.pid2config: if id(p) in self.pid2config:
self.index2config[(group_index, p_index)] = self.pid2config[id(p)] self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
def override_config(self, parameters, key=None, value=None, key_value_dict=None): def override_config(self, parameters, key=None, value=None, key_value_dict=None):
''' """
Overrides initial optimizer config for specific parameters. Overrides initial optimizer config for specific parameters.
The key-values of the optimizer config for the input parameters are overidden The key-values of the optimizer config for the input parameters are overidden
@ -63,7 +66,7 @@ class GlobalOptimManager(object):
The value for the hyperparamters. The value for the hyperparamters.
key_value_dict : dict key_value_dict : dict
A dictionary with multiple key-values to override. A dictionary with multiple key-values to override.
''' """
self.uses_config_override = True self.uses_config_override = True
if isinstance(parameters, torch.nn.Parameter): if isinstance(parameters, torch.nn.Parameter):
parameters = [parameters] parameters = [parameters]
@ -75,16 +78,16 @@ class GlobalOptimManager(object):
if key_value_dict is not None: if key_value_dict is not None:
for p in parameters: for p in parameters:
if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict) if id(p) in self.pid2config:
else: self.pid2config[id(p)] = key_value_dict self.pid2config[id(p)].update(key_value_dict)
else:
self.pid2config[id(p)] = key_value_dict
def register_module_override(self, module, param_name, config): def register_module_override(self, module, param_name, config):
self.module_weight_config_triple.append((module, param_name, config)) self.module_weight_config_triple.append((module, param_name, config))
class Optimizer8bit(torch.optim.Optimizer): class Optimizer8bit(torch.optim.Optimizer):
def __init__(self, params, defaults, optim_bits=32): def __init__(self, params, defaults, optim_bits=32):
super(Optimizer8bit, self).__init__(params, defaults) super(Optimizer8bit, self).__init__(params, defaults)
self.initialized = False self.initialized = False
@ -92,23 +95,32 @@ class Optimizer8bit(torch.optim.Optimizer):
self.mng = GlobalOptimManager.get_instance() self.mng = GlobalOptimManager.get_instance()
self.non_castable_tensor_keys = set( self.non_castable_tensor_keys = set(
['qmap1', 'qmap2', [
'max1', 'max2', "qmap1",
'new_max1', 'new_max2', "qmap2",
'state1', 'state2', "max1",
'gnorm_vec', 'absmax1', 'absmax2', "max2",
'unorm_vec']) "new_max1",
"new_max2",
"state1",
"state2",
"gnorm_vec",
"absmax1",
"absmax2",
"unorm_vec",
]
)
if optim_bits == 8: self.fill_qmap() if optim_bits == 8:
self.fill_qmap()
def fill_qmap(self): def fill_qmap(self):
self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True) self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True)
self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False) self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False)
def __setstate__(self, state): def __setstate__(self, state):
super(Optimizer8bit, self).__setstate__(state) super(Optimizer8bit, self).__setstate__(state)
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
r"""Loads the optimizer state. r"""Loads the optimizer state.
@ -120,21 +132,28 @@ class Optimizer8bit(torch.optim.Optimizer):
state_dict = deepcopy(state_dict) state_dict = deepcopy(state_dict)
# Validate the state_dict # Validate the state_dict
groups = self.param_groups groups = self.param_groups
saved_groups = state_dict['param_groups'] saved_groups = state_dict["param_groups"]
if len(groups) != len(saved_groups): if len(groups) != len(saved_groups):
raise ValueError("loaded state dict has a different number of " raise ValueError(
"parameter groups") "loaded state dict has a different number of " "parameter groups"
param_lens = (len(g['params']) for g in groups) )
saved_lens = (len(g['params']) for g in saved_groups) param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group " raise ValueError(
"that doesn't match the size of optimizer's group") "loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Update the state # Update the state
id_map = {old_id: p for old_id, p in id_map = {
zip(chain.from_iterable((g['params'] for g in saved_groups)), old_id: p
chain.from_iterable((g['params'] for g in groups)))} for old_id, p in zip(
chain.from_iterable((g["params"] for g in saved_groups)),
chain.from_iterable((g["params"] for g in groups)),
)
}
def cast(param, value): def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param.""" r"""Make a deep copy of value, casting all tensors to device of param."""
@ -161,7 +180,7 @@ class Optimizer8bit(torch.optim.Optimizer):
# State that is not assigned to params is copied as is (needed for # State that is not assigned to params is copied as is (needed for
# backward compatibility). # backward compatibility).
state = defaultdict(dict) state = defaultdict(dict)
for k, v in state_dict['state'].items(): for k, v in state_dict["state"].items():
if k in id_map: if k in id_map:
param = id_map[k] param = id_map[k]
state[param] = cast(param, v) state[param] = cast(param, v)
@ -170,15 +189,15 @@ class Optimizer8bit(torch.optim.Optimizer):
# Update parameter groups, setting their 'params' value # Update parameter groups, setting their 'params' value
def update_group(group, new_group): def update_group(group, new_group):
new_group['params'] = group['params'] new_group["params"] = group["params"]
return new_group return new_group
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups)] param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups}) self.__setstate__({"state": state, "param_groups": param_groups})
def to_gpu(self): def to_gpu(self):
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']): for pindex, p in enumerate(group["params"]):
if p in self.state: if p in self.state:
values = self.state[p] values = self.state[p]
for k, v in values.items(): for k, v in values.items():
@ -189,17 +208,23 @@ class Optimizer8bit(torch.optim.Optimizer):
for module, attr, config in self.mng.module_weight_config_triple: for module, attr, config in self.mng.module_weight_config_triple:
pmodule = getattr(module, attr) pmodule = getattr(module, attr)
assert pmodule is not None assert pmodule is not None
assert isinstance(pmodule, torch.Tensor) or isinstance(pmodule, torch.Parameter) assert isinstance(pmodule, torch.Tensor) or isinstance(
pmodule, torch.Parameter
)
found = False found = False
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
if found: break if found:
for pindex, p in enumerate(group['params']): break
if found: break for pindex, p in enumerate(group["params"]):
if found:
break
if id(p) == id(pmodule): if id(p) == id(pmodule):
# found the matching parameter # found the matching parameter
# init override # init override
self.mng.pid2config[id(p)] = config self.mng.pid2config[id(p)] = config
self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[id(p)] self.mng.index2config[(gindex, pindex)] = self.mng.pid2config[
id(p)
]
found = True found = True
@torch.no_grad() @torch.no_grad()
@ -219,11 +244,11 @@ class Optimizer8bit(torch.optim.Optimizer):
if not self.initialized: if not self.initialized:
self.check_overrides() self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True self.initialized = True
for gindex, group in enumerate(self.param_groups): for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group['params']): for pindex, p in enumerate(group["params"]):
if p.grad is None: if p.grad is None:
continue continue
state = self.state[p] state = self.state[p]
@ -236,58 +261,70 @@ class Optimizer8bit(torch.optim.Optimizer):
def get_config(self, gindex, pindex, group): def get_config(self, gindex, pindex, group):
config = {} config = {}
config['betas'] = group['betas'] config["betas"] = group["betas"]
config['eps'] = group['eps'] config["eps"] = group["eps"]
config['weight_decay'] = group['weight_decay'] config["weight_decay"] = group["weight_decay"]
config['lr'] = group['lr'] config["lr"] = group["lr"]
config['optim_bits'] = self.args.optim_bits config["optim_bits"] = self.args.optim_bits
config['min_8bit_size'] = self.args.min_8bit_size config["min_8bit_size"] = self.args.min_8bit_size
config['percentile_clipping'] = self.args.percentile_clipping config["percentile_clipping"] = self.args.percentile_clipping
config['block_wise'] = self.args.block_wise config["block_wise"] = self.args.block_wise
config['max_unorm'] = self.args.max_unorm config["max_unorm"] = self.args.max_unorm
config['skip_zeros'] = self.args.skip_zeros config["skip_zeros"] = self.args.skip_zeros
if (gindex, pindex) in self.mng.index2config: if (gindex, pindex) in self.mng.index2config:
config.update(self.mng.index2config[(gindex, pindex)]) config.update(self.mng.index2config[(gindex, pindex)])
return config return config
def init_state(self, group, p, gindex, pindex): def init_state(self, group, p, gindex, pindex):
raise NotImplementedError(f'init_state method needs to be overidden') raise NotImplementedError(f"init_state method needs to be overidden")
def update_step(self, group, p, gindex, pindex): def update_step(self, group, p, gindex, pindex):
raise NotImplementedError(f'The update_step method needs to be overidden') raise NotImplementedError(f"The update_step method needs to be overidden")
class Optimizer2State(Optimizer8bit): class Optimizer2State(Optimizer8bit):
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, def __init__(
weight_decay=0.0, optim_bits=32, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, optimizer_name,
skip_zeros=False): params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps)) raise ValueError("Invalid epsilon value: {}".format(eps))
if isinstance(betas, str): if isinstance(betas, str):
# format: '(beta1, beta2)' # format: '(beta1, beta2)'
betas = betas.replace('(', '').replace(')', '').strip().split(',') betas = betas.replace("(", "").replace(")", "").strip().split(",")
betas = [float(b) for b in betas] betas = [float(b) for b in betas]
for i in range(len(betas)): for i in range(len(betas)):
if not 0.0 <= betas[i] < 1.0: if not 0.0 <= betas[i] < 1.0:
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
weight_decay=weight_decay)
super(Optimizer2State, self).__init__(params, defaults, optim_bits) super(Optimizer2State, self).__init__(params, defaults, optim_bits)
if args is None: if args is None:
args = {} args = {}
args['optim_bits'] = optim_bits args["optim_bits"] = optim_bits
args['percentile_clipping'] = 100 args["percentile_clipping"] = 100
args['min_8bit_size'] = min_8bit_size args["min_8bit_size"] = min_8bit_size
args['percentile_clipping'] = percentile_clipping args["percentile_clipping"] = percentile_clipping
args['block_wise'] = block_wise args["block_wise"] = block_wise
args['max_unorm'] = max_unorm args["max_unorm"] = max_unorm
args['skip_zeros'] = skip_zeros args["skip_zeros"] = skip_zeros
self.args = MockArgs(args) self.args = MockArgs(args)
else: else:
@ -299,50 +336,83 @@ class Optimizer2State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex): def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group) config = self.get_config(gindex, pindex, group)
if config['optim_bits'] == 32: if config["optim_bits"] == 32:
dtype = torch.float32 dtype = torch.float32
elif config['optim_bits'] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config['min_8bit_size']: dtype = torch.float32 if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
state = self.state[p] state = self.state[p]
state['step'] = 0 state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) state["state1"] = torch.zeros_like(
state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
state["state2"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state['step'] == 0: if state["step"] == 0:
if 'dynamic' not in self.name2qmap: self.fill_qmap() if "dynamic" not in self.name2qmap:
self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) self.fill_qmap()
self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device) self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to(p.device)
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) state["state1"] = torch.zeros_like(
state['qmap1'] = self.name2qmap['dynamic'] p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"]
state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) state["state2"] = torch.zeros_like(
state['qmap2'] = self.name2qmap['udynamic'] p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap2"] = self.name2qmap["udynamic"]
if config['block_wise']: if config["block_wise"]:
n = p.numel() n = p.numel()
blocks = n//2048 blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0 blocks += 1 if n % 2048 > 0 else 0
state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state["absmax1"] = torch.zeros(
state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) (blocks,), dtype=torch.float32, device=p.device
)
state["absmax2"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
else: else:
state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) state["new_max1"] = torch.zeros(
state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) (1,), dtype=torch.float32, device=p.device
state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device) )
state["max2"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state["new_max2"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config['percentile_clipping'] < 100: if config["percentile_clipping"] < 100:
state['gnorm_vec'] = torch.zeros((100,), device=p.device) state["gnorm_vec"] = torch.zeros((100,), device=p.device)
if config['max_unorm'] > 0.0: if config["max_unorm"] > 0.0:
state['unorm_vec'] = torch.zeros((1,), device=p.device) state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad() @torch.no_grad()
def update_step(self, group, p, gindex, pindex): def update_step(self, group, p, gindex, pindex):
@ -351,41 +421,101 @@ class Optimizer2State(Optimizer8bit):
config = self.get_config(gindex, pindex, group) config = self.get_config(gindex, pindex, group)
state['step'] += 1 state["step"] += 1
step = state['step'] step = state["step"]
if config['percentile_clipping'] < 100: if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
)
else: else:
gnorm_scale = 1.0 gnorm_scale = 1.0
if state['state1'].dtype == torch.float: if state["state1"].dtype == torch.float:
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], F.optimizer_update_32bit(
state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale, self.optimizer_name,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros']) grad,
p,
state["state1"],
config["betas"][0],
config["eps"],
step,
config["lr"],
state["state2"],
config["betas"][1],
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state['state1'].dtype == torch.uint8 and not config['block_wise']: elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], F.optimizer_update_8bit(
config['eps'], step, config['lr'], self.optimizer_name,
state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'], grad,
config['weight_decay'], gnorm_scale=gnorm_scale, p,
unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["max1"],
state["max2"],
state["new_max1"],
state["new_max2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
unorm_vec=state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
# swap maxes # swap maxes
state['max1'], state['new_max1'] = state['new_max1'], state['max1'] state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
state['max2'], state['new_max2'] = state['new_max2'], state['max2'] state["max2"], state["new_max2"] = state["new_max2"], state["max2"]
elif state['state1'].dtype == torch.uint8 and config['block_wise']: elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], F.optimizer_update_8bit_blockwise(
config['eps'], step, config['lr'], self.optimizer_name,
state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'], grad,
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) p,
state["state1"],
state["state2"],
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
state["qmap2"],
state["absmax1"],
state["absmax2"],
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
)
class Optimizer1State(Optimizer8bit): class Optimizer1State(Optimizer8bit):
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8, def __init__(
weight_decay=0.0, optim_bits=32, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0, optimizer_name,
skip_zeros=False): params,
lr=1e-3,
betas=(0.9, 0.0),
eps=1e-8,
weight_decay=0.0,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
max_unorm=0.0,
skip_zeros=False,
):
if not 0.0 <= lr: if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr)) raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps: if not 0.0 <= eps:
@ -395,19 +525,18 @@ class Optimizer1State(Optimizer8bit):
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}") raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
weight_decay=weight_decay)
super(Optimizer1State, self).__init__(params, defaults, optim_bits) super(Optimizer1State, self).__init__(params, defaults, optim_bits)
if args is None: if args is None:
args = {} args = {}
args['optim_bits'] = optim_bits args["optim_bits"] = optim_bits
args['percentile_clipping'] = 100 args["percentile_clipping"] = 100
args['min_8bit_size'] = min_8bit_size args["min_8bit_size"] = min_8bit_size
args['percentile_clipping'] = percentile_clipping args["percentile_clipping"] = percentile_clipping
args['block_wise'] = block_wise args["block_wise"] = block_wise
args['max_unorm'] = max_unorm args["max_unorm"] = max_unorm
args['skip_zeros'] = skip_zeros args["skip_zeros"] = skip_zeros
self.args = MockArgs(args) self.args = MockArgs(args)
else: else:
@ -419,43 +548,61 @@ class Optimizer1State(Optimizer8bit):
def init_state(self, group, p, gindex, pindex): def init_state(self, group, p, gindex, pindex):
config = self.get_config(gindex, pindex, group) config = self.get_config(gindex, pindex, group)
if config['optim_bits'] == 32: if config["optim_bits"] == 32:
dtype = torch.float32 dtype = torch.float32
elif config['optim_bits'] == 8: elif config["optim_bits"] == 8:
dtype = torch.uint8 dtype = torch.uint8
else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}') else:
raise NotImplementedError(
f'Amount of optimizer bits not supported: {config["optim_bits"]}'
)
if p.numel() < config['min_8bit_size']: dtype = torch.float32 if p.numel() < config["min_8bit_size"]:
dtype = torch.float32
state = self.state[p] state = self.state[p]
state['step'] = 0 state["step"] = 0
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096): if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device) state["state1"] = torch.zeros_like(
p,
memory_format=torch.preserve_format,
dtype=torch.float32,
device=p.device,
)
elif dtype == torch.uint8: elif dtype == torch.uint8:
if state['step'] == 0: if state["step"] == 0:
if 'dynamic' not in self.name2qmap: self.fill_qmap() if "dynamic" not in self.name2qmap:
self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device) self.fill_qmap()
self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to(p.device)
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device) state["state1"] = torch.zeros_like(
state['qmap1'] = self.name2qmap['dynamic'] p,
memory_format=torch.preserve_format,
dtype=torch.uint8,
device=p.device,
)
state["qmap1"] = self.name2qmap["dynamic"]
if config['block_wise']: if config["block_wise"]:
n = p.numel() n = p.numel()
blocks = n//2048 blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0 blocks += 1 if n % 2048 > 0 else 0
state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device) state["absmax1"] = torch.zeros(
(blocks,), dtype=torch.float32, device=p.device
)
else: else:
state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) state["max1"] = torch.zeros((1,), dtype=torch.float32, device=p.device)
state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device) state["new_max1"] = torch.zeros(
(1,), dtype=torch.float32, device=p.device
)
if config['percentile_clipping'] < 100: if config["percentile_clipping"] < 100:
state['gnorm_vec'] = torch.zeros((100,), device=p.device) state["gnorm_vec"] = torch.zeros((100,), device=p.device)
if config['max_unorm'] > 0.0:
state['unorm_vec'] = torch.zeros((1,), device=p.device)
if config["max_unorm"] > 0.0:
state["unorm_vec"] = torch.zeros((1,), device=p.device)
@torch.no_grad() @torch.no_grad()
def update_step(self, group, p, gindex, pindex): def update_step(self, group, p, gindex, pindex):
@ -464,29 +611,77 @@ class Optimizer1State(Optimizer8bit):
config = self.get_config(gindex, pindex, group) config = self.get_config(gindex, pindex, group)
state['step'] += 1 state["step"] += 1
step = state['step'] step = state["step"]
if config['percentile_clipping'] < 100: if config["percentile_clipping"] < 100:
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping']) current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(
grad, state["gnorm_vec"], step, config["percentile_clipping"]
)
else: else:
gnorm_scale = 1.0 gnorm_scale = 1.0
if state['state1'].dtype == torch.float: if state["state1"].dtype == torch.float:
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], F.optimizer_update_32bit(
None, 0.0, config['weight_decay'], gnorm_scale, self.optimizer_name,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], grad,
skip_zeros=config['skip_zeros']) p,
state["state1"],
config["betas"][0],
config["eps"],
step,
config["lr"],
None,
0.0,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
skip_zeros=config["skip_zeros"],
)
elif state['state1'].dtype == torch.uint8 and not config['block_wise']: elif state["state1"].dtype == torch.uint8 and not config["block_wise"]:
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], F.optimizer_update_8bit(
config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None, self.optimizer_name,
config['weight_decay'], gnorm_scale, grad,
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) p,
state["state1"],
None,
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
None,
state["max1"],
None,
state["new_max1"],
None,
config["weight_decay"],
gnorm_scale,
state["unorm_vec"] if config["max_unorm"] > 0.0 else None,
max_unorm=config["max_unorm"],
)
state['max1'], state['new_max1'] = state['new_max1'], state['max1'] state["max1"], state["new_max1"] = state["new_max1"], state["max1"]
elif state['state1'].dtype == torch.uint8 and config['block_wise']: elif state["state1"].dtype == torch.uint8 and config["block_wise"]:
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], F.optimizer_update_8bit_blockwise(
config['eps'], step, config['lr'], self.optimizer_name,
state['qmap1'], None, state['absmax1'], None, grad,
config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) p,
state["state1"],
None,
config["betas"][0],
config["betas"][1],
config["eps"],
step,
config["lr"],
state["qmap1"],
None,
state["absmax1"],
None,
config["weight_decay"],
gnorm_scale=gnorm_scale,
skip_zeros=config["skip_zeros"],
)

View File

@ -1,36 +1,109 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class RMSprop(Optimizer1State): class RMSprop(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None, def __init__(
min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0: if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered: if centered:
raise NotImplementedError(f'Centered RMSprop is not supported!') raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, super(RMSprop, self).__init__(
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) "rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class RMSprop8bit(Optimizer1State): class RMSprop8bit(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, def __init__(
min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0: if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered: if centered:
raise NotImplementedError(f'Centered RMSprop is not supported!') raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, super(RMSprop8bit, self).__init__(
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) "rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class RMSprop32bit(Optimizer1State): class RMSprop32bit(Optimizer1State):
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None, def __init__(
min_8bit_size=4096, percentile_clipping=100, block_wise=True): self,
params,
lr=1e-2,
alpha=0.99,
eps=1e-8,
weight_decay=0,
momentum=0,
centered=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if alpha == 0: if alpha == 0:
raise NotImplementedError(f'RMSprop with alpha==0.0 is not supported!') raise NotImplementedError(f"RMSprop with alpha==0.0 is not supported!")
if centered: if centered:
raise NotImplementedError(f'Centered RMSprop is not supported!') raise NotImplementedError(f"Centered RMSprop is not supported!")
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps, super(RMSprop32bit, self).__init__(
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) "rmsprop",
params,
lr,
(alpha, momentum),
eps,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)

View File

@ -1,32 +1,99 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from bitsandbytes.optim.optimizer import Optimizer1State from bitsandbytes.optim.optimizer import Optimizer1State
class SGD(Optimizer1State): class SGD(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0, def __init__(
weight_decay=0, nesterov=False, optim_bits=32, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
optim_bits=32,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f'SGD without momentum is not supported!') raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, super(SGD, self).__init__(
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise) "momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
optim_bits,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class SGD8bit(Optimizer1State): class SGD8bit(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0, def __init__(
weight_decay=0, nesterov=False, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f'SGD without momentum is not supported!') raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, super(SGD8bit, self).__init__(
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise) "momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
8,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)
class SGD32bit(Optimizer1State): class SGD32bit(Optimizer1State):
def __init__(self, params, lr, momentum=0, dampening=0, def __init__(
weight_decay=0, nesterov=False, args=None, self,
min_8bit_size=4096, percentile_clipping=100, block_wise=True): params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
args=None,
min_8bit_size=4096,
percentile_clipping=100,
block_wise=True,
):
if momentum == 0: if momentum == 0:
raise NotImplementedError(f'SGD without momentum is not supported!') raise NotImplementedError(f"SGD without momentum is not supported!")
super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0, super(SGD32bit, self).__init__(
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise) "momentum",
params,
lr,
(momentum, dampening),
0.0,
weight_decay,
32,
args,
min_8bit_size,
percentile_clipping,
block_wise,
)

View File

@ -1,7 +1,9 @@
import sys import sys
def print_err(s: str) -> None: def print_err(s: str) -> None:
print(s, file=sys.stderr) print(s, file=sys.stderr)
def warn_of_missing_prerequisite(s: str) -> None: def warn_of_missing_prerequisite(s: str) -> None:
print_err('WARNING, missing pre-requisite: ' + s) print_err("WARNING, missing pre-requisite: " + s)

View File

@ -1,31 +1,45 @@
from itertools import product
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
from itertools import product
def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb): def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
k = 25 k = 25
for i in range(k): for i in range(k):
if dims == 2: if dims == 2:
A = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to(
torch.int8
)
elif dims == 3: elif dims == 3:
A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to(
B = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) torch.int8
)
B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to(torch.int8)
C1 = torch.matmul(A.float(), B.t().float()) C1 = torch.matmul(A.float(), B.t().float())
A2, SA = F.transform(A, 'col32') A2, SA = F.transform(A, "col32")
B2, SB = F.transform(B, 'colx') B2, SB = F.transform(B, "colx")
if dims == 2: if dims == 2:
C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') C2, SC = F.transform(
torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device="cuda"),
"col32",
)
else: else:
C2, SC = F.transform(torch.zeros(A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') C2, SC = F.transform(
torch.zeros(
A.shape[0], A.shape[1], B.shape[0], dtype=torch.int32, device="cuda"
),
"col32",
)
F.igemmlt(A2, B2, C2, SA, SB, SC) F.igemmlt(A2, B2, C2, SA, SB, SC)
C3, S = F.transform(C2, 'row', state=SC) C3, S = F.transform(C2, "row", state=SC)
#torch.testing.assert_allclose(C1, C3.float()) # torch.testing.assert_allclose(C1, C3.float())
#print(C1) # print(C1)
#print(C2) # print(C2)
#print(C3) # print(C3)
allclose = torch.allclose(C1, C3.float()) allclose = torch.allclose(C1, C3.float())
if allclose: if allclose:
print(C1) print(C1)
@ -33,29 +47,29 @@ def test_igemmlt(dim1, dim2, dim3, dim4, dims, ldb):
print(C3) print(C3)
## transposed ## transposed
#A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8) # A = torch.randint(-128, 127, size=(dim4, dim3), device='cuda').to(torch.int8)
#if dims == 2: # if dims == 2:
# B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim1, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(A.float(), B.float().t()) # C1 = torch.matmul(A.float(), B.float().t())
#elif dims == 3: # elif dims == 3:
# B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.float(), A.t().float()) # C1 = torch.matmul(B.float(), A.t().float())
# C1 = C1.permute([2, 0, 1]) # C1 = C1.permute([2, 0, 1])
#A2, SA = F.transform(A, 'col32') # A2, SA = F.transform(A, 'col32')
#B2, SB = F.transform(B, 'colx') # B2, SB = F.transform(B, 'colx')
#if dims == 2: # if dims == 2:
# C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32') # C2, SC = F.transform(torch.zeros(A.shape[0], B.shape[0], dtype=torch.int32, device='cuda'), 'col32')
#else: # else:
# C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda') # C2 = torch.zeros(A.shape[0], B.shape[0], B.shape[1], dtype=torch.int32, device='cuda')
# state = (C2.shape, 'row', A.shape[0]) # state = (C2.shape, 'row', A.shape[0])
# C2, SC = F.transform(C2, 'col32', state=state) # C2, SC = F.transform(C2, 'col32', state=state)
#F.igemmlt(A2, B2, C2, SA, SB, SC) # F.igemmlt(A2, B2, C2, SA, SB, SC)
#C3, S = F.transform(C2, 'row', state=SC, ld=[0]) # C3, S = F.transform(C2, 'row', state=SC, ld=[0])
#torch.testing.assert_allclose(C1, C3.float()) # torch.testing.assert_allclose(C1, C3.float())
## weight update ## weight update
#if dims == 3: # if dims == 3:
# A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8) # A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device='cuda').to(torch.int8)
# B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8) # B = torch.randint(-128, 127, size=(dim1, dim2, dim4), device='cuda').to(torch.int8)
# C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float()) # C1 = torch.matmul(B.view(-1, B.shape[-1]).t().float(), A.view(-1, A.shape[-1]).float())
@ -73,18 +87,18 @@ dims = (2, 3)
ldb = [0] ldb = [0]
n = 2 n = 2
dim1 = torch.randint(1,256, size=(n,)).tolist() dim1 = torch.randint(1, 256, size=(n,)).tolist()
dim2 = torch.randint(32,512, size=(n,)).tolist() dim2 = torch.randint(32, 512, size=(n,)).tolist()
dim3 = torch.randint(32,1024, size=(n,)).tolist() dim3 = torch.randint(32, 1024, size=(n,)).tolist()
dim4 = torch.randint(32,1024, size=(n,)).tolist() dim4 = torch.randint(32, 1024, size=(n,)).tolist()
values = list(product(dim1,dim2,dim3,dim4,dims, ldb)) values = list(product(dim1, dim2, dim3, dim4, dims, ldb))
for ldb in range(32, 4096, 32): for ldb in range(32, 4096, 32):
#for ldb in [None]: # for ldb in [None]:
val = test_igemmlt(2, 2, 2, 2, 2, ldb) val = test_igemmlt(2, 2, 2, 2, 2, ldb)
if val: if val:
print(val, ldb) print(val, ldb)
else: else:
print('nope', ldb) print("nope", ldb)
#for val in values: # for val in values:
#test_igemmlt(*val) # test_igemmlt(*val)

View File

@ -1,19 +1,21 @@
# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Facebook, Inc. and its affiliates.
# #
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os
import glob import glob
from setuptools import setup, find_packages import os
from setuptools import find_packages, setup
libs = list(glob.glob('./bitsandbytes/libbitsandbytes*.so')) libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so"))
libs = [os.path.basename(p) for p in libs] libs = [os.path.basename(p) for p in libs]
print('libs:', libs) print("libs:", libs)
def read(fname): def read(fname):
return open(os.path.join(os.path.dirname(__file__), fname)).read() return open(os.path.join(os.path.dirname(__file__), fname)).read()
setup( setup(
name=f"bitsandbytes", name=f"bitsandbytes",
version=f"0.31.0", version=f"0.31.0",
@ -27,11 +29,11 @@ setup(
entry_points={ entry_points={
"console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"], "console_scripts": ["debug_cuda = bitsandbytes.debug_cli:cli"],
}, },
package_data={'': libs}, package_data={"": libs},
long_description=read('README.md'), long_description=read("README.md"),
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
classifiers=[ classifiers=[
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
'Topic :: Scientific/Engineering :: Artificial Intelligence' "Topic :: Scientific/Engineering :: Artificial Intelligence",
], ],
) )

View File

@ -1,27 +1,38 @@
import pytest
import torch
import bitsandbytes as bnb
from itertools import product from itertools import product
import pytest
import torch
import bitsandbytes as bnb
n = 1 n = 1
k = 25 k = 25
dim1 = torch.randint(16,64, size=(n,)).tolist() dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32,96, size=(n,)).tolist() dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist() dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist() dim4 = torch.randint(32, 96, size=(n,)).tolist()
funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)]
str_funcs = ['bmm', 'matmul'] str_funcs = ["bmm", "matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ['FF', 'TF', 'TT', 'FT'] req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, False), (False, True), (True, True), (True, False)] transpose = [(False, False), (False, True), (True, True), (True, False)]
str_transpose = ['FF', 'FT', 'TT', 'TF'] str_transpose = ["FF", "FT", "TT", "TF"]
dtype = [torch.float32, torch.float16] dtype = [torch.float32, torch.float16]
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose)) values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose))
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose)) str_values = list(
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values] product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) )
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names
)
def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
dim2 = dim2 - (dim2 % 16) dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16) dim3 = dim3 - (dim3 % 16)
@ -32,9 +43,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0]) A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0])
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1]) target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
if not transpose[0] and not transpose[1]: if not transpose[0] and not transpose[1]:
@ -52,9 +65,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175 assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001 assert (idx == 0).sum().item() < n * 0.001
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
@ -78,16 +91,22 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)
# batched matrix multiply # batched matrix multiply
if funcs[0] in [torch.bmm, torch.matmul]: if funcs[0] in [torch.bmm, torch.matmul]:
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) A = torch.randn(
B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1]) size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) )
B = torch.randn(
size=(dim1, dim3, dim4), device="cuda", requires_grad=req_grad[1]
)
target = torch.randn(
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
out_torch = funcs[0](A, B) out_torch = funcs[0](A, B)
@ -95,7 +114,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.01 assert (idx == 0).sum().item() < n * 0.01
torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2)
if any(req_grad): if any(req_grad):
@ -120,16 +139,20 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02 assert (idx == 0).sum().item() < n * 0.02
if funcs[0] in [torch.matmul]: if funcs[0] in [torch.matmul]:
dim1 = dim1 - (dim1 % 16) dim1 = dim1 - (dim1 % 16)
A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) A = torch.randn(
size=(dim1, dim2, dim3), device="cuda", requires_grad=req_grad[0]
)
dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) dimB = (dim4, dim3) if transpose[1] else (dim3, dim4)
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1])
target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) target = torch.randn(
size=(dim1, dim2, dim4), device="cuda", requires_grad=req_grad[1]
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
if transpose[1]: if transpose[1]:
@ -141,9 +164,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
n = out_bnb.numel() n = out_bnb.numel()
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175 assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001 assert (idx == 0).sum().item() < n * 0.001
if any(req_grad): if any(req_grad):
out_bnb.data.copy_(out_torch) out_bnb.data.copy_(out_torch)
@ -167,51 +190,96 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
if req_grad[1]: if req_grad[1]:
n = gradB1.numel() n = gradB1.numel()
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02 assert (idx == 0).sum().item() < n * 0.02
n = 1 n = 1
k = 3 k = 3
dim1 = torch.randint(16,64, size=(n,)).tolist() dim1 = torch.randint(16, 64, size=(n,)).tolist()
dim2 = torch.randint(32,96, size=(n,)).tolist() dim2 = torch.randint(32, 96, size=(n,)).tolist()
dim3 = torch.randint(32,96, size=(n,)).tolist() dim3 = torch.randint(32, 96, size=(n,)).tolist()
dim4 = torch.randint(32,96, size=(n,)).tolist() dim4 = torch.randint(32, 96, size=(n,)).tolist()
#dim1 = (17,) # dim1 = (17,)
#dim2 = (7,) # dim2 = (7,)
#dim3 = (37,) # dim3 = (37,)
#dim4 = (23,) # dim4 = (23,)
decomp = [0.0, 6.0] decomp = [0.0, 6.0]
funcs = [(torch.matmul, bnb.matmul)] funcs = [(torch.matmul, bnb.matmul)]
str_funcs = ['matmul'] str_funcs = ["matmul"]
req_grad = [(False, False), (True, False), (True, True), (False, True)] req_grad = [(False, False), (True, False), (True, True), (False, True)]
req_grad_str = ['FF', 'TF', 'TT', 'FT'] req_grad_str = ["FF", "TF", "TT", "FT"]
transpose = [(False, True), (False, False)] transpose = [(False, True), (False, False)]
str_transpose = ['NT', 'NN'] str_transpose = ["NT", "NN"]
dtype = [torch.float16] dtype = [torch.float16]
has_fp16_weights = [True, False] has_fp16_weights = [True, False]
values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights)) values = list(
str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights)) product(
names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values] dim1,
@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names) dim2,
def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights): dim3,
dim4,
funcs,
dtype,
req_grad,
transpose,
decomp,
has_fp16_weights,
)
)
str_values = list(
product(
dim1,
dim2,
dim3,
dim4,
str_funcs,
dtype,
req_grad_str,
str_transpose,
decomp,
has_fp16_weights,
)
)
names = [
"dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format(
*vals
)
for vals in str_values
]
@pytest.mark.parametrize(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights",
values,
ids=names,
)
def test_matmullt(
dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights
):
dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda') outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda")
for i in range(k): for i in range(k):
# normal multiply # normal multiply
if funcs[0] in [torch.mm, torch.matmul]: if funcs[0] in [torch.mm, torch.matmul]:
A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype) A = torch.randn(
size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype
)
if decomp == 6.0: if decomp == 6.0:
with torch.no_grad(): with torch.no_grad():
A[:, outlier_dim] = 6.0 A[:, outlier_dim] = 6.0
B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype) B = torch.randn(
target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype) size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype
)
target = torch.randn(
size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype
)
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)
B2 = B.clone() B2 = B.clone()
@ -219,8 +287,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
state.threshold = decomp state.threshold = decomp
state.has_fp16_weights = has_fp16_weights state.has_fp16_weights = has_fp16_weights
if not has_fp16_weights: if not has_fp16_weights:
if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous() if not transpose[0] and not transpose[1]:
state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2) B2 = B2.t().contiguous()
(
state.CB,
CBt,
state.SCB,
SCBt,
coo_tensorB,
) = bnb.functional.double_quant(B2)
B2 = state.CB B2 = state.CB
if not transpose[0] and transpose[1]: if not transpose[0] and transpose[1]:
@ -231,12 +306,12 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
out_bnb = funcs[1](A, B2.t(), state=state) out_bnb = funcs[1](A, B2.t(), state=state)
n = out_bnb.numel() n = out_bnb.numel()
err = torch.abs(out_bnb-out_torch).mean().item() err = torch.abs(out_bnb - out_torch).mean().item()
#print(f'abs error {err:.4f}') # print(f'abs error {err:.4f}')
idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1)
assert (idx==0).sum().item() < n*0.0175 assert (idx == 0).sum().item() < n * 0.0175
idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2)
assert (idx==0).sum().item() < n*0.001 assert (idx == 0).sum().item() < n * 0.001
if has_fp16_weights: if has_fp16_weights:
if any(req_grad): if any(req_grad):
@ -263,8 +338,7 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec
assert torch.abs(gradB1).sum() > 0.0 assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0 assert torch.abs(gradB2).sum() > 0.0
idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx==0).sum().item() < n*0.1 assert (idx == 0).sum().item() < n * 0.1
idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx==0).sum().item() < n*0.02 assert (idx == 0).sum().item() < n * 0.02
torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3)

View File

@ -1,37 +1,45 @@
import pytest
import os import os
from typing import List, NamedTuple
from typing import List import pytest
from bitsandbytes.cuda_setup import ( from bitsandbytes.cuda_setup import (CUDA_RUNTIME_LIB, evaluate_cuda_setup,
CUDA_RUNTIME_LIB, get_cuda_runtime_lib_path, tokenize_paths)
get_cuda_runtime_lib_path,
evaluate_cuda_setup,
tokenize_paths,
)
HAPPY_PATH__LD_LIB_TEST_PATHS: List[tuple[str,str]] = [ class InputAndExpectedOutput(NamedTuple):
input: str
output: str
HAPPY_PATH__LD_LIB_TEST_PATHS: List[InputAndExpectedOutput] = [
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f":some/other/dir:dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir:dir/with/{CUDA_RUNTIME_LIB}:", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"some/other/dir::dir/with/{CUDA_RUNTIME_LIB}", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"), (f"dir/with/{CUDA_RUNTIME_LIB}:some/other/dir", f"dir/with/{CUDA_RUNTIME_LIB}"),
(f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so", f"dir/with/{CUDA_RUNTIME_LIB}"), (
f"dir/with/{CUDA_RUNTIME_LIB}:other/dir/libcuda.so",
f"dir/with/{CUDA_RUNTIME_LIB}",
),
] ]
@pytest.mark.parametrize( @pytest.fixture(params=HAPPY_PATH__LD_LIB_TEST_PATHS)
"test_input, expected", def happy_path_path_string(tmpdir, request):
HAPPY_PATH__LD_LIB_TEST_PATHS for path in tokenize_paths(request.param):
) test_dir.mkdir()
if CUDA_RUNTIME_LIB in path:
(test_input / CUDA_RUNTIME_LIB).touch()
@pytest.mark.parametrize("test_input, expected", HAPPY_PATH__LD_LIB_TEST_PATHS)
def test_get_cuda_runtime_lib_path__happy_path( def test_get_cuda_runtime_lib_path__happy_path(
tmp_path, test_input: str, expected: str tmp_path, test_input: str, expected: str
): ):
for path in tokenize_paths(test_input): for path in tokenize_paths(test_input):
assert False == tmp_path / test_input path.mkdir()
test_dir.mkdir() (path / CUDA_RUNTIME_LIB).touch()
(test_input / CUDA_RUNTIME_LIB).touch()
assert get_cuda_runtime_lib_path(test_input) == expected assert get_cuda_runtime_lib_path(test_input) == expected
@ -47,40 +55,33 @@ def test_get_cuda_runtime_lib_path__unhappy_path(tmp_path, test_input: str):
(test_input / CUDA_RUNTIME_LIB).touch() (test_input / CUDA_RUNTIME_LIB).touch()
with pytest.raises(FileNotFoundError) as err_info: with pytest.raises(FileNotFoundError) as err_info:
get_cuda_runtime_lib_path(test_input) get_cuda_runtime_lib_path(test_input)
assert all( assert all(match in err_info for match in {"duplicate", CUDA_RUNTIME_LIB})
match in err_info
for match in {"duplicate", CUDA_RUNTIME_LIB}
)
def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path): def test_get_cuda_runtime_lib_path__non_existent_dir(capsys, tmp_path):
existent_dir = tmp_path / 'a/b' existent_dir = tmp_path / "a/b"
existent_dir.mkdir() existent_dir.mkdir()
non_existent_dir = tmp_path / 'c/d' # non-existent dir non_existent_dir = tmp_path / "c/d" # non-existent dir
test_input = ":".join([str(existent_dir), str(non_existent_dir)]) test_input = ":".join([str(existent_dir), str(non_existent_dir)])
get_cuda_runtime_lib_path(test_input) get_cuda_runtime_lib_path(test_input)
std_err = capsys.readouterr().err std_err = capsys.readouterr().err
assert all( assert all(match in std_err for match in {"WARNING", "non-existent"})
match in std_err
for match in {"WARNING", "non-existent"}
)
def test_full_system(): def test_full_system():
## this only tests the cuda version and not compute capability ## this only tests the cuda version and not compute capability
ld_path = os.environ['LD_LIBRARY_PATH'] ld_path = os.environ["LD_LIBRARY_PATH"]
paths = ld_path.split(':') paths = ld_path.split(":")
version = '' version = ""
for p in paths: for p in paths:
if 'cuda' in p: if "cuda" in p:
idx = p.rfind('cuda-') idx = p.rfind("cuda-")
version = p[idx+5:idx+5+4].replace('/', '') version = p[idx + 5 : idx + 5 + 4].replace("/", "")
version = float(version) version = float(version)
break break
binary_name = evaluate_cuda_setup() binary_name = evaluate_cuda_setup()
binary_name = binary_name.replace('libbitsandbytes_cuda', '') binary_name = binary_name.replace("libbitsandbytes_cuda", "")
assert binary_name.startswith(str(version).replace('.', '')) assert binary_name.startswith(str(version).replace(".", ""))

File diff suppressed because it is too large Load Diff

View File

@ -1,21 +1,27 @@
from itertools import product
import pytest import pytest
import torch import torch
from itertools import product
from torch import nn from torch import nn
import bitsandbytes as bnb import bitsandbytes as bnb
class MockArgs(object): class MockArgs(object):
def __init__(self, initial_data): def __init__(self, initial_data):
for key in initial_data: for key in initial_data:
setattr(self, key, initial_data[key]) setattr(self, key, initial_data[key])
class MLP8bit(torch.nn.Module): class MLP8bit(torch.nn.Module):
def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0): def __init__(self, dim1, dim2, has_fp16_weights=True, threshold=0.0):
super(MLP8bit, self).__init__() super(MLP8bit, self).__init__()
self.fc1 = bnb.nn.Linear8bitLt(dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold) self.fc1 = bnb.nn.Linear8bitLt(
self.fc2 = bnb.nn.Linear8bitLt(dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold) dim1, dim2, has_fp16_weights=has_fp16_weights, threshold=threshold
)
self.fc2 = bnb.nn.Linear8bitLt(
dim2, dim1, has_fp16_weights=has_fp16_weights, threshold=threshold
)
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
@ -25,108 +31,114 @@ class MLP8bit(torch.nn.Module):
def get_args(): def get_args():
args = MockArgs([]) args = MockArgs([])
args.quant_type = 'vector' args.quant_type = "vector"
args.use_8bit_training = 'full' args.use_8bit_training = "full"
args.clip_freq = 9999 args.clip_freq = 9999
return args return args
def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10):
idx = torch.isclose(a, b, rtol, atol) idx = torch.isclose(a, b, rtol, atol)
sumval = (idx==0).sum().item() sumval = (idx == 0).sum().item()
if sumval > count: if sumval > count:
print(f'Too many values not close: assert {sumval} < {count}') print(f"Too many values not close: assert {sumval} < {count}")
torch.testing.assert_allclose(a, b, rtol, atol) torch.testing.assert_allclose(a, b, rtol, atol)
class LinearFunction(torch.autograd.Function):
class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = LinearFunction.round_stoachastic if stochastic else torch.round
norm = math.sqrt(math.pi)/math.sqrt(2.0) norm = math.sqrt(math.pi) / math.sqrt(2.0)
#std = torch.abs(x).mean()*norm # std = torch.abs(x).mean()*norm
std = torch.std(x) std = torch.std(x)
max1 = std*trim_value max1 = std * trim_value
x = x/max1*127 x = x / max1 * 127
x = round_func(x) x = round_func(x)
x[x > 127] = 127 x[x > 127] = 127
x[x < -127] = -127 x[x < -127] = -127
x = x/127*max1 x = x / 127 * max1
return x return x
def quant(x, quant_type, dim=1): def quant(x, quant_type, dim=1):
if quant_type == 'linear': if quant_type == "linear":
max1 = torch.abs(x).max().float() max1 = torch.abs(x).max().float()
xq = torch.round(x/max1*127).to(torch.int8) xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1 return xq, max1
elif quant_type == 'vector': elif quant_type == "vector":
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
xq = torch.round(x/max1*127).to(torch.int8) xq = torch.round(x / max1 * 127).to(torch.int8)
return xq, max1 return xq, max1
elif quant_type == 'min-max': elif quant_type == "min-max":
maxA = torch.amax(x, dim=dim, keepdim=True).float() maxA = torch.amax(x, dim=dim, keepdim=True).float()
minA = torch.amin(x, dim=dim, keepdim=True).float() minA = torch.amin(x, dim=dim, keepdim=True).float()
scale = (maxA-minA)/2.0 scale = (maxA - minA) / 2.0
xq = torch.round(127*(x-minA-scale)/scale).to(torch.int8) xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8)
return xq, (minA.float(), scale.float()) return xq, (minA.float(), scale.float())
else: return None else:
return None
def dequant(xq, S1, S2, dtype, quant_type): def dequant(xq, S1, S2, dtype, quant_type):
if quant_type == 'linear': if quant_type == "linear":
norm = S1*S2/(127*127) norm = S1 * S2 / (127 * 127)
# double cast needed to prevent overflows # double cast needed to prevent overflows
return (xq.float()*norm).to(dtype) return (xq.float() * norm).to(dtype)
elif quant_type == 'vector': elif quant_type == "vector":
x = xq.float() x = xq.float()
if len(xq.shape) == 2 and len(S1.shape) == 3: S1 = S1.squeeze(0) if len(xq.shape) == 2 and len(S1.shape) == 3:
if len(xq.shape) == 2 and len(S2.shape) == 3: S2 = S2.squeeze(0) S1 = S1.squeeze(0)
#print(x.shape, S1.shape, S2.shape) if len(xq.shape) == 2 and len(S2.shape) == 3:
S2 = S2.squeeze(0)
# print(x.shape, S1.shape, S2.shape)
if len(S1.shape) == 2: if len(S1.shape) == 2:
x *= S1.t()/127 x *= S1.t() / 127
else: else:
x *= S1/127 x *= S1 / 127
x *= S2/127 x *= S2 / 127
return x.to(dtype) return x.to(dtype)
else: return None else:
return None
def dequant_min_max(xq, A, B, SA, SB, dtype): def dequant_min_max(xq, A, B, SA, SB, dtype):
offset = B.float().t().sum(0)*(SA[0]+SA[1]) offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float() x = xq.float()
if len(xq.shape) == 2 and len(SB.shape) == 3: SB = SB.squeeze(0) if len(xq.shape) == 2 and len(SB.shape) == 3:
if len(xq.shape) == 2 and len(SA.shape) == 3: SA = SA.squeeze(0) SB = SB.squeeze(0)
if len(xq.shape) == 2 and len(SA.shape) == 3:
SA = SA.squeeze(0)
if len(SB.shape) == 2: if len(SB.shape) == 2:
x *= SB.t()/127 x *= SB.t() / 127
else: else:
x *= SB/127 x *= SB / 127
x *= SA[1]/127 x *= SA[1] / 127
x +=offset x += offset
return x.to(dtype) return x.to(dtype)
def get_8bit_linear(x, stochastic=False): def get_8bit_linear(x, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.abs(x).max() max1 = torch.abs(x).max()
x = x/max1*127 x = x / max1 * 127
x = round_func(x)/127*max1 x = round_func(x) / 127 * max1
#x = torch.round(x)/128*max1 # x = torch.round(x)/128*max1
return x return x
@staticmethod @staticmethod
def get_8bit_vector_wise(x, dim, stochastic=False): def get_8bit_vector_wise(x, dim, stochastic=False):
round_func = LinearFunction.round_stoachastic if stochastic else torch.round round_func = LinearFunction.round_stoachastic if stochastic else torch.round
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1==0] = 1.0 max1[max1 == 0] = 1.0
x = (x*127)/max1 x = (x * 127) / max1
x = round_func(x)/127*max1 x = round_func(x) / 127 * max1
return x return x
@staticmethod @staticmethod
def round_stoachastic(x): def round_stoachastic(x):
sign = torch.sign(x) sign = torch.sign(x)
absx = torch.abs(x) absx = torch.abs(x)
decimal = absx-torch.floor(absx) decimal = absx - torch.floor(absx)
rdm = torch.rand_like(decimal) rdm = torch.rand_like(decimal)
return sign*(torch.floor(absx)+(rdm < decimal).to(x.dtype)) return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype))
@staticmethod @staticmethod
def fake_8bit_storage(w, exponent_bits): def fake_8bit_storage(w, exponent_bits):
@ -140,10 +152,10 @@ class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def fake_8bit_storage_quantile(w, args): def fake_8bit_storage_quantile(w, args):
code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) code = bnb.functional.estimate_quantiles(w.data, offset=args.offset)
#C = bnb.functional.quantize_no_absmax(code, w) # C = bnb.functional.quantize_no_absmax(code, w)
#out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data)
#print(out) # print(out)
#out = out.half() # out = out.half()
code /= torch.max(torch.abs(code)) code /= torch.max(torch.abs(code))
absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) absmax, C = bnb.functional.quantize_blockwise(w.data, code=code)
out = bnb.functional.dequantize_blockwise(absmax, C, code) out = bnb.functional.dequantize_blockwise(absmax, C, code)
@ -162,7 +174,7 @@ class LinearFunction(torch.autograd.Function):
@staticmethod @staticmethod
def fake_8bit_storage_with_max(w, topk=8): def fake_8bit_storage_with_max(w, topk=8):
blocked_w = einops.rearrange(w.flatten(), '(h b) -> h b', b=256) blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256)
max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True)
idx = idx[:, :topk] idx = idx[:, :topk]
max_val = max_val[:, :topk] max_val = max_val[:, :topk]
@ -191,22 +203,21 @@ class LinearFunction(torch.autograd.Function):
w.copy_(unblocked_w) w.copy_(unblocked_w)
return unblocked_w return unblocked_w
@staticmethod @staticmethod
def forward(ctx, x, weight, bias=None, args=None): def forward(ctx, x, weight, bias=None, args=None):
if args.use_8bit_training != 'off': if args.use_8bit_training != "off":
weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2)
outputq = bnb.functional.igemm(x8, weight8.t()) outputq = bnb.functional.igemm(x8, weight8.t())
output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type) output = LinearFunction.dequant(outputq, S1, S2, x.dtype, args.quant_type)
#if torch.rand(1) < 0.01: # if torch.rand(1) < 0.01:
#output32 = torch.matmul(x, weight.t()) # output32 = torch.matmul(x, weight.t())
#err = torch.abs(output-output32).float() # err = torch.abs(output-output32).float()
#relerr = err/(torch.abs(output32).float()+1e-8) # relerr = err/(torch.abs(output32).float()+1e-8)
#print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy)
else: else:
#output = torch.matmul(x, weight.t()) # output = torch.matmul(x, weight.t())
output = torch.einsum('bsi,oi->bso', x, weight) output = torch.einsum("bsi,oi->bso", x, weight)
ctx.save_for_backward(x, weight, bias) ctx.save_for_backward(x, weight, bias)
ctx.args = args ctx.args = args
@ -221,37 +232,49 @@ class LinearFunction(torch.autograd.Function):
args = ctx.args args = ctx.args
stochastic = False stochastic = False
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0) if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
# weight and x are already 8bit # weight and x are already 8bit
# -> transform grad_output to 8-bit # -> transform grad_output to 8-bit
if args.use_8bit_training == 'forward+wgrad': if args.use_8bit_training == "forward+wgrad":
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = bnb.functional.igemm(grad_output8, x8) grad_weight8 = bnb.functional.igemm(grad_output8, x8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
#grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x)
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
elif args.use_8bit_training == 'full': elif args.use_8bit_training == "full":
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=[0, 1]) grad_output8, S1 = LinearFunction.quant(
grad_output, args.quant_type, dim=[0, 1]
)
x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1])
grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) grad_weight8 = torch.zeros_like(weight, dtype=torch.int32)
bnb.functional.igemm(grad_output8, x8, out=grad_weight8) bnb.functional.igemm(grad_output8, x8, out=grad_weight8)
grad_weight = LinearFunction.dequant(grad_weight8, S1, S2, grad_output.dtype, args.quant_type) grad_weight = LinearFunction.dequant(
grad_weight8, S1, S2, grad_output.dtype, args.quant_type
)
grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2) grad_output8, S1 = LinearFunction.quant(grad_output, args.quant_type, dim=2)
weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0)
grad_input8 = bnb.functional.igemm(grad_output8, weight8) grad_input8 = bnb.functional.igemm(grad_output8, weight8)
grad_input = LinearFunction.dequant(grad_input8, S1, S3, grad_output.dtype, args.quant_type) grad_input = LinearFunction.dequant(
grad_input8, S1, S3, grad_output.dtype, args.quant_type
)
else: else:
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
grad_weight = torch.einsum('bsi,bso->oi', x, grad_output) grad_weight = torch.einsum("bsi,bso->oi", x, grad_output)
return grad_input, grad_weight, grad_bias, None return grad_input, grad_weight, grad_bias, None
class Linear8bit(nn.Module): class Linear8bit(nn.Module):
def __init__(self, input_features, output_features, bias=True, args=None): def __init__(self, input_features, output_features, bias=True, args=None):
super(Linear8bit, self).__init__() super(Linear8bit, self).__init__()
@ -263,7 +286,7 @@ class Linear8bit(nn.Module):
if bias: if bias:
self.bias = nn.Parameter(torch.empty(output_features)) self.bias = nn.Parameter(torch.empty(output_features))
else: else:
self.register_parameter('bias', None) self.register_parameter("bias", None)
torch.nn.init.xavier_uniform_(self.weight) torch.nn.init.xavier_uniform_(self.weight)
if self.bias is not None: if self.bias is not None:
@ -275,12 +298,11 @@ class Linear8bit(nn.Module):
return LinearFunction.apply(x, self.weight, self.bias, self.args) return LinearFunction.apply(x, self.weight, self.bias, self.args)
def test_linear8bit(): def test_linear8bit():
l0 = torch.nn.Linear(32, 64).cuda().half() l0 = torch.nn.Linear(32, 64).cuda().half()
l1 = bnb.nn.Linear8bit(32,64, args=get_args()).cuda().half() l1 = bnb.nn.Linear8bit(32, 64, args=get_args()).cuda().half()
l2 = Linear8bit(32, 64, args=get_args()).cuda().half() l2 = Linear8bit(32, 64, args=get_args()).cuda().half()
l3 = bnb.nn.Linear8bitLt(32,64).cuda().half() l3 = bnb.nn.Linear8bitLt(32, 64).cuda().half()
l0.weight.data = l2.weight.data.clone() l0.weight.data = l2.weight.data.clone()
l0.bias.data = l2.bias.data.clone() l0.bias.data = l2.bias.data.clone()
@ -292,8 +314,8 @@ def test_linear8bit():
l3.bias.data = l2.bias.data.clone() l3.bias.data = l2.bias.data.clone()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
t = torch.randn(16, 8, 64, device='cuda').half() t = torch.randn(16, 8, 64, device="cuda").half()
b2 = b1.clone() b2 = b1.clone()
b3 = b1.clone() b3 = b1.clone()
b0 = b1.clone() b0 = b1.clone()
@ -318,16 +340,20 @@ def test_linear8bit():
assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) assert_all_approx_close(l1.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2) assert_all_approx_close(l3.bias.grad, l2.bias.grad, atol=0.01, rtol=0, count=2)
assert_all_approx_close(l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) assert_all_approx_close(
assert_all_approx_close(l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2) l1.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
)
assert_all_approx_close(
l3.weight.grad, l2.weight.grad, atol=0.013, rtol=0.05, count=2
)
err1 = torch.abs(l0.weight.grad-l1.weight.grad).mean().item() err1 = torch.abs(l0.weight.grad - l1.weight.grad).mean().item()
err2 = torch.abs(l0.weight.grad-l2.weight.grad).mean().item() err2 = torch.abs(l0.weight.grad - l2.weight.grad).mean().item()
err3 = torch.abs(l0.weight.grad-l3.weight.grad).mean().item() err3 = torch.abs(l0.weight.grad - l3.weight.grad).mean().item()
assert err1*0.8 < err2 assert err1 * 0.8 < err2
assert err2*0.8 < err3 assert err2 * 0.8 < err3
assert err3*0.8 < err1 assert err3 * 0.8 < err1
l0.weight.grad = None l0.weight.grad = None
l1.weight.grad = None l1.weight.grad = None
@ -341,23 +367,28 @@ def test_linear8bit():
threshold = [0.0, 3.0] threshold = [0.0, 3.0]
values = threshold values = threshold
names = ['threshold_{0}'.format(vals) for vals in values] names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_inference(threshold): def test_linear8bitlt_inference(threshold):
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold).cuda().half() l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half()
assert l1.weight.device.type == 'cuda' assert l1.weight.device.type == "cuda"
assert l1.weight.dtype == torch.float16 assert l1.weight.dtype == torch.float16
l1.eval() l1.eval()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1) o1 = l1(b1)
if i == 1: if i == 1:
assert l1.state.CxB is not None assert l1.state.CxB is not None
def test_linear8bitlt_accumulated_gradient(): def test_linear8bitlt_accumulated_gradient():
l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32,32).cuda().half() for i in range(2)]) l1 = torch.nn.Sequential(
l2 = torch.nn.Sequential(*[torch.nn.Linear(32,32).cuda().half() for i in range(2)]) *[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]
)
l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)])
l2[0].weight = torch.nn.Parameter(l1[0].weight.clone()) l2[0].weight = torch.nn.Parameter(l1[0].weight.clone())
l2[0].bias = torch.nn.Parameter(l1[0].bias.clone()) l2[0].bias = torch.nn.Parameter(l1[0].bias.clone())
l2[1].weight = torch.nn.Parameter(l1[1].weight.clone()) l2[1].weight = torch.nn.Parameter(l1[1].weight.clone())
@ -367,9 +398,8 @@ def test_linear8bitlt_accumulated_gradient():
acc_steps = 10 acc_steps = 10
for i in range(10): for i in range(10):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1) o1 = l1(b1)
o2 = l2(b1) o2 = l2(b1)
loss1 = o1.mean() loss1 = o1.mean()
@ -385,8 +415,12 @@ def test_linear8bitlt_accumulated_gradient():
opt1.zero_grad(True) opt1.zero_grad(True)
opt2.step() opt2.step()
opt2.zero_grad(True) opt2.zero_grad(True)
assert_all_approx_close(l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2) assert_all_approx_close(
assert_all_approx_close(l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2) l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2
)
assert_all_approx_close(
l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2
)
# we do this copy because otherwise we have small divergences over time that add up # we do this copy because otherwise we have small divergences over time that add up
l1[0].weight.data.copy_(l2[0].weight.data) l1[0].weight.data.copy_(l2[0].weight.data)
l1[1].weight.data.copy_(l2[1].weight.data) l1[1].weight.data.copy_(l2[1].weight.data)
@ -397,15 +431,21 @@ def test_linear8bitlt_accumulated_gradient():
threshold = [0.0, 2.0] threshold = [0.0, 2.0]
values = threshold values = threshold
names = ['threshold_{0}'.format(vals) for vals in values] names = ["threshold_{0}".format(vals) for vals in values]
@pytest.mark.parametrize("threshold", values, ids=names) @pytest.mark.parametrize("threshold", values, ids=names)
def test_linear8bitlt_no_fp16_weights(threshold): def test_linear8bitlt_no_fp16_weights(threshold):
l1 = bnb.nn.Linear8bitLt(32,64, threshold=threshold, has_fp16_weights=False).cuda().half() l1 = (
bnb.nn.Linear8bitLt(32, 64, threshold=threshold, has_fp16_weights=False)
.cuda()
.half()
)
assert l1.weight.dtype == torch.int8 assert l1.weight.dtype == torch.int8
l1.eval() l1.eval()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = l1(b1) o1 = l1(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
@ -414,57 +454,70 @@ def test_linear8bitlt_no_fp16_weights(threshold):
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half() mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda().half()
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda() mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().cuda()
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to("cuda")
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to('cuda')
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == 'cuda' assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == 'cuda' assert mlp.fc2.weight.device.type == "cuda"
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(torch.float16).to('cuda') mlp = (
MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False)
.to(torch.float16)
.to("cuda")
)
for i in range(100): for i in range(100):
b1 = torch.randn(16, 8, 32, device='cuda').half() b1 = torch.randn(16, 8, 32, device="cuda").half()
o1 = mlp(b1) o1 = mlp(b1)
assert o1.dtype == torch.float16 assert o1.dtype == torch.float16
if threshold > 0: assert mlp.fc1.state.idx is not None if threshold > 0:
if threshold > 0: assert mlp.fc2.state.idx is not None assert mlp.fc1.state.idx is not None
if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8
assert mlp.fc1.weight.device.type == 'cuda' assert mlp.fc1.weight.device.type == "cuda"
assert mlp.fc2.weight.device.type == 'cuda' assert mlp.fc2.weight.device.type == "cuda"

View File

@ -1,81 +1,132 @@
import os
import time
import shutil
import uuid
import pytest
import ctypes import ctypes
import os
import shutil
import time
import uuid
from itertools import product
from os.path import join
import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
import bitsandbytes.functional as F import bitsandbytes.functional as F
from os.path import join # import apex
from itertools import product
#import apex
k = 20 k = 20
def get_temp_dir(): def get_temp_dir():
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4())) path = "/tmp/autoswap/{0}".format(str(uuid.uuid4()))
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
def rm_path(path): def rm_path(path):
shutil.rmtree(path) shutil.rmtree(path)
str2optimizers = {} str2optimizers = {}
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam)
#str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
#str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) # str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam) str2optimizers["momentum_pytorch"] = (
#str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam) None,
#str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam) lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
bnb.optim.Adam,
)
# str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
# str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam) str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam)
#str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) # str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False)) str2optimizers["momentum"] = (
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9)) lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
#str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB) lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False),
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False)) )
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) str2optimizers["lars"] = (
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False)) lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False)) lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9),
#str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit) )
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9)) # str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
str2optimizers["rmsprop"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["adam8bit"] = (
torch.optim.Adam,
lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False),
)
str2optimizers["momentum8bit"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False),
)
str2optimizers["rmsprop8bit"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False),
)
# str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
str2optimizers["lars8bit"] = (
lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9),
)
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) str2optimizers["adam8bit_blockwise"] = (
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True)) torch.optim.Adam,
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True)) lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True),
)
str2optimizers["momentum8bit_blockwise"] = (
lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2optimizers["rmsprop8bit_blockwise"] = (
lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9),
lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True),
)
str2statenames = {} str2statenames = {}
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames['momentum'] = [('momentum_buffer', 'state1')] str2statenames["momentum"] = [("momentum_buffer", "state1")]
str2statenames['lars'] = [('momentum_buffer', 'state1')] str2statenames["lars"] = [("momentum_buffer", "state1")]
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')] str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")]
str2statenames['rmsprop'] = [('square_avg', 'state1')] str2statenames["rmsprop"] = [("square_avg", "state1")]
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] str2statenames["adam8bit"] = [
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')] ("exp_avg", "state1", "qmap1", "max1"),
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')] ("exp_avg_sq", "state2", "qmap2", "max2"),
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] ]
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')] str2statenames["lamb8bit"] = [
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')] ("exp_avg", "state1", "qmap1", "max1"),
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')] ("exp_avg_sq", "state2", "qmap2", "max2"),
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')] ]
str2statenames["adam8bit_blockwise"] = [
("exp_avg", "state1", "qmap1", "absmax1"),
("exp_avg_sq", "state2", "qmap2", "absmax2"),
]
str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["momentum8bit_blockwise"] = [
("momentum_buffer", "state1", "qmap1", "absmax1")
]
str2statenames["lars8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")]
str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")]
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097, 1] dim2 = [32, 1024, 4097, 1]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb'] optimizer_names = ["adam", "momentum", "rmsprop", "lars", "lamb"]
values = list(product(dim1,dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer32bit(dim1, dim2, gtype, optim_name): def test_optimizer32bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone() p2 = p1.clone()
p1 = p1.float() p1 = p1.float()
torch_optimizer = str2optimizers[optim_name][0]([p1]) torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
@ -84,9 +135,8 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
else: else:
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
for i in range(k): for i in range(k):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float() p1.grad = g.clone().float()
p2.grad = g.clone() p2.grad = g.clone()
@ -94,21 +144,31 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step() torch_optimizer.step()
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
if i % (k//5) == 0 and i > 0: if i % (k // 5) == 0 and i > 0:
path = get_temp_dir() path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer del bnb_optimizer
bnb_optimizer = None bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol) torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
for name1, name2 in str2statenames[optim_name]: for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol) torch.testing.assert_allclose(
torch_optimizer.state[p1][name1],
bnb_optimizer.state[p2][name2],
atol=atol,
rtol=rtol,
)
if gtype == torch.float16: if gtype == torch.float16:
# the adam buffers should also be close because they are 32-bit # the adam buffers should also be close because they are 32-bit
@ -118,20 +178,24 @@ def test_optimizer32bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.half().float() p1.data = p1.data.half().float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.half(), p2) torch.testing.assert_allclose(p1.half(), p2)
if optim_name in ['lars', 'lamb']: if optim_name in ["lars", "lamb"]:
assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0 assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
values = list(product(dim1,dim2, gtype)) values = list(product(dim1, dim2, gtype))
names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_gtype_{2}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
def test_global_config(dim1, dim2, gtype): def test_global_config(dim1, dim2, gtype):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 return
p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
mask = torch.rand_like(p2) < 0.1 mask = torch.rand_like(p2) < 0.1
beta1 = 0.9 beta1 = 0.9
beta2 = 0.999 beta2 = 0.999
@ -139,7 +203,7 @@ def test_global_config(dim1, dim2, gtype):
eps = 1e-8 eps = 1e-8
bnb.optim.GlobalOptimManager.get_instance().initialize() bnb.optim.GlobalOptimManager.get_instance().initialize()
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8) bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
p1 = p1.cuda() p1 = p1.cuda()
@ -154,30 +218,41 @@ def test_global_config(dim1, dim2, gtype):
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
for i in range(50): for i in range(50):
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001 g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
p1.grad = g1 p1.grad = g1
p2.grad = g2 p2.grad = g2
p3.grad = g3 p3.grad = g3
adam2.step() adam2.step()
assert adam2.state[p3]['state1'].dtype == torch.uint8 assert adam2.state[p3]["state1"].dtype == torch.uint8
assert adam2.state[p3]['state2'].dtype == torch.uint8 assert adam2.state[p3]["state2"].dtype == torch.uint8
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise'] optimizer_names = [
values = list(product(dim1,dim2, gtype, optimizer_names)) "adam8bit",
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] "momentum8bit",
"rmsprop8bit",
"adam8bit_blockwise",
"lamb8bit",
"lars8bit",
"momentum8bit_blockwise",
"rmsprop8bit_blockwise",
]
values = list(product(dim1, dim2, gtype, optimizer_names))
names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_optimizer8bit(dim1, dim2, gtype, optim_name): def test_optimizer8bit(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone() p2 = p1.clone()
p1 = p1.float() p1 = p1.float()
blocksize = 2048 blocksize = 2048
@ -197,7 +272,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors = [] relerrors = []
for i in range(50): for i in range(50):
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g.clone().float() p1.grad = g.clone().float()
p2.grad = g.clone() p2.grad = g.clone()
@ -208,17 +283,31 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
dequant_states = [] dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]: for name1, name2, qmap, max_val in str2statenames[optim_name]:
#print(bnb_optimizer.state[p2][max_val], name1) # print(bnb_optimizer.state[p2][max_val], name1)
if 'blockwise' in optim_name: if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else: else:
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) s1 = F.dequantize(
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone()) dequant_states.append(s1.clone())
err = torch.abs(p1-p2) err = torch.abs(p1 - p2)
relerr = err/torch.abs(p1) relerr = err / torch.abs(p1)
assert err.mean() < 0.0001 assert err.mean() < 0.0001
assert relerr.mean() < 0.001 assert relerr.mean() < 0.001
@ -226,28 +315,44 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
relerrors.append(relerr.mean().item()) relerrors.append(relerr.mean().item())
if i % 10 == 0 and i > 0: if i % 10 == 0 and i > 0:
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
s1cpy = s.clone() s1cpy = s.clone()
raws1cpy = bnb_optimizer.state[p2][name2].clone() raws1cpy = bnb_optimizer.state[p2][name2].clone()
qmap1 = bnb_optimizer.state[p2][qmap].clone() qmap1 = bnb_optimizer.state[p2][qmap].clone()
path = get_temp_dir() path = get_temp_dir()
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt')) torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt"))
del bnb_optimizer del bnb_optimizer
bnb_optimizer = None bnb_optimizer = None
bnb_optimizer = str2optimizers[optim_name][1]([p2]) bnb_optimizer = str2optimizers[optim_name][1]([p2])
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt'))) bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt")))
rm_path(path) rm_path(path)
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2]) torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap]) torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
if 'blockwise' in optim_name: if "blockwise" in optim_name:
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize) s1 = F.dequantize_blockwise(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
blocksize=blocksize,
)
else: else:
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2]) s1 = F.dequantize(
code=bnb_optimizer.state[p2][qmap],
absmax=bnb_optimizer.state[p2][max_val],
A=bnb_optimizer.state[p2][name2],
)
torch.testing.assert_allclose(s1cpy, s1) torch.testing.assert_allclose(s1cpy, s1)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0 num_not_close = (
torch.isclose(
torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol
)
== 0
)
assert num_not_close.sum().item() < 20 assert num_not_close.sum().item() < 20
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol) torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
@ -256,24 +361,28 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
p1.data = p1.data.to(gtype).float() p1.data = p1.data.to(gtype).float()
p2.copy_(p1.data) p2.copy_(p1.data)
torch.testing.assert_allclose(p1.to(gtype), p2) torch.testing.assert_allclose(p1.to(gtype), p2)
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): for (name1, name2, qmap, max_val), s in zip(
str2statenames[optim_name], dequant_states
):
torch_optimizer.state[p1][name1].copy_(s.data) torch_optimizer.state[p1][name1].copy_(s.data)
#print(sum(errors)/len(errors)) # print(sum(errors)/len(errors))
#print(sum(relerrors)/len(relerrors)) # print(sum(relerrors)/len(relerrors))
dim1 = [1024] dim1 = [1024]
dim2 = [32, 1024, 4097] dim2 = [32, 1024, 4097]
gtype = [torch.float32] gtype = [torch.float32]
optim_bits = [32, 8] optim_bits = [32, 8]
values = list(product(dim1,dim2, gtype, optim_bits)) values = list(product(dim1, dim2, gtype, optim_bits))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1 return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
beta1 = 0.9 beta1 = 0.9
beta2 = 0.999 beta2 = 0.999
lr = 0.001 lr = 0.001
@ -281,19 +390,23 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
p1 = p1.cuda() p1 = p1.cuda()
p2 = p1.clone() p2 = p1.clone()
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) adam2 = bnb.optim.Adam(
[p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5
)
gnorm_vec = torch.zeros(100).cuda() gnorm_vec = torch.zeros(100).cuda()
step = 0 step = 0
for i in range(50): for i in range(50):
step += 1 step += 1
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i) g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + (0.01 * i)
g2 = g1.clone() g2 = g1.clone()
p2.grad = g2 p2.grad = g2
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5) current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(
g1 = (g1.float()*gnorm_scale).to(gtype) g1, gnorm_vec, step, 5
)
g1 = (g1.float() * gnorm_scale).to(gtype)
p1.grad = g1 p1.grad = g1
adam1.step() adam1.step()
@ -302,47 +415,69 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
if optim_bits == 32: if optim_bits == 32:
torch.testing.assert_allclose(p1, p2) torch.testing.assert_allclose(p1, p2)
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4) torch.testing.assert_allclose(
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4) adam1.state[p1]["state1"],
adam2.state[p2]["state1"],
atol=5e-5,
rtol=1e-4,
)
torch.testing.assert_allclose(
adam1.state[p1]["state2"],
adam2.state[p2]["state2"],
atol=5e-5,
rtol=1e-4,
)
elif optim_bits == 8: elif optim_bits == 8:
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3) torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3) torch.testing.assert_allclose(
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3) adam1.state[p1]["state1"], adam2.state[p2]["state1"], atol=2, rtol=1e-3
adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1']) )
adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2']) torch.testing.assert_allclose(
adam1.state[p1]["state2"], adam2.state[p2]["state2"], atol=2, rtol=1e-3
)
adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"])
adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"])
if i % 10 == 0 and i > 0: if i % 10 == 0 and i > 0:
path = get_temp_dir() path = get_temp_dir()
torch.save(adam2.state_dict(),join(path, 'opt.pt')) torch.save(adam2.state_dict(), join(path, "opt.pt"))
del adam2 del adam2
adam2 = None adam2 = None
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5) adam2 = bnb.optim.Adam(
adam2.load_state_dict(torch.load(join(path, 'opt.pt'))) [p2],
lr,
(beta1, beta2),
eps,
optim_bits=optim_bits,
percentile_clipping=5,
)
adam2.load_state_dict(torch.load(join(path, "opt.pt")))
dim1 = [4096] dim1 = [4096]
dim2 = [4096] dim2 = [4096]
gtype = [torch.float32, torch.float16] gtype = [torch.float32, torch.float16]
#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] # optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] # optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] # optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
#optimizer_names = ['lamb_apex', 'lamb8bit'] # optimizer_names = ['lamb_apex', 'lamb8bit']
#optimizer_names = ['lars_apex', 'lars8bit'] # optimizer_names = ['lars_apex', 'lars8bit']
optimizer_names = ['adam8bit_blockwise'] optimizer_names = ["adam8bit_blockwise"]
values = list(product(dim1,dim2, gtype, optimizer_names)) values = list(product(dim1, dim2, gtype, optimizer_names))
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values] names = ["dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}".format(*vals) for vals in values]
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) @pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
if dim1 == 1 and dim2 == 1: return if dim1 == 1 and dim2 == 1:
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1]) bnb_optimizer = str2optimizers[optim_name][1]([p1])
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01 g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g p1.grad = g
for i in range(k): for i in range(k):
if i == k//5: if i == k // 5:
# 100 iterations for burn-in # 100 iterations for burn-in
torch.cuda.synchronize() torch.cuda.synchronize()
t0 = time.time() t0 = time.time()
@ -350,10 +485,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
bnb_optimizer.step() bnb_optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
s = time.time()-t0 s = time.time() - t0
print('') print("")
params = (k-k//5)*dim1*dim2 params = (k - k // 5) * dim1 * dim2
print(optim_name, gtype, s/params) print(optim_name, gtype, s / params)
#assert s < 3.9 # assert s < 3.9