Remove trailing whitespace & ensure newline at EOF

This commit is contained in:
Tom Aarsen 2022-10-27 13:11:29 +02:00
parent 31f6689504
commit 1eec77d34c
19 changed files with 121 additions and 127 deletions

View File

@ -49,7 +49,7 @@ Features:
Bug fixes: Bug fixes:
- Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13 - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13
- Fixed an unsafe use of eval. #8 - Fixed an unsafe use of eval. #8
- Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15 - Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15
Docs: Docs:
- Added instructions how to solve "\_\_fatbinwrap_" errors. - Added instructions how to solve "\_\_fatbinwrap_" errors.

View File

@ -28,4 +28,4 @@ outlined on that page and do not file a public issue.
## License ## License
By contributing to bitsandbytes, you agree that your contributions will be licensed By contributing to bitsandbytes, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree. under the LICENSE file in the root directory of this source tree.

View File

@ -26,14 +26,14 @@ INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/inclu
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
# NVIDIA NVCC compilation flags # NVIDIA NVCC compilation flags
COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler COMPUTE_CAPABILITY := -gencode arch=compute_35,code=sm_35 # Kepler
COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler COMPUTE_CAPABILITY += -gencode arch=compute_37,code=sm_37 # Kepler
COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell
COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
# CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not # CUDA 9.2 supports CC 3.0, but CUDA >= 11.0 does not
CC_CUDA92 := -gencode arch=compute_30,code=sm_30 CC_CUDA92 := -gencode arch=compute_30,code=sm_30
@ -58,38 +58,38 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda110_nomatmul: $(BUILD_DIR) env cuda110_nomatmul: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda11x_nomatmul: $(BUILD_DIR) env cuda11x_nomatmul: $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT
$(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB)
cuda110: $(BUILD_DIR) env cuda110: $(BUILD_DIR) env
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cuda11x: $(BUILD_DIR) env cuda11x: $(BUILD_DIR) env
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)
cpuonly: $(BUILD_DIR) env cpuonly: $(BUILD_DIR) env
@ -117,7 +117,7 @@ $(ROOT_DIR)/dependencies/cub:
cd dependencies/cub; git checkout 1.11.0 cd dependencies/cub; git checkout 1.11.0
clean: clean:
rm build/* rm build/*
cleaneggs: cleaneggs:
rm -rf *.egg* rm -rf *.egg*

View File

@ -1,6 +1,6 @@
# bitsandbytes # bitsandbytes
The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions. The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions.
@ -48,7 +48,7 @@ out = linear(x.to(torch.float16))
Requirements: anaconda, cudatoolkit, pytorch Requirements: anaconda, cudatoolkit, pytorch
Hardware requirements: Hardware requirements:
- LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or older). - LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or older).
- 8-bit optimizers and quantization: NVIDIA Maxwell GPU or newer (>=GTX 9XX). - 8-bit optimizers and quantization: NVIDIA Maxwell GPU or newer (>=GTX 9XX).
@ -87,7 +87,7 @@ Note that by default all parameter tensors with less than 4096 elements are kept
``` ```
# parameter tensors with less than 16384 values are optimized in 32-bit # parameter tensors with less than 16384 values are optimized in 32-bit
# it is recommended to use multiplies of 4096 # it is recommended to use multiplies of 4096
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
``` ```
### Change Bits and other Hyperparameters for Individual Parameters ### Change Bits and other Hyperparameters for Individual Parameters

View File

@ -15,7 +15,7 @@ tensor = torch.Tensor
""" """
This class pools outlier dimensions across layers. This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features This is particularly important for small models where outlier features
are less systematic and occur with low frequency. are less systematic and occur with low frequency.
""" """
class GlobalOutlierPooler(object): class GlobalOutlierPooler(object):

View File

@ -1133,7 +1133,7 @@ def igemm(
ptr = CUBLAS_Context.get_instance().get_context(A.device) ptr = CUBLAS_Context.get_instance().get_context(A.device)
# B^T @ A^T = C^T # B^T @ A^T = C^T
# [km, nk -> mn] # [km, nk -> mn]
is_on_gpu([B, A, out]) is_on_gpu([B, A, out])
lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k),
get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc))

View File

@ -267,7 +267,7 @@ class Linear8bitLt(nn.Linear):
self.weight.data = self.state.CxB self.weight.data = self.state.CxB
elif self.state.memory_efficient_backward and self.state.CxB is not None: elif self.state.memory_efficient_backward and self.state.CxB is not None:
# For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass. # For memory efficient backward, we convert 8-bit row major to turing/ampere format at each inference pass.
# Thus, we delete CxB from the state. # Thus, we delete CxB from the state.
del self.state.CxB del self.state.CxB
return out return out

View File

@ -4,7 +4,7 @@ Basic steps.
1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly` 1. `make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cpuonly`
2. `CUDA_VERSION=XXX python setup.py install` 2. `CUDA_VERSION=XXX python setup.py install`
To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive).
For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands: For your convenience, there is an installation script in the root directory that installs CUDA 11.1 locally and configures it automatically. After installing you should add the `bin` sub-directory to the `$PATH` variable to make the compiler visible to your system. To do this you can add this to your `.bashrc` by executing these commands:
```bash ```bash
@ -13,7 +13,7 @@ echo "export PATH=$PATH:/usr/local/cuda/bin/" >> ~/.bashrc
source ~/.bashrc source ~/.bashrc
``` ```
By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler. By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler.
Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed

View File

@ -62,7 +62,7 @@ void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long
for (int i = 0; i < valid_chunks; i++) for (int i = 0; i < valid_chunks; i++)
int err = pthread_join(threads[i], NULL); int err = pthread_join(threads[i], NULL);
free(threads); free(threads);
for (int i = 0; i < valid_chunks; i++) for (int i = 0; i < valid_chunks; i++)
free(args[i]); free(args[i]);

View File

@ -1,6 +1,6 @@
// Copyright (c) Facebook, Inc. and its affiliates. // Copyright (c) Facebook, Inc. and its affiliates.
// //
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#include <kernels.cuh> #include <kernels.cuh>
@ -303,7 +303,7 @@ __global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* ou
if(threadIdx.x % 32 < 8) if(threadIdx.x % 32 < 8)
{ {
// offset: 8 values per 256 input values // offset: 8 values per 256 input values
// //
int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8; int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8;
} }
@ -572,7 +572,7 @@ __global__ void kDequantize(float *code, unsigned char *A, float *out, const int
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n) const int step, const float lr, const float gnorm_scale, const int n)
@ -620,7 +620,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
{ {
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case ADAM: case ADAM:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
s1_vals[j] *= correction1; s1_vals[j] *= correction1;
@ -651,7 +651,7 @@ __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1) __launch_bounds__(TH, 1)
__global__ void kOptimizer32bit2State(T* g, T* p, __global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@ -714,7 +714,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
{ {
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case ADAM: case ADAM:
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{ {
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
@ -739,7 +739,7 @@ __global__ void kOptimizer32bit2State(T* g, T* p,
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) __launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n) const int step, const float lr, const float gnorm_scale, const int n)
@ -781,19 +781,19 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
{ {
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case MOMENTUM: case MOMENTUM:
if(step == 1) if(step == 1)
s1_vals[j] = (float)g_vals[j]; // state update s1_vals[j] = (float)g_vals[j]; // state update
else else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break; break;
case ADAGRAD: case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
@ -817,7 +817,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1) __launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p, __global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm, float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
@ -880,7 +880,7 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
{ {
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case MOMENTUM: case MOMENTUM:
if(step == 1) if(step == 1)
s1_vals[j] = (float)g_vals[j]; s1_vals[j] = (float)g_vals[j];
else else
@ -888,11 +888,11 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
break; break;
case ADAGRAD: case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
break; break;
@ -1154,12 +1154,12 @@ kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned cha
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void __global__ void
__launch_bounds__(NUM_THREADS, 2) __launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm, float *unorm,
const float beta1, const float beta1,
const float eps, const int step, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
const float weight_decay, const float weight_decay,
const float gnorm_scale, const int n) const float gnorm_scale, const int n)
{ {
@ -1209,7 +1209,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0];
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case MOMENTUM: case MOMENTUM:
if(step == 1) if(step == 1)
s1_vals[j] = (float)g_vals[j]; s1_vals[j] = (float)g_vals[j];
else else
@ -1217,7 +1217,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
if(unorm != NULL) if(unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j]; local_unorm += s1_vals[j]*s1_vals[j];
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break; break;
} }
@ -1242,10 +1242,10 @@ template<typename T, int OPTIMIZER>
__global__ void __global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm, const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta1,
const float eps, const int step, const float lr, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
float weight_decay, float weight_decay,
const float gnorm_scale, const int n) const float gnorm_scale, const int n)
{ {
@ -1311,7 +1311,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case MOMENTUM: case MOMENTUM:
if(step == 1) if(step == 1)
s1_vals[j] = g_vals[j]; s1_vals[j] = g_vals[j];
else else
@ -1319,7 +1319,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
break; break;
@ -1399,7 +1399,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
const float beta1, const float beta2, const float beta1, const float beta2,
const float eps, const int step, const float lr, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2, float* absmax1, float* absmax2,
float weight_decay, float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n) const float gnorm_scale, const bool skip_zeros, const int n)
{ {
@ -1543,7 +1543,7 @@ kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char
StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items); StoreT(temp_storage.storeh).Store(&(p[i]), g_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3 // quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH # pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++) for(unsigned int j = 0; j < N_PER_TH; j++)
{ {
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
@ -1656,16 +1656,16 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case MOMENTUM: case MOMENTUM:
if(step == 1) if(step == 1)
s1_vals[j] = g_val; s1_vals[j] = g_val;
else else
s1_vals[j] = (s1_vals[j]*beta1) + g_val; s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break; break;
case RMSPROP: case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break; break;
case ADAGRAD: case ADAGRAD:
s1_vals[j] = s1_vals[j] + (g_val*g_val); s1_vals[j] = s1_vals[j] + (g_val*g_val);
break; break;
} }
@ -1696,14 +1696,14 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
{ {
switch(OPTIMIZER) switch(OPTIMIZER)
{ {
case MOMENTUM: case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break; break;
case RMSPROP: case RMSPROP:
g_val = g_vals[j]; g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break; break;
case ADAGRAD: case ADAGRAD:
g_val = g_vals[j]; g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break; break;
@ -1716,7 +1716,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3 // quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH # pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++) for(unsigned int j = 0; j < N_PER_TH; j++)
{ {
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
@ -1893,9 +1893,9 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
{ {
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive // Strategy: To dequantize we need to load col/row statistics. This can be very expensive
// since different row/col stats need to be loaded with each thread. // since different row/col stats need to be loaded with each thread.
// (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
// and would lead to low global load utilization. // and would lead to low global load utilization.
// (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
// for each thread and this is duplicated by a factor of 32/num-cols-per-thread. // for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
// (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
@ -1903,7 +1903,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
// We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
// the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
// we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
// shared memory loads. // shared memory loads.
// data is in 32 column-tile major with tile width 32 columns and numRows rows // data is in 32 column-tile major with tile width 32 columns and numRows rows
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
@ -2140,7 +2140,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
// To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values
// As such we need: // As such we need:
// at least 32*4 shared memory tiles for col32; preferably 32*32 // at least 32*4 shared memory tiles for col32; preferably 32*32
// at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32
// at least 32*8 shared memory tiles for col4_turing: preferably 32*32 // at least 32*8 shared memory tiles for col4_turing: preferably 32*32
@ -2150,7 +2150,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
// we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy
// for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough
// register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM
// //
// to make the shared memory work with that occupancy we might need to union the block loads/stores // to make the shared memory work with that occupancy we might need to union the block loads/stores
// each block loads TILE_COLs columns and TILE_ROW rows // each block loads TILE_COLs columns and TILE_ROW rows
@ -2239,7 +2239,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
switch(FORMAT) switch(FORMAT)
{ {
case COL32: case COL32:
if(TRANSPOSE) if(TRANSPOSE)
{ {
// data lies in shared memory in the following way: // data lies in shared memory in the following way:
@ -2264,7 +2264,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
// each 32 columns we have new tile // each 32 columns we have new tile
// each tile has size outRows*32 and base_row is done in increments of 32 // each tile has size outRows*32 and base_row is done in increments of 32
offset = base_row*outRows; offset = base_row*outRows;
out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data;
} }
} }
@ -2310,7 +2310,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
// we increase by row_tile_column every 32 columns // we increase by row_tile_column every 32 columns
// base_row increase in increments of 32 // base_row increase in increments of 32
//int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements
//int col_offset = (base_row/32)*row_tile_column; //int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 256*outRows/8*base_row/32 = outRows*base_row // 256*outRows/8*base_row/32 = outRows*base_row
int col_offset = outRows*base_row; int col_offset = outRows*base_row;
@ -2347,7 +2347,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
// this happends every 8 rows anew (subrow % 8) // this happends every 8 rows anew (subrow % 8)
// one writes 4 columns at once that is (col % 4) for the particular index in the subtile // one writes 4 columns at once that is (col % 4) for the particular index in the subtile
int subcol = warp_lane; int subcol = warp_lane;
// add local offset (4x4 sub-tile) // add local offset (4x4 sub-tile)
if(subrow % 2 == 1) if(subrow % 2 == 1)
// odd // odd
@ -2387,7 +2387,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
// we increase by row_tile_column every 32 columns // we increase by row_tile_column every 32 columns
// base_row increase in increments of 32 // base_row increase in increments of 32
//int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements
//int col_offset = (base_row/32)*row_tile_column; //int col_offset = (base_row/32)*row_tile_column;
// -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8
// 1024*outRows/32*base_row/32 = outRows*base_row // 1024*outRows/32*base_row/32 = outRows*base_row
int col_offset = outRows*base_row; int col_offset = outRows*base_row;
@ -2445,7 +2445,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
#define C 1.0f/127.0f #define C 1.0f/127.0f
#define MAX_SPARSE_COUNT 32 #define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256 #define SMEM_SIZE 8*256
template <typename T, int SPMM_ITEMS, int BITS> template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{ {
@ -2575,7 +2575,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
#pragma unroll num_items #pragma unroll num_items
for(int k = 0; k < num_items; k++) for(int k = 0; k < num_items; k++)
local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items]; reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
} }
else else
@ -2589,11 +2589,11 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
idx_col_B += blockDim.x*SPMM_ITEMS; idx_col_B += blockDim.x*SPMM_ITEMS;
local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
} }
} }
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA)
{ {
int local_colidx = idx[blockIdx.x]; int local_colidx = idx[blockIdx.x];
if(FORMAT==COL_TURING) if(FORMAT==COL_TURING)
@ -2653,7 +2653,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
out[out_idx] = val; out[out_idx] = val;
} }
} }
} }
//============================================================== //==============================================================
// TEMPLATE DEFINITIONS // TEMPLATE DEFINITIONS

View File

@ -1,6 +1,6 @@
// Copyright (c) Facebook, Inc. and its affiliates. // Copyright (c) Facebook, Inc. and its affiliates.
// //
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#include <float.h> #include <float.h>
@ -18,49 +18,49 @@ template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n); template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, __global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n); const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(T* g, T* p, __global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay, const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS> template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, float* state1, float *unorm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n); const int step, const float lr, const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p, __global__ void kOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm, const float max_unorm, const float param_norm, float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float eps, const float weight_decay, const float beta1, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void __global__ void
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm, float *unorm,
const float beta1, const float beta1,
const float eps, const int step, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
const float weight_decay, const float weight_decay,
const float gnorm_scale, const int n); const float gnorm_scale, const int n);
template<typename T, int OPTIMIZER> template<typename T, int OPTIMIZER>
__global__ void __global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm, const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta1,
const float eps, const int step, const float lr, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles1,
float* max1, float* new_max1, float* max1, float* new_max1,
float weight_decay, const float gnorm_scale, const int n); float weight_decay, const float gnorm_scale, const int n);
@ -70,7 +70,7 @@ __global__ void
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
float *unorm, float *unorm,
const float beta1, const float beta2, const float beta1, const float beta2,
const float eps, const int step, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2, float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n); const float gnorm_scale, const int n);
@ -81,7 +81,7 @@ __global__ void
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
const float *unorm, const float max_unorm, const float param_norm, const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta1, const float beta2,
const float eps, const int step, const float lr, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2, float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay, const float gnorm_scale, const int n); float weight_decay, const float gnorm_scale, const int n);
@ -121,5 +121,3 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
#endif #endif

View File

@ -1,6 +1,6 @@
// Copyright (c) Facebook, Inc. and its affiliates. // Copyright (c) Facebook, Inc. and its affiliates.
// //
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#include <ops.cuh> #include <ops.cuh>
@ -212,7 +212,7 @@ void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, in
} }
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount) long long int strideA, long long int strideB, long long int strideC, int batchCount)
{ {
const int falpha = 1; const int falpha = 1;
@ -322,7 +322,7 @@ template <typename T, int SRC, int TARGET, bool transpose, int DTYPE> void trans
cublasLtOrder_t orderOut = get_order<TARGET>(); cublasLtOrder_t orderOut = get_order<TARGET>();
int ldA = get_leading_dim<SRC>(dim1, dim2); int ldA = get_leading_dim<SRC>(dim1, dim2);
int ldOut = get_leading_dim<TARGET>(dim1, dim2); int ldOut = get_leading_dim<TARGET>(dim1, dim2);
cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL;
cublasLtMatrixTransformDesc_t A2Out_desc = NULL; cublasLtMatrixTransformDesc_t A2Out_desc = NULL;
cublasOperation_t opTranspose = CUBLAS_OP_T; cublasOperation_t opTranspose = CUBLAS_OP_T;
@ -368,7 +368,7 @@ template void transform<int8_t, ROW, COL_AMPERE, false, 8>(cublasLtHandle_t ltHa
template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); template void transform<int8_t, COL32, ROW, false, 8>(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2);
template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); template void transform<int32_t, COL32, ROW, false, 32>(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2);
template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc)
{ {
#ifdef NO_CUBLASLT #ifdef NO_CUBLASLT
cout << "" << endl; cout << "" << endl;

View File

@ -1,6 +1,6 @@
// Copyright (c) Facebook, Inc. and its affiliates. // Copyright (c) Facebook, Inc. and its affiliates.
// //
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
@ -131,7 +131,7 @@ void dequantize(float *code, unsigned char *A, float *out, int n);
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n); template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, float weight_decay, float beta1, float beta2, float eps, float weight_decay,
int step, float lr, const float gnorm_scale, bool skip_zeros, int n); int step, float lr, const float gnorm_scale, bool skip_zeros, int n);
@ -139,15 +139,15 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm, float *unorm, float max_unorm, float param_norm,
float beta1, float beta2, float beta1, float beta2,
float eps, int step, float lr, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* quantiles1, float* quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2, float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay, float weight_decay,
const float gnorm_scale, int n); const float gnorm_scale, int n);
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros, int n); bool skip_zeros, int n);
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n); template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
@ -155,7 +155,7 @@ template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step,
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc);
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount); long long int strideA, long long int strideB, long long int strideC, int batchCount);

View File

@ -1,6 +1,6 @@
// Copyright (c) Facebook, Inc. and its affiliates. // Copyright (c) Facebook, Inc. and its affiliates.
// //
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#if BUILD_CUDA #if BUILD_CUDA
@ -9,7 +9,7 @@
#include <cpu_ops.h> #include <cpu_ops.h>
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
// maintain all that boilerplate // maintain all that boilerplate
//=================================================================================== //===================================================================================
// UNMANGLED CALLS // UNMANGLED CALLS
@ -290,4 +290,3 @@ extern "C"
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); }
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); }
} }

View File

@ -76,6 +76,3 @@ if [[ -n "$CUDA_VERSION" ]]; then
else else
echo "" echo ""
fi fi

View File

@ -14,16 +14,16 @@ mng.register_parameters(model.parameters()) # 1. register parameters while still
model = model.cuda() model = model.cuda()
# use 8-bit optimizer states for all parameters # use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# 2a. override: the parameter model.fc1.weight now uses 32-bit Adam # 2a. override: the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, 'optim_bits', 32) mng.override_config(model.fc1.weight, 'optim_bits', 32)
# 2b. override: the two special layers use # 2b. override: the two special layers use
# sparse optimization + different learning rate + different Adam betas # sparse optimization + different learning rate + different Adam betas
mng.override_config([model.special.weight, model.also_special.weight], mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
``` ```
Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm` Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm`
For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager: For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager:

View File

@ -121,7 +121,7 @@ template <unsigned char Gap, typename T>
struct DirectTraits<true,Gap,T> struct DirectTraits<true,Gap,T>
{ {
typedef FVec1<SSE, T> fVec1; typedef FVec1<SSE, T> fVec1;
static void checkH(T scaler, T H_Times_x0, T xN) static void checkH(T scaler, T H_Times_x0, T xN)
{ {
union { union {
@ -177,9 +177,9 @@ struct DirectInfo
, cst0(fun_t::cst0(H, x[0])) , cst0(fun_t::cst0(H, x[0]))
{ {
myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned"); myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned");
uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]); uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]);
const uint32 npad = Gap-1; const uint32 npad = Gap-1;
const uint32 n_sz = n + npad; // size of padded vector const uint32 n_sz = n + npad; // size of padded vector
@ -320,7 +320,7 @@ struct DirectInfo
T cst0 = fun_t::cst0(H, px[0]); T cst0 = fun_t::cst0(H, px[0]);
const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]); const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]);
buckets.resize(maxIndex + 1); buckets.resize(maxIndex + 1);
data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL)); data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL));
} }

View File

@ -203,7 +203,7 @@ struct IVec<SSE, double> : IVecBase<SSE>
#if 1 #if 1
// takes 4 cycles // takes 4 cycles
__m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle __m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle
__m128i s = _mm_add_epi32(vec, hi); __m128i s = _mm_add_epi32(vec, hi);
int32 x = _mm_cvtsi128_si32(s); int32 x = _mm_cvtsi128_si32(s);
return -x; return -x;
#else #else

View File

@ -336,7 +336,7 @@ def test_matmullt(
) )
bias = None bias = None
bias2 = None bias2 = None
if has_bias: if has_bias:
bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2])
bias2 = bias.clone() bias2 = bias.clone()
torch.nn.init.xavier_uniform_(B) torch.nn.init.xavier_uniform_(B)