Initial commit
This commit is contained in:
commit
7439924891
0
.buckconfig
Normal file
0
.buckconfig
Normal file
135
.gitignore
vendored
Normal file
135
.gitignore
vendored
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
pip-wheel-metadata/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# vim
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
dependencies
|
||||||
|
cuda_build
|
25
BUCK
Normal file
25
BUCK
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
prebuilt_python_library(
|
||||||
|
name = 'bnb-cuda102',
|
||||||
|
binary_src = ':bnb-cuda102-wheel',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
remote_file(
|
||||||
|
name = 'bnb-cuda102-wheel',
|
||||||
|
url = 'https://test-files.pythonhosted.org/packages/4e/69/025b08bf1b7e777ca3800dc79ebe9dfd7309931f0a5f3de132d1433076ff/bitsandbytes_cuda102-0.0.22-py3-none-any.whl',
|
||||||
|
sha1 = '8c89e640afab18cdc6b7c5924c70e25036811686',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
prebuilt_python_library(
|
||||||
|
name = 'bnb-cuda111',
|
||||||
|
binary_src = ':bnb-cuda111-wheel',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
remote_file(
|
||||||
|
name = 'bnb-cuda111-wheel',
|
||||||
|
url = 'https://test-files.pythonhosted.org/packages/f9/38/2179701c80ae2aa9606bce7d498f397bd94e7bb2ff7e7c30ed032a3a39c2/bitsandbytes_cuda111-0.0.22-py3-none-any.whl',
|
||||||
|
sha1 = '433f534b225bc29391782c8a9d82635bc0eb9d33',
|
||||||
|
)
|
||||||
|
|
23
CHANGELOG.md
Normal file
23
CHANGELOG.md
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
v0.0.21
|
||||||
|
- Ampere, RTX 30 series GPUs now compatible with the library.
|
||||||
|
|
||||||
|
v0.0.22:
|
||||||
|
|
||||||
|
- Fixed an error where a `reset_parameters()` call on the `StableEmbedding` would lead to an error in older PyTorch versions (from 1.7.0).
|
||||||
|
|
||||||
|
v0.0.23:
|
||||||
|
|
||||||
|
Bugs:
|
||||||
|
- Unified quantization API: each quantization function now returns `Q, S` where `Q` is the quantized tensor and `S` the quantization state which may hold absolute max values, a quantization map or more. For dequantization all functions now accept the inputs `Q, S` so that `Q` is dequantized with the quantization state `S`.
|
||||||
|
- Fixed an issue where the CUDA 11.1 binary was not compiled with the right headers
|
||||||
|
|
||||||
|
API changes:
|
||||||
|
- Block-wise quantization for optimizers now enabled by default
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Block-wise quantization routines now support CPU Tensors.
|
||||||
|
|
||||||
|
|
||||||
|
v0.0.24:
|
||||||
|
|
||||||
|
- Fixed a bug where a float/half conversion led to a compilation error for CUDA 11.1 on Turning GPUs.
|
80
CODE_OF_CONDUCT.md
Normal file
80
CODE_OF_CONDUCT.md
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
# Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
In the interest of fostering an open and welcoming environment, we as
|
||||||
|
contributors and maintainers pledge to make participation in our project and
|
||||||
|
our community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||||
|
level of experience, education, socio-economic status, nationality, personal
|
||||||
|
appearance, race, religion, or sexual identity and orientation.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to creating a positive environment
|
||||||
|
include:
|
||||||
|
|
||||||
|
* Using welcoming and inclusive language
|
||||||
|
* Being respectful of differing viewpoints and experiences
|
||||||
|
* Gracefully accepting constructive criticism
|
||||||
|
* Focusing on what is best for the community
|
||||||
|
* Showing empathy towards other community members
|
||||||
|
|
||||||
|
Examples of unacceptable behavior by participants include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||||
|
advances
|
||||||
|
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or electronic
|
||||||
|
address, without explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Our Responsibilities
|
||||||
|
|
||||||
|
Project maintainers are responsible for clarifying the standards of acceptable
|
||||||
|
behavior and are expected to take appropriate and fair corrective action in
|
||||||
|
response to any instances of unacceptable behavior.
|
||||||
|
|
||||||
|
Project maintainers have the right and responsibility to remove, edit, or
|
||||||
|
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||||
|
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||||
|
permanently any contributor for other behaviors that they deem inappropriate,
|
||||||
|
threatening, offensive, or harmful.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all project spaces, and it also applies when
|
||||||
|
an individual is representing the project or its community in public spaces.
|
||||||
|
Examples of representing a project or community include using an official
|
||||||
|
project e-mail address, posting via an official social media account, or acting
|
||||||
|
as an appointed representative at an online or offline event. Representation of
|
||||||
|
a project may be further defined and clarified by project maintainers.
|
||||||
|
|
||||||
|
This Code of Conduct also applies outside the project spaces when there is a
|
||||||
|
reasonable belief that an individual's behavior may have a negative impact on
|
||||||
|
the project or its community.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported by contacting the project team at <opensource-conduct@fb.com>. All
|
||||||
|
complaints will be reviewed and investigated and will result in a response that
|
||||||
|
is deemed necessary and appropriate to the circumstances. The project team is
|
||||||
|
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||||
|
Further details of specific enforcement policies may be posted separately.
|
||||||
|
|
||||||
|
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||||
|
faith may face temporary or permanent repercussions as determined by other
|
||||||
|
members of the project's leadership.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||||
|
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||||
|
|
||||||
|
[homepage]: https://www.contributor-covenant.org
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see
|
||||||
|
https://www.contributor-covenant.org/faq
|
31
CONTRIBUTING.md
Normal file
31
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# Contributing to bitsandbytes
|
||||||
|
We want to make contributing to this project as easy and transparent as
|
||||||
|
possible.
|
||||||
|
|
||||||
|
## Pull Requests
|
||||||
|
We actively welcome your pull requests.
|
||||||
|
|
||||||
|
1. Fork the repo and create your branch from `main`.
|
||||||
|
2. If you've added code that should be tested, add tests.
|
||||||
|
3. If you've changed APIs, update the documentation.
|
||||||
|
4. Ensure the test suite passes.
|
||||||
|
5. Make sure your code lints.
|
||||||
|
6. If you haven't already, complete the Contributor License Agreement ("CLA").
|
||||||
|
|
||||||
|
## Contributor License Agreement ("CLA")
|
||||||
|
In order to accept your pull request, we need you to submit a CLA. You only need
|
||||||
|
to do this once to work on any of Facebook's open source projects.
|
||||||
|
|
||||||
|
Complete your CLA here: <https://code.facebook.com/cla>
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
We use GitHub issues to track public bugs. Please ensure your description is
|
||||||
|
clear and has sufficient instructions to be able to reproduce the issue.
|
||||||
|
|
||||||
|
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
|
||||||
|
disclosure of security bugs. In those cases, please go through the process
|
||||||
|
outlined on that page and do not file a public issue.
|
||||||
|
|
||||||
|
## License
|
||||||
|
By contributing to bitsandbytes, you agree that your contributions will be licensed
|
||||||
|
under the LICENSE file in the root directory of this source tree.
|
21
LICENSE
Normal file
21
LICENSE
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
60
Makefile
Normal file
60
Makefile
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST)))
|
||||||
|
ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH)))
|
||||||
|
|
||||||
|
GPP:= /usr/bin/g++
|
||||||
|
NVCC := $(CUDA_HOME)/bin/nvcc
|
||||||
|
###########################################
|
||||||
|
|
||||||
|
CSRC := $(ROOT_DIR)/csrc
|
||||||
|
BUILD_DIR:= $(ROOT_DIR)/cuda_build
|
||||||
|
|
||||||
|
FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
|
||||||
|
FILES_CPP := $(CSRC)/pythonInterface.c
|
||||||
|
|
||||||
|
INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include
|
||||||
|
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcuda -lcublas -lcurand -lcusparse -L $(CONDA_PREFIX)/lib
|
||||||
|
|
||||||
|
# NVIDIA NVCC compilation flags
|
||||||
|
COMPUTE_CAPABILITY := -gencode arch=compute_50,code=sm_50 # Maxwell
|
||||||
|
COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell
|
||||||
|
COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal
|
||||||
|
COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta
|
||||||
|
COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta
|
||||||
|
|
||||||
|
all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
|
$(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
|
cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
|
$(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
|
cuda10x: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -gencode arch=compute_75,code=sm_75 -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
|
$(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
|
cuda110: $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -gencode arch=compute_80,code=sm_80 -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
|
$(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
|
cuda11x: $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
|
||||||
|
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
|
||||||
|
$(GPP) -std=c++11 -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes.so $(LIB)
|
||||||
|
|
||||||
|
$(BUILD_DIR):
|
||||||
|
mkdir -p cuda_build
|
||||||
|
mkdir -p dependencies
|
||||||
|
|
||||||
|
$(ROOT_DIR)/dependencies/cub:
|
||||||
|
git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub
|
||||||
|
|
||||||
|
clean:
|
||||||
|
rm cuda_build/* ./bitsandbytes/libbitsandbytes.so
|
||||||
|
|
||||||
|
cleaneggs:
|
||||||
|
rm -rf *.egg*
|
3
NOTICE.md
Normal file
3
NOTICE.md
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license.
|
||||||
|
|
||||||
|
We thank Fabio Cannizzo for this work on FastBinarySearch which is included in this project.
|
106
README.md
Normal file
106
README.md
Normal file
|
@ -0,0 +1,106 @@
|
||||||
|
# bitsandbytes
|
||||||
|
|
||||||
|
bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers and quantization functions.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB
|
||||||
|
- Percentile clipping: A gradient clipping technique that adjusts dynamically for each weight-tensor during training
|
||||||
|
- Stable Embedding Layer: Improved stability through better initialization, and normalization
|
||||||
|
- Fast quantile estimation: Up to 100x faster than other algorithms
|
||||||
|
- 8-bit quantization: Quantile, Linear, and Dynamic quantization
|
||||||
|
|
||||||
|
#### Details
|
||||||
|
- **8-bit Optimizers** use an 8-bit instead of 32-bit state and thus save 75% of memory.
|
||||||
|
- **Percentile Clipping** is an adaptive gradient clipping technique that adapts the clipping threshold automatically during training for each weight-tensor. It tracks a history of the past 100 gradient norms, and the gradient is clipped at a certain percentile p. For most tasks, p=5 works well and provides improved stability and, in some cases, even better performance (ResNet-50 ImageNet).
|
||||||
|
- The **Stable Embedding Layer** uses a less variable initialization coupled with layer norm for stability. Usually, dense optimizers are used in conjunction with sparse BPE/word embeddings, and these dense optimizers perform incorrect updates, leading to instability. The Stable Embedding Layer fixes this problem by performing sparse updates by default for any chosen bnb optimizer.
|
||||||
|
- Fast quantile estimation via **SRAM-Quantiles** algorithm, which is up to 100x faster than previous algorithms to estimate quantiles.
|
||||||
|
- Various **8-bit Quantization** schemes which are useful to compress data. For example, gradient communication or Mixture of Experts token routing can be improved by using 8-bit quantization before communication followed by decompression to 16/32-bit.
|
||||||
|
|
||||||
|
## Requirements & Installation
|
||||||
|
|
||||||
|
Requirements: anaconda, cudatoolkit, pytorch
|
||||||
|
Hardware requirements: NVIDIA Maxwell GPU or newer (>=GTX 9XX)
|
||||||
|
|
||||||
|
The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website.
|
||||||
|
|
||||||
|
bitsandbytes is compatible with all major PyTorch releases and cudatoolkit versions, but for now, you need to select the right version manually. To do this run:
|
||||||
|
|
||||||
|
```conda list | grep cudatoolkit```
|
||||||
|
|
||||||
|
and take note of the Cuda version that you have installed. Then you can install bitsandbytes via:
|
||||||
|
```bash
|
||||||
|
# choices: {cuda92, cuda 100, cuda101, cuda102, cuda110, cuda111, cuda113}
|
||||||
|
# replace XXX with the respective number
|
||||||
|
pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX
|
||||||
|
```
|
||||||
|
|
||||||
|
To check if your installation was successful, you can execute the following command, which runs a single bnb Adam update.
|
||||||
|
```
|
||||||
|
wget https://gist.githubusercontent.com/TimDettmers/1f5188c6ee6ed69d211b7fe4e381e713/raw/4d17c3d09ccdb57e9ab7eca0171f2ace6e4d2858/check_bnb_install.py && python check_bnb_install.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using bitsandbytes
|
||||||
|
|
||||||
|
### Using the 8-bit Optimizers
|
||||||
|
|
||||||
|
With bitsandbytes 8-bit optimizers can be used by changing a single line of code in your codebase. For NLP models we recommend also to use the StableEmbedding layers (see below) which improves results and helps with stable 8-bit optimization. To get started with 8-bit optimizers, it is sufficient to replace your old optimizer with the 8-bit optimizer in the following way:
|
||||||
|
```python
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer
|
||||||
|
adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer
|
||||||
|
adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalent
|
||||||
|
|
||||||
|
# use 32-bit Adam with 5th percentile clipping
|
||||||
|
adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995),
|
||||||
|
optim_bits=32, percentile_clipping=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm).
|
||||||
|
|
||||||
|
### Change Bits and other Hyperparameters for Individual Parameters
|
||||||
|
|
||||||
|
If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, with can use the `GlobalOptimManager`. With this, we can also configure specific parameters for sparse optimization, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere).
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
mng = bnb.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
model = MyModel()
|
||||||
|
mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU
|
||||||
|
|
||||||
|
model = model.cuda()
|
||||||
|
# use 8-bit optimizer states for all parameters
|
||||||
|
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
|
||||||
|
|
||||||
|
# 2a. override: the parameter model.fc1.weight now uses 32-bit Adam
|
||||||
|
mng.override_config(model.fc1.weight, 'optim_bits', 32)
|
||||||
|
|
||||||
|
# 2b. override: the two special layers use
|
||||||
|
# sparse optimization + different learning rate + different Adam betas
|
||||||
|
mng.override_config([model.special.weight, model.also_special.weight],
|
||||||
|
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stable Embedding Layer
|
||||||
|
|
||||||
|
To use the stable embedding layer, simply replace the PyTorch embedding layer with `bnb.nn.StableEmbedding`. By default, this layer is sparsely optimized.
|
||||||
|
|
||||||
|
### Fairseq Users
|
||||||
|
|
||||||
|
To use the Stable Embedding Layer, override the respective `build_embedding(...)` function of your model. Make sure to also use the `--no-scale-embedding` flag to disable scaling of the word embedding layer (nor replaced with layer norm). You can use the optimizers by replacing the optimizer in the respective file (`adam.py` etc.).
|
||||||
|
|
||||||
|
## Release and Feature History
|
||||||
|
|
||||||
|
Last release: v0.0.22:
|
||||||
|
- Fixed an error where a `reset_parameters()` call on the `StableEmbedding` would lead to an error in older PyTorch versions (from 1.7.0).
|
||||||
|
|
||||||
|
For upcoming features and changes and full history see [Patch Notes](PATCH_NOTES.md).
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license.
|
||||||
|
|
||||||
|
We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization.
|
10
bitsandbytes/__init__.py
Normal file
10
bitsandbytes/__init__.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
from .optim import adam
|
||||||
|
from .nn import modules
|
||||||
|
__pdoc__ = {'libBitsNBytes' : False,
|
||||||
|
'optim.optimizer.Optimizer8bit': False,
|
||||||
|
'optim.optimizer.MockArgs': False
|
||||||
|
}
|
531
bitsandbytes/functional.py
Normal file
531
bitsandbytes/functional.py
Normal file
File diff suppressed because one or more lines are too long
5
bitsandbytes/nn/__init__.py
Normal file
5
bitsandbytes/nn/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
from .modules import StableEmbedding
|
44
bitsandbytes/nn/modules.py
Normal file
44
bitsandbytes/nn/modules.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from bitsandbytes.optim import GlobalOptimManager
|
||||||
|
|
||||||
|
class StableEmbedding(torch.nn.Embedding):
|
||||||
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
|
||||||
|
max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = True, _weight: Optional[Tensor] = None) -> None:
|
||||||
|
super(StableEmbedding, self).__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, False, _weight)
|
||||||
|
self.norm = torch.nn.LayerNorm(embedding_dim)
|
||||||
|
GlobalOptimManager.get_instance().register_parameters(self.weight)
|
||||||
|
GlobalOptimManager.get_instance().override_config(self.weight, 'optim_bits', 32)
|
||||||
|
|
||||||
|
def reset_parameters(self) -> None:
|
||||||
|
torch.nn.init.xavier_uniform_(self.weight)
|
||||||
|
self._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
|
''' !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding
|
||||||
|
to make the Layer compatible with Pytorch < 1.9.
|
||||||
|
This means that if this changes in future PyTorch releases this need to change too
|
||||||
|
which is cumbersome. However, with this we can ensure compatibility with previous
|
||||||
|
PyTorch releases.
|
||||||
|
'''
|
||||||
|
def _fill_padding_idx_with_zero(self) -> None:
|
||||||
|
if self.padding_idx is not None:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.weight[self.padding_idx].fill_(0)
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
emb = F.embedding(
|
||||||
|
input, self.weight, self.padding_idx, self.max_norm,
|
||||||
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||||
|
|
||||||
|
return self.norm(emb)
|
10
bitsandbytes/optim/__init__.py
Normal file
10
bitsandbytes/optim/__init__.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
from .adam import Adam, Adam8bit, Adam32bit
|
||||||
|
from .sgd import SGD, SGD8bit, SGD32bit
|
||||||
|
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
|
||||||
|
from .lamb import LAMB, LAMB8bit, LAMB32bit
|
||||||
|
from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit
|
||||||
|
from .optimizer import GlobalOptimManager
|
28
bitsandbytes/optim/adam.py
Normal file
28
bitsandbytes/optim/adam.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||||
|
|
||||||
|
class Adam(Optimizer2State):
|
||||||
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||||
|
weight_decay=0, amsgrad=False, optim_bits=32, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
super(Adam, self).__init__('adam', params, lr, betas, eps,
|
||||||
|
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
|
||||||
|
|
||||||
|
class Adam8bit(Optimizer2State):
|
||||||
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||||
|
weight_decay=0, amsgrad=False, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
super(Adam8bit, self).__init__('adam', params, lr, betas, eps,
|
||||||
|
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
|
||||||
|
|
||||||
|
class Adam32bit(Optimizer2State):
|
||||||
|
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||||
|
weight_decay=0, amsgrad=False, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
super(Adam32bit, self).__init__('adam', params, lr, betas, eps,
|
||||||
|
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
|
||||||
|
|
||||||
|
|
29
bitsandbytes/optim/lamb.py
Normal file
29
bitsandbytes/optim/lamb.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import apex
|
||||||
|
from bitsandbytes.optim.optimizer import Optimizer2State
|
||||||
|
|
||||||
|
class LAMB(Optimizer2State):
|
||||||
|
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
|
||||||
|
weight_decay=0, amsgrad=False, adam_w_mode=True, optim_bits=32, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
|
||||||
|
super(LAMB, self).__init__('lamb', params, lr, betas, eps,
|
||||||
|
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
|
||||||
|
|
||||||
|
class LAMB8bit(Optimizer2State):
|
||||||
|
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
|
||||||
|
weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
|
||||||
|
super(LAMB8bit, self).__init__('lamb', params, lr, betas, eps,
|
||||||
|
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
|
||||||
|
|
||||||
|
class LAMB32bit(Optimizer2State):
|
||||||
|
def __init__(self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8,
|
||||||
|
weight_decay=0, amsgrad=False, adam_w_mode=True, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=False, max_unorm=1.0):
|
||||||
|
super(LAMB32bit, self).__init__('lamb', params, lr, betas, eps,
|
||||||
|
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, max_unorm=1.0)
|
||||||
|
|
||||||
|
|
115
bitsandbytes/optim/lars.py
Normal file
115
bitsandbytes/optim/lars.py
Normal file
|
@ -0,0 +1,115 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from bitsandbytes.optim.optimizer import Optimizer1State
|
||||||
|
|
||||||
|
class LARS(Optimizer1State):
|
||||||
|
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||||
|
weight_decay=0, nesterov=False, optim_bits=32, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
|
||||||
|
if momentum == 0:
|
||||||
|
raise NotImplementError(f'LARS without momentum is not supported!')
|
||||||
|
super(LARS, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
|
||||||
|
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
|
||||||
|
|
||||||
|
class LARS8bit(Optimizer1State):
|
||||||
|
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||||
|
weight_decay=0, nesterov=False, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
|
||||||
|
if momentum == 0:
|
||||||
|
raise NotImplementError(f'LARS without momentum is not supported!')
|
||||||
|
super(LARS8bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
|
||||||
|
weight_decay, 8, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
|
||||||
|
|
||||||
|
class LARS32bit(Optimizer1State):
|
||||||
|
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||||
|
weight_decay=0, nesterov=False, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, max_unorm=0.02):
|
||||||
|
if momentum == 0:
|
||||||
|
raise NotImplementError(f'LARS without momentum is not supported!')
|
||||||
|
super(LARS32bit, self).__init__('lars', params, lr, (momentum, dampening), 0.0,
|
||||||
|
weight_decay, 32, args, min_8bit_size, percentile_clipping, max_unorm=max_unorm, block_wise=False)
|
||||||
|
|
||||||
|
|
||||||
|
class PytorchLARS(Optimizer):
|
||||||
|
def __init__(self, params, lr=0.01, momentum=0, dampening=0,
|
||||||
|
weight_decay=0, nesterov=False, max_unorm=0.02):
|
||||||
|
if lr < 0.0:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if momentum < 0.0:
|
||||||
|
raise ValueError("Invalid momentum value: {}".format(momentum))
|
||||||
|
if weight_decay < 0.0:
|
||||||
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||||
|
|
||||||
|
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
|
||||||
|
weight_decay=weight_decay, nesterov=nesterov, max_unorm=max_unorm)
|
||||||
|
if nesterov and (momentum <= 0 or dampening != 0):
|
||||||
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||||
|
super(PytorchLARS, self).__init__(params, defaults)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super(PytorchLARS, self).__setstate__(state)
|
||||||
|
for group in self.param_groups:
|
||||||
|
group.setdefault('nesterov', False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
for group in self.param_groups:
|
||||||
|
params_with_grad = []
|
||||||
|
d_p_list = []
|
||||||
|
momentum_buffer_list = []
|
||||||
|
weight_decay = group['weight_decay']
|
||||||
|
momentum = group['momentum']
|
||||||
|
dampening = group['dampening']
|
||||||
|
nesterov = group['nesterov']
|
||||||
|
max_unorm = group['max_unorm']
|
||||||
|
lr = group['lr']
|
||||||
|
|
||||||
|
for p in group['params']:
|
||||||
|
if p.grad is None: continue
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
d_p = p.grad
|
||||||
|
if weight_decay != 0:
|
||||||
|
d_p = d_p.add(param, alpha=weight_decay)
|
||||||
|
|
||||||
|
if momentum != 0:
|
||||||
|
buf = state.get('momentum_buffer', None)
|
||||||
|
|
||||||
|
if buf is None:
|
||||||
|
buf = torch.clone(d_p).detach()
|
||||||
|
state['momentum_buffer']= buf
|
||||||
|
else:
|
||||||
|
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
|
||||||
|
|
||||||
|
if nesterov:
|
||||||
|
update = d_p + buf*momentum
|
||||||
|
else:
|
||||||
|
update = buf
|
||||||
|
|
||||||
|
update_scale = 1.0
|
||||||
|
if max_unorm > 0.0:
|
||||||
|
assert p.dtype == torch.float32
|
||||||
|
pnorm = torch.norm(p.detach())
|
||||||
|
unorm = torch.norm(update)
|
||||||
|
if unorm > max_unorm*pnorm:
|
||||||
|
update_scale = max_unorm*pnorm/unorm
|
||||||
|
|
||||||
|
p.add_(update, alpha=-lr*update_scale)
|
||||||
|
|
||||||
|
return loss
|
460
bitsandbytes/optim/optimizer.py
Normal file
460
bitsandbytes/optim/optimizer.py
Normal file
|
@ -0,0 +1,460 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import torch
|
||||||
|
import bitsandbytes.functional as F
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
from itertools import chain
|
||||||
|
from collections import defaultdict, abc as container_abcs
|
||||||
|
|
||||||
|
class MockArgs(object):
|
||||||
|
def __init__(self, initial_data):
|
||||||
|
for key in initial_data:
|
||||||
|
setattr(self, key, initial_data[key])
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalOptimManager(object):
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
raise RuntimeError('Call get_instance() instead')
|
||||||
|
|
||||||
|
def initialize(self):
|
||||||
|
self.pid2config = {}
|
||||||
|
self.index2config = {}
|
||||||
|
self.optimizer = None
|
||||||
|
self.uses_config_override = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls.__new__(cls)
|
||||||
|
cls._instance.initialize()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def register_parameters(self, params):
|
||||||
|
param_groups = list(params)
|
||||||
|
if not isinstance(param_groups[0], dict):
|
||||||
|
param_groups = [{'params': param_groups}]
|
||||||
|
|
||||||
|
for group_index, group in enumerate(param_groups):
|
||||||
|
for p_index, p in enumerate(group['params']):
|
||||||
|
if id(p) in self.pid2config:
|
||||||
|
self.index2config[(group_index, p_index)] = self.pid2config[id(p)]
|
||||||
|
|
||||||
|
def override_config(self, parameters, key=None, value=None, key_value_dict=None):
|
||||||
|
'''
|
||||||
|
Overrides initial optimizer config for specific parameters.
|
||||||
|
|
||||||
|
The key-values of the optimizer config for the input parameters are overidden
|
||||||
|
This can be both, optimizer parameters like "betas", or "lr" or it can be
|
||||||
|
8-bit specific paramters like "optim_bits", "percentile_clipping".
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
parameters : torch.Tensor or list(torch.Tensors)
|
||||||
|
The input parameters.
|
||||||
|
key : str
|
||||||
|
The hyperparamter to override.
|
||||||
|
value : object
|
||||||
|
The value for the hyperparamters.
|
||||||
|
key_value_dict : dict
|
||||||
|
A dictionary with multiple key-values to override.
|
||||||
|
'''
|
||||||
|
self.uses_config_override = True
|
||||||
|
if isinstance(parameters, torch.nn.Parameter):
|
||||||
|
parameters = [parameters]
|
||||||
|
if isinstance(parameters, torch.Tensor):
|
||||||
|
parameters = [parameters]
|
||||||
|
if key is not None and value is not None:
|
||||||
|
assert key_value_dict is None
|
||||||
|
key_value_dict = {key: value}
|
||||||
|
|
||||||
|
if key_value_dict is not None:
|
||||||
|
for p in parameters:
|
||||||
|
if id(p) in self.pid2config:self.pid2config[id(p)].update(key_value_dict)
|
||||||
|
else: self.pid2config[id(p)] = key_value_dict
|
||||||
|
|
||||||
|
|
||||||
|
class Optimizer8bit(torch.optim.Optimizer):
|
||||||
|
|
||||||
|
def __init__(self, params, defaults, optim_bits=32):
|
||||||
|
super(Optimizer8bit, self).__init__(params, defaults)
|
||||||
|
self.checked_if_on_gpu = False
|
||||||
|
self.name2qmap = {}
|
||||||
|
|
||||||
|
self.mng = GlobalOptimManager.get_instance()
|
||||||
|
self.non_castable_tensor_keys = set(
|
||||||
|
['qmap1', 'qmap2',
|
||||||
|
'max1', 'max2',
|
||||||
|
'new_max1', 'new_max2',
|
||||||
|
'state1', 'state2',
|
||||||
|
'gnorm_vec', 'absmax1', 'absmax2',
|
||||||
|
'unorm_vec'])
|
||||||
|
|
||||||
|
if optim_bits == 8: self.fill_qmap()
|
||||||
|
|
||||||
|
def fill_qmap(self):
|
||||||
|
self.name2qmap['dynamic'] = F.create_dynamic_map(signed=True)
|
||||||
|
self.name2qmap['udynamic'] = F.create_dynamic_map(signed=False)
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
super(Optimizer8bit, self).__setstate__(state)
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
r"""Loads the optimizer state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): optimizer state. Should be an object returned
|
||||||
|
from a call to :meth:`state_dict`.
|
||||||
|
"""
|
||||||
|
# deepcopy, to be consistent with module API
|
||||||
|
state_dict = deepcopy(state_dict)
|
||||||
|
# Validate the state_dict
|
||||||
|
groups = self.param_groups
|
||||||
|
saved_groups = state_dict['param_groups']
|
||||||
|
|
||||||
|
if len(groups) != len(saved_groups):
|
||||||
|
raise ValueError("loaded state dict has a different number of "
|
||||||
|
"parameter groups")
|
||||||
|
param_lens = (len(g['params']) for g in groups)
|
||||||
|
saved_lens = (len(g['params']) for g in saved_groups)
|
||||||
|
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||||
|
raise ValueError("loaded state dict contains a parameter group "
|
||||||
|
"that doesn't match the size of optimizer's group")
|
||||||
|
|
||||||
|
# Update the state
|
||||||
|
id_map = {old_id: p for old_id, p in
|
||||||
|
zip(chain.from_iterable((g['params'] for g in saved_groups)),
|
||||||
|
chain.from_iterable((g['params'] for g in groups)))}
|
||||||
|
|
||||||
|
def cast(param, value):
|
||||||
|
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
# Floating-point types are a bit special here. They are the only ones
|
||||||
|
# that are assumed to always match the type of params.
|
||||||
|
if param.is_floating_point() and value.dtype != torch.uint8:
|
||||||
|
value = value.to(param.dtype)
|
||||||
|
return value
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
for k, v in value.items():
|
||||||
|
if k in self.non_castable_tensor_keys:
|
||||||
|
value[k] = v.to(param.device)
|
||||||
|
else:
|
||||||
|
value[k] = cast(param, v)
|
||||||
|
|
||||||
|
return value
|
||||||
|
elif isinstance(value, container_abcs.Iterable):
|
||||||
|
return type(value)(cast(param, v) for v in value)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||||
|
# State that is not assigned to params is copied as is (needed for
|
||||||
|
# backward compatibility).
|
||||||
|
state = defaultdict(dict)
|
||||||
|
for k, v in state_dict['state'].items():
|
||||||
|
if k in id_map:
|
||||||
|
param = id_map[k]
|
||||||
|
state[param] = cast(param, v)
|
||||||
|
else:
|
||||||
|
state[k] = v
|
||||||
|
|
||||||
|
# Update parameter groups, setting their 'params' value
|
||||||
|
def update_group(group, new_group):
|
||||||
|
new_group['params'] = group['params']
|
||||||
|
return new_group
|
||||||
|
param_groups = [
|
||||||
|
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
||||||
|
self.__setstate__({'state': state, 'param_groups': param_groups})
|
||||||
|
|
||||||
|
def to_gpu(self):
|
||||||
|
self.checked_if_on_gpu = True
|
||||||
|
for gindex, group in enumerate(self.param_groups):
|
||||||
|
for pindex, p in enumerate(group['params']):
|
||||||
|
if p in self.state:
|
||||||
|
values = self.state[p]
|
||||||
|
for k, v in values.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
self.state[p][k] = v.to(p.device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
"""Performs a single optimization step.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
closure (callable, optional): A closure that reevaluates the model
|
||||||
|
and returns the loss.
|
||||||
|
"""
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
|
||||||
|
overflows = []
|
||||||
|
|
||||||
|
if not self.checked_if_on_gpu: self.to_gpu() # needed for fairseq pure fp16 training
|
||||||
|
for gindex, group in enumerate(self.param_groups):
|
||||||
|
for pindex, p in enumerate(group['params']):
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
state = self.state[p]
|
||||||
|
if len(state) == 0:
|
||||||
|
self.init_state(group, p, gindex, pindex)
|
||||||
|
|
||||||
|
self.update_step(group, p, gindex, pindex)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def get_config(self, gindex, pindex, group):
|
||||||
|
config = {}
|
||||||
|
config['betas'] = group['betas']
|
||||||
|
config['eps'] = group['eps']
|
||||||
|
config['weight_decay'] = group['weight_decay']
|
||||||
|
config['lr'] = group['lr']
|
||||||
|
config['optim_bits'] = self.args.optim_bits
|
||||||
|
config['min_8bit_size'] = self.args.min_8bit_size
|
||||||
|
config['percentile_clipping'] = self.args.percentile_clipping
|
||||||
|
config['block_wise'] = self.args.block_wise
|
||||||
|
config['max_unorm'] = self.args.max_unorm
|
||||||
|
|
||||||
|
if (gindex, pindex) in self.mng.index2config:
|
||||||
|
config.update(self.mng.index2config[(gindex, pindex)])
|
||||||
|
return config
|
||||||
|
|
||||||
|
def init_state(self, group, p, gindex, pindex):
|
||||||
|
raise NotImplementedError(f'init_state method needs to be overidden')
|
||||||
|
|
||||||
|
def update_step(self, group, p, gindex, pindex):
|
||||||
|
raise NotImplementedError(f'The update_step method needs to be overidden')
|
||||||
|
|
||||||
|
class Optimizer2State(Optimizer8bit):
|
||||||
|
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
||||||
|
weight_decay=0.0, optim_bits=32, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
|
if isinstance(betas, str):
|
||||||
|
betas = eval(betas)
|
||||||
|
print(betas, 'parsed')
|
||||||
|
for i in range(len(betas)):
|
||||||
|
if not 0.0 <= betas[i] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||||
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||||
|
weight_decay=weight_decay)
|
||||||
|
super(Optimizer2State, self).__init__(params, defaults, optim_bits)
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
args = {}
|
||||||
|
args['optim_bits'] = optim_bits
|
||||||
|
args['percentile_clipping'] = 100
|
||||||
|
args['min_8bit_size'] = min_8bit_size
|
||||||
|
args['percentile_clipping'] = percentile_clipping
|
||||||
|
args['block_wise'] = block_wise
|
||||||
|
args['max_unorm'] = max_unorm
|
||||||
|
|
||||||
|
self.args = MockArgs(args)
|
||||||
|
else:
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.optimizer_name = optimizer_name
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def init_state(self, group, p, gindex, pindex):
|
||||||
|
config = self.get_config(gindex, pindex, group)
|
||||||
|
|
||||||
|
if config['optim_bits'] == 32:
|
||||||
|
dtype = torch.float32
|
||||||
|
elif config['optim_bits'] == 8:
|
||||||
|
dtype = torch.uint8
|
||||||
|
else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
|
||||||
|
|
||||||
|
if p.numel() < config['min_8bit_size']: dtype = torch.float32
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
state['step'] = 0
|
||||||
|
|
||||||
|
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
|
||||||
|
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
|
||||||
|
state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
|
||||||
|
elif dtype == torch.uint8:
|
||||||
|
if state['step'] == 0:
|
||||||
|
if 'dynamic' not in self.name2qmap: self.fill_qmap()
|
||||||
|
self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
|
||||||
|
self.name2qmap['udynamic'] = self.name2qmap['udynamic'].to(p.device)
|
||||||
|
|
||||||
|
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
|
||||||
|
state['qmap1'] = self.name2qmap['dynamic']
|
||||||
|
|
||||||
|
state['state2'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
|
||||||
|
state['qmap2'] = self.name2qmap['udynamic']
|
||||||
|
|
||||||
|
if config['block_wise']:
|
||||||
|
n = p.numel()
|
||||||
|
blocks = n//2048
|
||||||
|
blocks += 1 if n % 2048 > 0 else 0
|
||||||
|
|
||||||
|
state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
|
||||||
|
state['absmax2'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
|
||||||
|
else:
|
||||||
|
state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
||||||
|
state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
||||||
|
state['max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
||||||
|
state['new_max2'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
||||||
|
|
||||||
|
if config['percentile_clipping'] < 100:
|
||||||
|
state['gnorm_vec'] = torch.zeros((100,), device=p.device)
|
||||||
|
|
||||||
|
if config['max_unorm'] > 0.0:
|
||||||
|
state['unorm_vec'] = torch.zeros((1,), device=p.device)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update_step(self, group, p, gindex, pindex):
|
||||||
|
state = self.state[p]
|
||||||
|
grad = p.grad
|
||||||
|
|
||||||
|
config = self.get_config(gindex, pindex, group)
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
step = state['step']
|
||||||
|
|
||||||
|
if config['percentile_clipping'] < 100:
|
||||||
|
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
|
||||||
|
else:
|
||||||
|
gnorm_scale = 1.0
|
||||||
|
|
||||||
|
if state['state1'].dtype == torch.float:
|
||||||
|
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
||||||
|
state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale,
|
||||||
|
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
|
||||||
|
|
||||||
|
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
||||||
|
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
|
||||||
|
config['eps'], step, config['lr'],
|
||||||
|
state['qmap1'], state['qmap2'], state['max1'], state['max2'], state['new_max1'], state['new_max2'],
|
||||||
|
config['weight_decay'], gnorm_scale=gnorm_scale,
|
||||||
|
unorm_vec=state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
|
||||||
|
|
||||||
|
# swap maxes
|
||||||
|
state['max1'], state['new_max1'] = state['new_max1'], state['max1']
|
||||||
|
state['max2'], state['new_max2'] = state['new_max2'], state['max2']
|
||||||
|
elif state['state1'].dtype == torch.uint8 and config['block_wise']:
|
||||||
|
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1],
|
||||||
|
config['eps'], step, config['lr'],
|
||||||
|
state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'],
|
||||||
|
config['weight_decay'], gnorm_scale=gnorm_scale)
|
||||||
|
|
||||||
|
|
||||||
|
class Optimizer1State(Optimizer8bit):
|
||||||
|
def __init__(self, optimizer_name, params, lr=1e-3, betas=(0.9, 0.0), eps=1e-8,
|
||||||
|
weight_decay=0.0, optim_bits=32, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True, max_unorm=0.0):
|
||||||
|
if not 0.0 <= lr:
|
||||||
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||||
|
if not 0.0 <= eps:
|
||||||
|
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||||
|
for i in range(len(betas)):
|
||||||
|
if not 0.0 <= betas[i] < 1.0:
|
||||||
|
raise ValueError(f"Invalid beta parameter at index {i}: {betas[i]}")
|
||||||
|
if not 0.0 <= weight_decay:
|
||||||
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||||
|
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||||
|
weight_decay=weight_decay)
|
||||||
|
super(Optimizer1State, self).__init__(params, defaults, optim_bits)
|
||||||
|
|
||||||
|
if args is None:
|
||||||
|
args = {}
|
||||||
|
args['optim_bits'] = optim_bits
|
||||||
|
args['percentile_clipping'] = 100
|
||||||
|
args['min_8bit_size'] = min_8bit_size
|
||||||
|
args['percentile_clipping'] = percentile_clipping
|
||||||
|
args['block_wise'] = block_wise
|
||||||
|
args['max_unorm'] = max_unorm
|
||||||
|
|
||||||
|
self.args = MockArgs(args)
|
||||||
|
else:
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.optimizer_name = optimizer_name
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def init_state(self, group, p, gindex, pindex):
|
||||||
|
config = self.get_config(gindex, pindex, group)
|
||||||
|
|
||||||
|
if config['optim_bits'] == 32:
|
||||||
|
dtype = torch.float32
|
||||||
|
elif config['optim_bits'] == 8:
|
||||||
|
dtype = torch.uint8
|
||||||
|
else: raise NotImplementedError(f'Amount of optimizer bits not supported: {config["optim_bits"]}')
|
||||||
|
|
||||||
|
if p.numel() < config['min_8bit_size']: dtype = torch.float32
|
||||||
|
|
||||||
|
state = self.state[p]
|
||||||
|
state['step'] = 0
|
||||||
|
|
||||||
|
if dtype == torch.float32 or (dtype == torch.uint8 and p.numel() < 4096):
|
||||||
|
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.float32, device=p.device)
|
||||||
|
elif dtype == torch.uint8:
|
||||||
|
if state['step'] == 0:
|
||||||
|
if 'dynamic' not in self.name2qmap: self.fill_qmap()
|
||||||
|
self.name2qmap['dynamic'] = self.name2qmap['dynamic'].to(p.device)
|
||||||
|
|
||||||
|
state['state1'] = torch.zeros_like(p, memory_format=torch.preserve_format, dtype=torch.uint8, device=p.device)
|
||||||
|
state['qmap1'] = self.name2qmap['dynamic']
|
||||||
|
|
||||||
|
if config['block_wise']:
|
||||||
|
n = p.numel()
|
||||||
|
blocks = n//2048
|
||||||
|
blocks += 1 if n % 2048 > 0 else 0
|
||||||
|
|
||||||
|
state['absmax1'] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
|
||||||
|
else:
|
||||||
|
state['max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
||||||
|
state['new_max1'] = torch.zeros((1,), dtype=torch.float32, device=p.device)
|
||||||
|
|
||||||
|
if config['percentile_clipping'] < 100:
|
||||||
|
state['gnorm_vec'] = torch.zeros((100,), device=p.device)
|
||||||
|
|
||||||
|
if config['max_unorm'] > 0.0:
|
||||||
|
state['unorm_vec'] = torch.zeros((1,), device=p.device)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update_step(self, group, p, gindex, pindex):
|
||||||
|
state = self.state[p]
|
||||||
|
grad = p.grad
|
||||||
|
|
||||||
|
config = self.get_config(gindex, pindex, group)
|
||||||
|
|
||||||
|
state['step'] += 1
|
||||||
|
step = state['step']
|
||||||
|
|
||||||
|
if config['percentile_clipping'] < 100:
|
||||||
|
current_gnorm, clip_value, gnorm_scale = F.percentile_clipping(grad, state['gnorm_vec'], step, config['percentile_clipping'])
|
||||||
|
else:
|
||||||
|
gnorm_scale = 1.0
|
||||||
|
|
||||||
|
if state['state1'].dtype == torch.float:
|
||||||
|
F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'],
|
||||||
|
None, 0.0, config['weight_decay'], gnorm_scale,
|
||||||
|
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
|
||||||
|
|
||||||
|
elif state['state1'].dtype == torch.uint8 and not config['block_wise']:
|
||||||
|
F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
||||||
|
config['eps'], step, config['lr'], state['qmap1'], None, state['max1'], None, state['new_max1'], None,
|
||||||
|
config['weight_decay'], gnorm_scale,
|
||||||
|
state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'])
|
||||||
|
|
||||||
|
state['max1'], state['new_max1'] = state['new_max1'], state['max1']
|
||||||
|
elif state['state1'].dtype == torch.uint8 and config['block_wise']:
|
||||||
|
F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1],
|
||||||
|
config['eps'], step, config['lr'],
|
||||||
|
state['qmap1'], None, state['absmax1'], None,
|
||||||
|
config['weight_decay'], gnorm_scale=gnorm_scale)
|
37
bitsandbytes/optim/rmsprop.py
Normal file
37
bitsandbytes/optim/rmsprop.py
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import torch
|
||||||
|
from bitsandbytes.optim.optimizer import Optimizer1State
|
||||||
|
|
||||||
|
class RMSprop(Optimizer1State):
|
||||||
|
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, optim_bits=32, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
if alpha == 0:
|
||||||
|
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
|
||||||
|
if centered:
|
||||||
|
raise NotImplementError(f'Centered RMSprop is not supported!')
|
||||||
|
super(RMSprop, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
|
||||||
|
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
|
||||||
|
|
||||||
|
class RMSprop8bit(Optimizer1State):
|
||||||
|
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
if alpha == 0:
|
||||||
|
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
|
||||||
|
if centered:
|
||||||
|
raise NotImplementError(f'Centered RMSprop is not supported!')
|
||||||
|
super(RMSprop8bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
|
||||||
|
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
|
||||||
|
|
||||||
|
class RMSprop32bit(Optimizer1State):
|
||||||
|
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
|
||||||
|
if alpha == 0:
|
||||||
|
raise NotImplementError(f'RMSprop with alpha==0.0 is not supported!')
|
||||||
|
if centered:
|
||||||
|
raise NotImplementError(f'Centered RMSprop is not supported!')
|
||||||
|
super(RMSprop32bit, self).__init__('rmsprop', params, lr, (alpha, momentum), eps,
|
||||||
|
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
|
32
bitsandbytes/optim/sgd.py
Normal file
32
bitsandbytes/optim/sgd.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
from bitsandbytes.optim.optimizer import Optimizer1State
|
||||||
|
|
||||||
|
class SGD(Optimizer1State):
|
||||||
|
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||||
|
weight_decay=0, nesterov=False, optim_bits=32, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
if momentum == 0:
|
||||||
|
raise NotImplementError(f'SGD without momentum is not supported!')
|
||||||
|
super(SGD, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
|
||||||
|
weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise)
|
||||||
|
|
||||||
|
class SGD8bit(Optimizer1State):
|
||||||
|
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||||
|
weight_decay=0, nesterov=False, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
if momentum == 0:
|
||||||
|
raise NotImplementError(f'SGD without momentum is not supported!')
|
||||||
|
super(SGD8bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
|
||||||
|
weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise)
|
||||||
|
|
||||||
|
class SGD32bit(Optimizer1State):
|
||||||
|
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||||
|
weight_decay=0, nesterov=False, args=None,
|
||||||
|
min_8bit_size=4096, percentile_clipping=100, block_wise=True):
|
||||||
|
if momentum == 0:
|
||||||
|
raise NotImplementError(f'SGD without momentum is not supported!')
|
||||||
|
super(SGD32bit, self).__init__('momentum', params, lr, (momentum, dampening), 0.0,
|
||||||
|
weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise)
|
1846
csrc/kernels.cu
Normal file
1846
csrc/kernels.cu
Normal file
File diff suppressed because it is too large
Load Diff
111
csrc/kernels.cuh
Normal file
111
csrc/kernels.cuh
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
//
|
||||||
|
// This source code is licensed under the MIT license found in the
|
||||||
|
// LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
#include <float.h>
|
||||||
|
#include <ops.cuh>
|
||||||
|
|
||||||
|
#ifndef kernels
|
||||||
|
#define kernels
|
||||||
|
|
||||||
|
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
|
||||||
|
|
||||||
|
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
|
||||||
|
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
|
||||||
|
|
||||||
|
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC> __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
|
||||||
|
template<typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH> __global__ void kDequantizeBlockwise(float *code, unsigned char * __restrict__ const A, float * __restrict__ const absmax, T *out, const int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||||
|
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
|
||||||
|
float* state1, float* state2, float *unorm,
|
||||||
|
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||||
|
const int step, const float lr, const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER>
|
||||||
|
__global__ void kOptimizer32bit2State(T* g, T* p,
|
||||||
|
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
|
||||||
|
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||||
|
const int step, const float lr, const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
|
||||||
|
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
|
||||||
|
float* state1, float *unorm,
|
||||||
|
const float beta1, const float eps, const float weight_decay,
|
||||||
|
const int step, const float lr, const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER>
|
||||||
|
__global__ void kOptimizer32bit1State(T* g, T* p,
|
||||||
|
float* state1, float *unorm, const float max_unorm, const float param_norm,
|
||||||
|
const float beta1, const float eps, const float weight_decay,
|
||||||
|
const int step, const float lr, const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER>
|
||||||
|
__global__ void
|
||||||
|
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
|
||||||
|
float *unorm,
|
||||||
|
const float beta1,
|
||||||
|
const float eps, const int step,
|
||||||
|
float* __restrict__ const quantiles1,
|
||||||
|
float* max1, float* new_max1,
|
||||||
|
const float weight_decay,
|
||||||
|
const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER>
|
||||||
|
__global__ void
|
||||||
|
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
|
||||||
|
const float *unorm, const float max_unorm, const float param_norm,
|
||||||
|
const float beta1,
|
||||||
|
const float eps, const int step, const float lr,
|
||||||
|
float* __restrict__ const quantiles1,
|
||||||
|
float* max1, float* new_max1,
|
||||||
|
float weight_decay, const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER>
|
||||||
|
__global__ void
|
||||||
|
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
|
||||||
|
float *unorm,
|
||||||
|
const float beta1, const float beta2,
|
||||||
|
const float eps, const int step,
|
||||||
|
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
|
||||||
|
float* max1, float* max2, float* new_max1, float* new_max2,
|
||||||
|
const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER>
|
||||||
|
__global__ void
|
||||||
|
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
|
||||||
|
const float *unorm, const float max_unorm, const float param_norm,
|
||||||
|
const float beta1, const float beta2,
|
||||||
|
const float eps, const int step, const float lr,
|
||||||
|
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
|
||||||
|
float* max1, float* max2, float* new_max1, float* new_max2,
|
||||||
|
float weight_decay, const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
|
||||||
|
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
|
||||||
|
const float beta1, const float beta2, const float eps, const int step, const float lr,
|
||||||
|
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
|
||||||
|
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
|
||||||
|
T* p, T* __restrict__ const g, unsigned char* state1,
|
||||||
|
const float beta1, const float beta2,
|
||||||
|
const float eps, const int step, const float lr,
|
||||||
|
float* __restrict__ const quantiles1,
|
||||||
|
float* absmax1,
|
||||||
|
float weight_decay,
|
||||||
|
const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
|
||||||
|
|
||||||
|
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
355
csrc/ops.cu
Normal file
355
csrc/ops.cu
Normal file
|
@ -0,0 +1,355 @@
|
||||||
|
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
//
|
||||||
|
// This source code is licensed under the MIT license found in the
|
||||||
|
// LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
#include <ops.cuh>
|
||||||
|
#include <kernels.cuh>
|
||||||
|
#include <cub/device/device_scan.cuh>
|
||||||
|
#include <limits>
|
||||||
|
#include <BinSearch.h>
|
||||||
|
|
||||||
|
|
||||||
|
using namespace BinSearch;
|
||||||
|
using std::cout;
|
||||||
|
using std::endl;
|
||||||
|
|
||||||
|
#define BLOCK_SIZE 4096
|
||||||
|
|
||||||
|
struct quantize_block_args
|
||||||
|
{
|
||||||
|
BinAlgo<Scalar, float, Direct2> *bin_searcher;
|
||||||
|
float *code;
|
||||||
|
float *A;
|
||||||
|
float *absmax;
|
||||||
|
unsigned char *out;
|
||||||
|
int block_end;
|
||||||
|
int block_idx;
|
||||||
|
int threadidx;
|
||||||
|
};
|
||||||
|
|
||||||
|
void *quantize_block(void *arguments)
|
||||||
|
{
|
||||||
|
// 1. find absmax in block
|
||||||
|
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
|
||||||
|
// 3. do binary search to find the closest value
|
||||||
|
// 4. check minimal distance
|
||||||
|
// 5. store index
|
||||||
|
|
||||||
|
struct quantize_block_args *args = (quantize_block_args*)arguments;
|
||||||
|
|
||||||
|
// 1. find absmax in block
|
||||||
|
float absmax_block = -FLT_MAX;
|
||||||
|
for (int i = args->block_idx; i < args->block_end; i++)
|
||||||
|
absmax_block = fmax(absmax_block, fabs(args->A[i]));
|
||||||
|
|
||||||
|
args->absmax[args->block_idx/BLOCK_SIZE] = absmax_block;
|
||||||
|
|
||||||
|
for (int i = args->block_idx; i < args->block_end; i++)
|
||||||
|
{
|
||||||
|
// 2. divide input value by absmax to normalize into [-1.0, 1.0]
|
||||||
|
// 3. do binary search to find the closest value
|
||||||
|
float normed_value = args->A[i]/absmax_block;
|
||||||
|
int idx = args->bin_searcher->scalar(normed_value);
|
||||||
|
|
||||||
|
// 4. check minimal distance
|
||||||
|
// The binary search returns always the value to the left, which might not be the closest value
|
||||||
|
if(idx < 255)
|
||||||
|
{
|
||||||
|
float dist_left = fabs(normed_value-(args->code[idx]));
|
||||||
|
float dist_right = fabs(normed_value-(args->code[idx+1]));
|
||||||
|
if(dist_right < dist_left){ idx+=1; }
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. store index
|
||||||
|
args->out[i] = (unsigned char)idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n)
|
||||||
|
{
|
||||||
|
|
||||||
|
// the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below
|
||||||
|
code[0] = -1.0f;
|
||||||
|
|
||||||
|
int num_blocks = n/BLOCK_SIZE;
|
||||||
|
num_blocks += n % BLOCK_SIZE == 0 ? 0 : 1;
|
||||||
|
|
||||||
|
pthread_t *threads = (pthread_t*)malloc(sizeof(pthread_t)*num_blocks);
|
||||||
|
struct quantize_block_args **args = (quantize_block_args**)malloc(num_blocks*sizeof(quantize_block_args*));
|
||||||
|
|
||||||
|
for(int i = 0; i < num_blocks; i++)
|
||||||
|
args[i] = (quantize_block_args*)malloc(sizeof(quantize_block_args));
|
||||||
|
|
||||||
|
const uint32 elements_code = 256;
|
||||||
|
BinAlgo<Scalar, float, Direct2> bin_searcher(code, elements_code);
|
||||||
|
|
||||||
|
for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
|
||||||
|
{
|
||||||
|
int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
|
||||||
|
int block_end = block_idx + valid_items;
|
||||||
|
|
||||||
|
struct quantize_block_args *arg = args[block_idx/BLOCK_SIZE];
|
||||||
|
arg->bin_searcher = &bin_searcher;
|
||||||
|
arg->code = code;
|
||||||
|
arg->A = A;
|
||||||
|
arg->absmax = absmax;
|
||||||
|
arg->out = out;
|
||||||
|
arg->block_end = block_end;
|
||||||
|
arg->block_idx = block_idx;
|
||||||
|
arg->threadidx = block_idx/BLOCK_SIZE;
|
||||||
|
|
||||||
|
pthread_create(&threads[block_idx/BLOCK_SIZE], NULL, &quantize_block, (void *)arg);
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i = 0; i < num_blocks; i++)
|
||||||
|
int err = pthread_join(threads[i], NULL);
|
||||||
|
|
||||||
|
free(threads);
|
||||||
|
for(int i = 0; i < num_blocks; i++)
|
||||||
|
free(args[i]);
|
||||||
|
free(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n)
|
||||||
|
{
|
||||||
|
for(int block_idx = 0; block_idx < n; block_idx+=BLOCK_SIZE)
|
||||||
|
{
|
||||||
|
int valid_items = n-block_idx >= BLOCK_SIZE ? BLOCK_SIZE : n - block_idx;
|
||||||
|
int block_end = block_idx + valid_items;
|
||||||
|
for (int i = block_idx; i < block_end; i++)
|
||||||
|
out[i] = code[A[i]]*absmax[block_idx/BLOCK_SIZE];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n)
|
||||||
|
{
|
||||||
|
int threads = 512;
|
||||||
|
int blocks = n/threads;
|
||||||
|
blocks = n % threads == 0 ? blocks : blocks + 1;
|
||||||
|
kHistogramScatterAdd2D<<<blocks, 512>>>(histogram, index1, index2, src, maxidx1, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n)
|
||||||
|
{
|
||||||
|
int blocks = n/4096;
|
||||||
|
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float)));
|
||||||
|
kEstimateQuantiles<T><<<blocks, 512>>>(A, code, offset, std::numeric_limits<T>::max(), n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantize(float *code, float *A, unsigned char *out, int n)
|
||||||
|
{
|
||||||
|
int blocks = n/1024;
|
||||||
|
blocks = n % 1024 == 0 ? blocks : blocks + 1;
|
||||||
|
kQuantize<<<blocks, 1024>>>(code, A, out, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
void dequantize(float *code, unsigned char *A, float *out, int n)
|
||||||
|
{
|
||||||
|
int blocks = n/1024;
|
||||||
|
blocks = n % 1024 == 0 ? blocks : blocks + 1;
|
||||||
|
kDequantize<<<blocks, 1024>>>(code, A, out, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n)
|
||||||
|
{
|
||||||
|
int blocks = n/4096;
|
||||||
|
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||||
|
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC><<<blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n)
|
||||||
|
{
|
||||||
|
int blocks = n/blocksize;
|
||||||
|
blocks = n % blocksize == 0 ? blocks : blocks + 1;
|
||||||
|
if(blocksize == 4096)
|
||||||
|
kDequantizeBlockwise<T, 4096, 1024, 4><<<blocks, 4096/4>>>(code, A, absmax, out, n);
|
||||||
|
else if(blocksize == 2048)
|
||||||
|
kDequantizeBlockwise<T, 2048, 512, 4><<<blocks, 2048/4>>>(code, A, absmax, out, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||||
|
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||||
|
const float beta1, const float beta2, const float eps, const float weight_decay,
|
||||||
|
const int step, const float lr, const float gnorm_scale, const int n)
|
||||||
|
{
|
||||||
|
int blocks = n/4096;
|
||||||
|
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||||
|
switch(OPTIMIZER)
|
||||||
|
{
|
||||||
|
case ADAM:
|
||||||
|
if(max_unorm > 0.0f)
|
||||||
|
{
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||||
|
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
kOptimizer32bit2State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
break;
|
||||||
|
case MOMENTUM:
|
||||||
|
case RMSPROP:
|
||||||
|
if(max_unorm > 0.0f)
|
||||||
|
{
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float)));
|
||||||
|
kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8><<<blocks, 512>>>(g, p, state1, unorm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
kOptimizer32bit1State<T, OPTIMIZER><<<blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, eps, weight_decay, step, lr, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
|
||||||
|
unsigned char* state1, unsigned char* state2,
|
||||||
|
float *unorm, float max_unorm, float param_norm,
|
||||||
|
float beta1, float beta2,
|
||||||
|
float eps, int step, float lr,
|
||||||
|
float* quantiles1, float* quantiles2,
|
||||||
|
float* max1, float* max2, float* new_max1, float* new_max2,
|
||||||
|
float weight_decay,
|
||||||
|
const float gnorm_scale, int n)
|
||||||
|
{
|
||||||
|
int blocks = n/4096;
|
||||||
|
blocks = n % 4096 == 0 ? blocks : blocks + 1;
|
||||||
|
|
||||||
|
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); }
|
||||||
|
|
||||||
|
switch(OPTIMIZER)
|
||||||
|
{
|
||||||
|
case ADAM:
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float)));
|
||||||
|
kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
kOptimizerStatic8bit2State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
|
||||||
|
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
break;
|
||||||
|
case MOMENTUM:
|
||||||
|
case RMSPROP:
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
|
||||||
|
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
kOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, eps, step, lr,
|
||||||
|
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BLOCKSIZE_2STATE 2048
|
||||||
|
#define NUM_2STATE 8
|
||||||
|
#define BLOCKSIZE_1STATE 2048
|
||||||
|
#define NUM_1STATE 8
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
|
||||||
|
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
|
||||||
|
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)
|
||||||
|
{
|
||||||
|
|
||||||
|
int blocks = 0;
|
||||||
|
switch(OPTIMIZER)
|
||||||
|
{
|
||||||
|
case ADAM:
|
||||||
|
blocks = n/BLOCKSIZE_2STATE;
|
||||||
|
blocks = n % BLOCKSIZE_2STATE == 0 ? blocks : blocks + 1;
|
||||||
|
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, state1, state2, beta1, beta2, eps, step, lr,
|
||||||
|
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
break;
|
||||||
|
case MOMENTUM:
|
||||||
|
case RMSPROP:
|
||||||
|
blocks = n/BLOCKSIZE_1STATE;
|
||||||
|
blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1;
|
||||||
|
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
|
||||||
|
quantiles1, absmax1, weight_decay, gnorm_scale, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
|
||||||
|
{
|
||||||
|
int blocks = n/2048;
|
||||||
|
blocks = n % 2048 == 0 ? blocks : blocks + 1;
|
||||||
|
CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
|
||||||
|
kPercentileClipping<T, 2048, 4><<<blocks, 512>>>(g, gnorm_vec, step, n);
|
||||||
|
CUDA_CHECK_RETURN(cudaPeekAtLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//==============================================================
|
||||||
|
// TEMPLATE DEFINITIONS
|
||||||
|
//==============================================================
|
||||||
|
|
||||||
|
template void estimateQuantiles(half *A, float *code, float offset, int n);
|
||||||
|
template void estimateQuantiles(float *A, float *code, float offset, int n);
|
||||||
|
|
||||||
|
template void quantizeBlockwise<half, 0>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||||
|
template void quantizeBlockwise<float, 0>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||||
|
template void quantizeBlockwise<half, 1>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||||
|
template void quantizeBlockwise<float, 1>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||||
|
template void dequantizeBlockwise<half>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n);
|
||||||
|
template void dequantizeBlockwise<float>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n);
|
||||||
|
|
||||||
|
#define MAKE_optimizer32bit(name, gtype) \
|
||||||
|
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
|
||||||
|
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
|
||||||
|
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||||
|
const int step, const float lr, const float gnorm_scale, const int n);
|
||||||
|
|
||||||
|
MAKE_optimizer32bit(ADAM, half)
|
||||||
|
MAKE_optimizer32bit(ADAM, float)
|
||||||
|
MAKE_optimizer32bit(MOMENTUM, half)
|
||||||
|
MAKE_optimizer32bit(MOMENTUM, float)
|
||||||
|
MAKE_optimizer32bit(RMSPROP, half)
|
||||||
|
MAKE_optimizer32bit(RMSPROP, float)
|
||||||
|
|
||||||
|
#define MAKE_optimizerStatic8bit(name, gtype) \
|
||||||
|
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||||
|
float *unorm, float max_unorm, float param_norm, \
|
||||||
|
float beta1, float beta2, \
|
||||||
|
float eps, int step, float lr, \
|
||||||
|
float* quantiles1, float* quantiles2, \
|
||||||
|
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||||
|
float weight_decay, \
|
||||||
|
const float gnorm_scale, int n); \
|
||||||
|
|
||||||
|
MAKE_optimizerStatic8bit(ADAM, half)
|
||||||
|
MAKE_optimizerStatic8bit(ADAM, float)
|
||||||
|
MAKE_optimizerStatic8bit(MOMENTUM, half)
|
||||||
|
MAKE_optimizerStatic8bit(MOMENTUM, float)
|
||||||
|
MAKE_optimizerStatic8bit(RMSPROP, half)
|
||||||
|
MAKE_optimizerStatic8bit(RMSPROP, float)
|
||||||
|
|
||||||
|
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
|
||||||
|
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
|
||||||
|
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||||
|
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); \
|
||||||
|
|
||||||
|
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
|
||||||
|
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
|
||||||
|
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
|
||||||
|
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
|
||||||
|
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
|
||||||
|
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
|
||||||
|
|
||||||
|
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
|
||||||
|
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
|
81
csrc/ops.cuh
Normal file
81
csrc/ops.cuh
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
//
|
||||||
|
// This source code is licensed under the MIT license found in the
|
||||||
|
// LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef ops_H
|
||||||
|
#define ops_H
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#define CUDA_CHECK_RETURN(value) { \
|
||||||
|
cudaError_t _m_cudaStat = value; \
|
||||||
|
if (_m_cudaStat != cudaSuccess) { \
|
||||||
|
fprintf(stderr, "Error %s at line %d in file %s\n", \
|
||||||
|
cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
|
||||||
|
exit(1); \
|
||||||
|
} }
|
||||||
|
|
||||||
|
#define THREADS_PER_BLOCKS (512)
|
||||||
|
|
||||||
|
typedef enum Operations_t
|
||||||
|
{
|
||||||
|
ksmul = 0,
|
||||||
|
} Operations_t;
|
||||||
|
|
||||||
|
typedef enum Optimizer_t
|
||||||
|
{
|
||||||
|
ADAM = 0,
|
||||||
|
MOMENTUM = 1,
|
||||||
|
RMSPROP = 2,
|
||||||
|
LARS = 3,
|
||||||
|
} Optimizer_t;
|
||||||
|
|
||||||
|
|
||||||
|
template <typename T> void estimateQuantiles(T *A, float *code, float offset, int n);
|
||||||
|
|
||||||
|
void quantize(float *code, float *A, unsigned char *out, int n);
|
||||||
|
void dequantize(float *code, unsigned char *A, float *out, int n);
|
||||||
|
template <typename T, int STOCHASTIC> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n);
|
||||||
|
template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
|
||||||
|
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
|
||||||
|
float beta1, float beta2, float eps, float weight_decay,
|
||||||
|
int step, float lr, const float gnorm_scale, int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2,
|
||||||
|
float *unorm, float max_unorm, float param_norm,
|
||||||
|
float beta1, float beta2,
|
||||||
|
float eps, int step, float lr,
|
||||||
|
float* quantiles1, float* quantiles2,
|
||||||
|
float* max1, float* max2, float* new_max1, float* new_max2,
|
||||||
|
float weight_decay,
|
||||||
|
const float gnorm_scale, int n);
|
||||||
|
|
||||||
|
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
|
||||||
|
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
|
||||||
|
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n);
|
||||||
|
|
||||||
|
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n);
|
||||||
|
|
||||||
|
void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, int n);
|
||||||
|
void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, int n);
|
||||||
|
|
||||||
|
void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
149
csrc/pythonInterface.c
Normal file
149
csrc/pythonInterface.c
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
// Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
//
|
||||||
|
// This source code is licensed under the MIT license found in the
|
||||||
|
// LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
#include <ops.cuh>
|
||||||
|
|
||||||
|
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
|
||||||
|
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
|
||||||
|
// maintain all that boilerplate
|
||||||
|
//===================================================================================
|
||||||
|
// UNMANGLED CALLS
|
||||||
|
//===================================================================================
|
||||||
|
|
||||||
|
void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles<float>(A, code, offset, n); }
|
||||||
|
void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles<half>(A, code, offset, n); }
|
||||||
|
|
||||||
|
|
||||||
|
#define MAKE_FUNC32(fname, oname, gtype, gbits) \
|
||||||
|
void fname##32bit_g##gbits(gtype *g, gtype *p, \
|
||||||
|
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||||
|
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||||
|
const int step, const float lr, float gnorm_scale, const int n) \
|
||||||
|
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
|
||||||
|
|
||||||
|
MAKE_FUNC32(momentum, MOMENTUM, float, 32)
|
||||||
|
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
|
||||||
|
MAKE_FUNC32(adam, ADAM, float, 32)
|
||||||
|
MAKE_FUNC32(adam, ADAM, half, 16)
|
||||||
|
MAKE_FUNC32(rmsprop, RMSPROP, float, 32)
|
||||||
|
MAKE_FUNC32(rmsprop, RMSPROP, half, 16)
|
||||||
|
|
||||||
|
#define MAKE_FUNC8(fname, oname, gtype, gbits) \
|
||||||
|
void fname##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||||
|
float *unorm, float max_unorm, float param_norm, \
|
||||||
|
float beta1, float beta2, \
|
||||||
|
float eps, int step, float lr, \
|
||||||
|
float* quantiles1, float* quantiles2, \
|
||||||
|
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||||
|
float weight_decay, float gnorm_scale, int n) \
|
||||||
|
{ \
|
||||||
|
optimizerStatic8bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||||
|
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
MAKE_FUNC8(adam, ADAM, float, 32)
|
||||||
|
MAKE_FUNC8(adam, ADAM, half, 16)
|
||||||
|
MAKE_FUNC8(momentum, MOMENTUM, float, 32)
|
||||||
|
MAKE_FUNC8(momentum, MOMENTUM, half, 16)
|
||||||
|
MAKE_FUNC8(rmsprop, RMSPROP, float, 32)
|
||||||
|
MAKE_FUNC8(rmsprop, RMSPROP, half, 16)
|
||||||
|
|
||||||
|
#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||||
|
void fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||||
|
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||||
|
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n)\
|
||||||
|
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); }\
|
||||||
|
|
||||||
|
MAKE_BLOCKWISE8(adam, ADAM, half, 16)
|
||||||
|
MAKE_BLOCKWISE8(adam, ADAM, float, 32)
|
||||||
|
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, 16)
|
||||||
|
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||||
|
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||||
|
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||||
|
|
||||||
|
|
||||||
|
void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping<float>(g, gnorm_vec, step, n); }
|
||||||
|
void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping<half>(g, gnorm_vec, step, n); }
|
||||||
|
|
||||||
|
void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<half, 0>(code, A, absmax, out, NULL, 0, n); }
|
||||||
|
void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise<float, 0>(code, A, absmax, out, NULL, 0, n); }
|
||||||
|
void quantizeBlockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<half, 1>(code, A, absmax, out, rand, rand_offset, n); }
|
||||||
|
void quantizeBlockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, const int n){ quantizeBlockwise<float, 1>(code, A, absmax, out, rand, rand_offset, n); }
|
||||||
|
|
||||||
|
void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise<half>(code, A, absmax, out, blocksize, n); } \
|
||||||
|
void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise<float>(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
{
|
||||||
|
void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); }
|
||||||
|
void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); }
|
||||||
|
void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); }
|
||||||
|
void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); }
|
||||||
|
void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, n); }
|
||||||
|
void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, n); }
|
||||||
|
void cquantize_blockwise_stochastic_fp16(float * code, half *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp16(code, A, absmax, out, rand, rand_offset, n); }
|
||||||
|
void cquantize_blockwise_stochastic_fp32(float * code, float *A, float *absmax, unsigned char *out, float *rand, int rand_offset, const int n){ quantizeBlockwise_stochastic_fp32(code, A, absmax, out, rand, rand_offset, n); }
|
||||||
|
|
||||||
|
void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); }
|
||||||
|
void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); }
|
||||||
|
|
||||||
|
#define MAKE_CFUNC32(name, gtype, gbits) \
|
||||||
|
void c##name##32bit_g##gbits(gtype *g, gtype *p, \
|
||||||
|
float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \
|
||||||
|
const float beta1, const float beta2, const float eps, const float weight_decay, \
|
||||||
|
const int step, const float lr, const float gnorm_scale, const int n) \
|
||||||
|
{ name##32bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); } \
|
||||||
|
|
||||||
|
MAKE_CFUNC32(adam, float, 32)
|
||||||
|
MAKE_CFUNC32(adam, half, 16)
|
||||||
|
MAKE_CFUNC32(momentum, float, 32)
|
||||||
|
MAKE_CFUNC32(momentum, half, 16)
|
||||||
|
MAKE_CFUNC32(rmsprop, float, 32)
|
||||||
|
MAKE_CFUNC32(rmsprop, half, 16)
|
||||||
|
|
||||||
|
#define MAKE_CFUNC8(name, gtype, gbits) \
|
||||||
|
void c##name##_static_8bit_g##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
|
||||||
|
float *unorm, float max_unorm, float param_norm, \
|
||||||
|
float beta1, float beta2, \
|
||||||
|
float eps, int step, float lr, \
|
||||||
|
float* quantiles1, float* quantiles2, \
|
||||||
|
float* max1, float* max2, float* new_max1, float* new_max2, \
|
||||||
|
float weight_decay, float gnorm_scale, int n) \
|
||||||
|
{ \
|
||||||
|
name##_static_8bit_g##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \
|
||||||
|
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \
|
||||||
|
} \
|
||||||
|
|
||||||
|
MAKE_CFUNC8(adam, float, 32)
|
||||||
|
MAKE_CFUNC8(adam, half, 16)
|
||||||
|
MAKE_CFUNC8(momentum, float, 32)
|
||||||
|
MAKE_CFUNC8(momentum, half, 16)
|
||||||
|
MAKE_CFUNC8(rmsprop, float, 32)
|
||||||
|
MAKE_CFUNC8(rmsprop, half, 16)
|
||||||
|
|
||||||
|
#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
|
||||||
|
void c##fname##_8bit_blockwise_fp##gbits(gtype* p, gtype* g, \
|
||||||
|
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
|
||||||
|
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n) \
|
||||||
|
{ fname##_8bit_blockwise_fp##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, n); } \
|
||||||
|
|
||||||
|
MAKE_CBLOCKWISE8(adam, ADAM, half, 16)
|
||||||
|
MAKE_CBLOCKWISE8(adam, ADAM, float, 32)
|
||||||
|
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, 16)
|
||||||
|
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, 32)
|
||||||
|
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, 16)
|
||||||
|
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, 32)
|
||||||
|
|
||||||
|
|
||||||
|
void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); }
|
||||||
|
void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); }
|
||||||
|
|
||||||
|
void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, const int n){ quantize_cpu(code, A, absmax, out, n); }
|
||||||
|
void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, const int n){ dequantize_cpu(code, A, absmax, out, n); }
|
||||||
|
|
||||||
|
void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
13
deploy.sh
Normal file
13
deploy.sh
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
CUDA_HOME=/usr/local/cuda-10.2 make
|
||||||
|
CUDA_VERSION=102 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
CUDA_HOME=/usr/local/cuda-11.1 make
|
||||||
|
CUDA_VERSION=111 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
86
deploy_from_slurm.sh
Normal file
86
deploy_from_slurm.sh
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
#!/bin/bash
|
||||||
|
module unload cuda
|
||||||
|
module unload gcc
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
module load cuda/9.2
|
||||||
|
module load gcc/7.3.0
|
||||||
|
CUDA_HOME=/public/apps/cuda/9.2
|
||||||
|
make
|
||||||
|
CUDA_VERSION=92 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
module unload cuda
|
||||||
|
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
module load cuda/10.0
|
||||||
|
CUDA_HOME=/public/apps/cuda/10.0
|
||||||
|
make cuda10x
|
||||||
|
CUDA_VERSION=100 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
module unload cuda
|
||||||
|
module unload gcc
|
||||||
|
module load gcc/8.4
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
module load cuda/10.1
|
||||||
|
CUDA_HOME=/public/apps/cuda/10.1
|
||||||
|
make cuda10x
|
||||||
|
CUDA_VERSION=101 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
module unload cuda
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
module load cuda/10.2
|
||||||
|
CUDA_HOME=/public/apps/cuda/10.2/
|
||||||
|
make cuda10x
|
||||||
|
CUDA_VERSION=102 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
module unload cuda
|
||||||
|
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
module load cuda/11.0
|
||||||
|
CUDA_HOME=/public/apps/cuda/11.0
|
||||||
|
make cuda110
|
||||||
|
CUDA_VERSION=110 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
module unload cuda
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
module load cuda/11.1
|
||||||
|
CUDA_HOME=/public/apps/cuda/11.1
|
||||||
|
make cuda11x
|
||||||
|
CUDA_VERSION=111 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
module unload cuda
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
module load cuda/11.2
|
||||||
|
CUDA_HOME=/public/apps/cuda/11.2
|
||||||
|
make cuda11x
|
||||||
|
CUDA_VERSION=112 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
module unload cuda
|
||||||
|
|
||||||
|
rm -rf dist build
|
||||||
|
make clean
|
||||||
|
make cleaneggs
|
||||||
|
CUDA_HOME=/private/home/timdettmers/git/autoswap/local/cuda-11.3 make cuda11x
|
||||||
|
CUDA_VERSION=113 python -m build
|
||||||
|
python -m twine upload --repository testpypi dist/* --verbose
|
||||||
|
module unload cuda
|
86
include/AAlloc.h
Normal file
86
include/AAlloc.h
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Portable.h"
|
||||||
|
|
||||||
|
namespace BinSearch {
|
||||||
|
namespace Details {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
bool isAligned(const T *p, size_t A)
|
||||||
|
{
|
||||||
|
return (reinterpret_cast<size_t>(p) % A) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T, size_t A=64>
|
||||||
|
struct AlignedVec
|
||||||
|
{
|
||||||
|
AlignedVec()
|
||||||
|
: m_storage(0)
|
||||||
|
, m_data(0)
|
||||||
|
, m_sz(0)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t nBytes(size_t sz)
|
||||||
|
{
|
||||||
|
return sz * sizeof(T) + A;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t shiftAmt(char *p)
|
||||||
|
{
|
||||||
|
return A>1? (A - (reinterpret_cast<size_t>(p) % A)) % A: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void setPtr(char *p, size_t sz)
|
||||||
|
{
|
||||||
|
m_sz = sz;
|
||||||
|
m_data = reinterpret_cast<T *>(p + shiftAmt(p));
|
||||||
|
}
|
||||||
|
|
||||||
|
//void setPtr(T *p, size_t sz)
|
||||||
|
//{
|
||||||
|
// m_sz = sz;
|
||||||
|
// if (A>1)
|
||||||
|
// myassert(((reinterpret_cast<size_t>(p) % A) == 0), "bad alignment");
|
||||||
|
// m_data = p;
|
||||||
|
//}
|
||||||
|
|
||||||
|
// internal allocation
|
||||||
|
void resize(size_t sz)
|
||||||
|
{
|
||||||
|
m_storage = new char[nBytes(sz)];
|
||||||
|
setPtr(m_storage, sz);
|
||||||
|
}
|
||||||
|
|
||||||
|
// external allocation
|
||||||
|
void set(char *storage, size_t sz)
|
||||||
|
{
|
||||||
|
setPtr(storage, sz);
|
||||||
|
}
|
||||||
|
|
||||||
|
~AlignedVec()
|
||||||
|
{
|
||||||
|
if (m_storage)
|
||||||
|
delete [] m_storage;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t size() const { return m_sz; }
|
||||||
|
T& operator[](size_t i) { return m_data[i]; }
|
||||||
|
const T& operator[](size_t i) const { return m_data[i]; }
|
||||||
|
T* begin() { return m_data; }
|
||||||
|
T* end() { return m_data+m_sz; }
|
||||||
|
const T* begin() const { return m_data; }
|
||||||
|
const T* end() const { return m_data+m_sz; }
|
||||||
|
T& front() { return m_data[0]; }
|
||||||
|
T& back() { return m_data[m_sz-1]; }
|
||||||
|
const T& front() const { return m_data[0]; }
|
||||||
|
const T& back() const { return m_data[m_sz - 1]; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
char *m_storage;
|
||||||
|
T *m_data;
|
||||||
|
size_t m_sz;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace Details
|
||||||
|
} // namespace BinSearch
|
341
include/Algo-Direct-Common.h
Normal file
341
include/Algo-Direct-Common.h
Normal file
|
@ -0,0 +1,341 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <type_traits>
|
||||||
|
#include "AAlloc.h"
|
||||||
|
|
||||||
|
namespace BinSearch {
|
||||||
|
namespace Details {
|
||||||
|
|
||||||
|
namespace DirectAux {
|
||||||
|
|
||||||
|
#define SAFETY_MULTI_PASS true
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct HResults
|
||||||
|
{
|
||||||
|
HResults(T h, double ratio, size_t n) : H(h), hRatio(ratio), nInc(n) {}
|
||||||
|
T H;
|
||||||
|
double hRatio;
|
||||||
|
size_t nInc;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef USE_FMA
|
||||||
|
template <Algos A> struct IsDirect { static const bool value = (A == Direct) || (A == DirectFMA); };
|
||||||
|
template <Algos A> struct IsDirect2 { static const bool value = (A == Direct2) || (A == Direct2FMA); };
|
||||||
|
template <Algos A> struct IsDirectCache { static const bool value = (A == DirectCache) || (A == DirectCacheFMA); };
|
||||||
|
#else
|
||||||
|
template <Algos A> struct IsDirect { static const bool value = (A == Direct); };
|
||||||
|
template <Algos A> struct IsDirect2 { static const bool value = (A == Direct2); };
|
||||||
|
template <Algos A> struct IsDirectCache { static const bool value = (A == DirectCache); };
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// general definition
|
||||||
|
template <Algos A, typename T, typename Enable = void>
|
||||||
|
struct BucketElem
|
||||||
|
{
|
||||||
|
FORCE_INLINE void set( uint32 b, const T *)
|
||||||
|
{
|
||||||
|
m_b = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE uint32 index() const { return m_b; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint32 m_b;
|
||||||
|
};
|
||||||
|
|
||||||
|
// specialization for DirectCache methods
|
||||||
|
|
||||||
|
template <typename T> struct MatchingIntType;
|
||||||
|
template <> struct MatchingIntType<double> { typedef uint64 type; };
|
||||||
|
template <> struct MatchingIntType<float> { typedef uint32 type; };
|
||||||
|
|
||||||
|
template <Algos A, typename T>
|
||||||
|
struct BucketElem<A, T, typename std::enable_if< IsDirectCache<A>::value >::type >
|
||||||
|
{
|
||||||
|
typedef typename MatchingIntType<T>::type I;
|
||||||
|
|
||||||
|
void set(uint32 b, const T *xi)
|
||||||
|
{
|
||||||
|
u.u.x = xi[b];
|
||||||
|
u.u.b = b;
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE I index() const { return u.u.b; }
|
||||||
|
FORCE_INLINE T x() const { return u.u.x; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
union {
|
||||||
|
double dummy;
|
||||||
|
struct
|
||||||
|
{
|
||||||
|
T x;
|
||||||
|
I b;
|
||||||
|
} u;
|
||||||
|
} u;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <bool UseFMA, unsigned char Gap, typename T>
|
||||||
|
struct DirectTraits
|
||||||
|
{
|
||||||
|
static void checkH(T scaler, T x0, T xN)
|
||||||
|
{
|
||||||
|
T Dn = xN - x0;
|
||||||
|
T ifmax = Dn * scaler;
|
||||||
|
myassert((ifmax < std::numeric_limits<uint32>::max() - (Gap - 1)),
|
||||||
|
"Problem unfeasible: index size exceeds uint32 capacity:"
|
||||||
|
<< " D[N] =" << Dn
|
||||||
|
<< ", H =" << scaler
|
||||||
|
<< ", H D[n] =" << ifmax << "\n"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE static uint32 f(T scaler, T x0, T z)
|
||||||
|
{
|
||||||
|
T tmp = scaler * (z - x0);
|
||||||
|
#ifdef USE_SSE2
|
||||||
|
return ftoi(FVec1<SSE,T>(tmp));
|
||||||
|
#else
|
||||||
|
return static_cast<uint32>(tmp);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <InstrSet I>
|
||||||
|
FORCE_INLINE static typename FTOITraits<I, T>::vec_t f(const FVec<I, T>& scaler, const FVec<I, T>& x0, const FVec<I, T>& z)
|
||||||
|
{
|
||||||
|
return ftoi(scaler*(z-x0));
|
||||||
|
}
|
||||||
|
|
||||||
|
static T cst0(T scaler, T x0)
|
||||||
|
{
|
||||||
|
return x0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef USE_FMA
|
||||||
|
template <unsigned char Gap, typename T>
|
||||||
|
struct DirectTraits<true,Gap,T>
|
||||||
|
{
|
||||||
|
typedef FVec1<SSE, T> fVec1;
|
||||||
|
|
||||||
|
static void checkH(T scaler, T H_Times_x0, T xN)
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
typename FVec1<SSE, T>::vec_t v;
|
||||||
|
T s;
|
||||||
|
} ifmax;
|
||||||
|
ifmax.v = mulSub(fVec1(scaler), fVec1(xN), fVec1(H_Times_x0));
|
||||||
|
myassert((ifmax.s < std::numeric_limits<uint32>::max() - (Gap - 1)),
|
||||||
|
"Problem unfeasible: index size exceeds uint32 capacity:"
|
||||||
|
<< " H X[0] =" << H_Times_x0
|
||||||
|
<< ", H =" << scaler
|
||||||
|
<< ", X[N] =" << xN
|
||||||
|
<< ", H X[N] - H X[0] =" << ifmax.s << "\n"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE static uint32 f(T scaler, T Hx0, T xi)
|
||||||
|
{
|
||||||
|
return ftoi(mulSub(fVec1(scaler), fVec1(xi), fVec1(Hx0)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <InstrSet I>
|
||||||
|
FORCE_INLINE static typename FTOITraits<I,T>::vec_t f(const FVec<I,T>& scaler, const FVec<I, T>& H_Times_X0, const FVec<I, T>& z)
|
||||||
|
{
|
||||||
|
return ftoi(mulSub(scaler, z, H_Times_X0));
|
||||||
|
}
|
||||||
|
|
||||||
|
static T cst0(T scaler, T x0)
|
||||||
|
{
|
||||||
|
return scaler*x0;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <unsigned char Gap, typename T, Algos A>
|
||||||
|
struct DirectInfo
|
||||||
|
{
|
||||||
|
static const bool UseFMA = (A == DirectFMA) || (A == Direct2FMA) || (A == DirectCacheFMA);
|
||||||
|
typedef DirectTraits<UseFMA, Gap, T> fun_t;
|
||||||
|
typedef BucketElem<A,T> bucket_t;
|
||||||
|
typedef AlignedVec<bucket_t> bucketvec_t;
|
||||||
|
|
||||||
|
struct Data {
|
||||||
|
Data() : buckets(0), xi(0), scaler(0), cst0(0) {}
|
||||||
|
Data( const T *x // for Direct must persist if xws=NULL
|
||||||
|
, uint32 n
|
||||||
|
, T H
|
||||||
|
, bucket_t *bws // assumed to gave size nb, as computed below
|
||||||
|
, T *xws = NULL // assumed to have size (n+Gap-1). Optional for Direct, unused for DirectCache, required for DirectGap
|
||||||
|
)
|
||||||
|
: buckets(bws)
|
||||||
|
, scaler(H)
|
||||||
|
, cst0(fun_t::cst0(H, x[0]))
|
||||||
|
{
|
||||||
|
myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned");
|
||||||
|
|
||||||
|
uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]);
|
||||||
|
|
||||||
|
const uint32 npad = Gap-1;
|
||||||
|
const uint32 n_sz = n + npad; // size of padded vector
|
||||||
|
|
||||||
|
if (xws) {
|
||||||
|
myassert(isAligned(xws,8), "x pointer not allocated or incorrectly aligned");
|
||||||
|
std::fill_n(xws, npad, x[0]); // pad in front with x[0]
|
||||||
|
std::copy(x, x+n, xws + npad);
|
||||||
|
xi = xws;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
myassert(Gap==1, "if Gap>1 then X workspace must be provided");
|
||||||
|
xi = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
populateIndex(bws, nb, xi, n_sz, scaler, cst0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const bucket_t *buckets;
|
||||||
|
const T *xi;
|
||||||
|
T scaler;
|
||||||
|
T cst0; // could be x0 or (scaler*x0), depending if we are using FMA or not
|
||||||
|
} data;
|
||||||
|
|
||||||
|
static T growStep(T H)
|
||||||
|
{
|
||||||
|
T step;
|
||||||
|
T P = next(H);
|
||||||
|
while ((step = P - H) == 0)
|
||||||
|
P = next(P);
|
||||||
|
return step;
|
||||||
|
}
|
||||||
|
|
||||||
|
static HResults<T> computeH(const T *px, uint32 nx)
|
||||||
|
{
|
||||||
|
myassert((nx > Gap), "Array X too small");
|
||||||
|
myassert(((Gap == 1) || (Gap == 2)), "Only tested for these values of Gap");
|
||||||
|
|
||||||
|
const T x0 = px[0];
|
||||||
|
const T xN = px[nx-1];
|
||||||
|
|
||||||
|
const T range = xN - x0;
|
||||||
|
myassert((range < std::numeric_limits<T>::max()), "range too large");
|
||||||
|
|
||||||
|
// check that D_i are strictly increasing and compute minimum value D_{i+Offset}-D_i
|
||||||
|
T deltaDMin = range;
|
||||||
|
for (uint32 i = Gap; i < nx; ++i) {
|
||||||
|
T Dnew = px[i] - x0;
|
||||||
|
T Dold = px[i - Gap] - x0;
|
||||||
|
myassert((Dnew > Dold),
|
||||||
|
"Problem unfeasible: D_i sequence not strictly increasing"
|
||||||
|
<< " X[" << 0 << "]=" << x0
|
||||||
|
<< " X[" << i - Gap << "]=" << px[i - Gap]
|
||||||
|
<< " X[" << i << "]=" << px[i]
|
||||||
|
<< "\n"
|
||||||
|
);
|
||||||
|
T deltaD = Dnew - Dold;
|
||||||
|
if (deltaD < deltaDMin)
|
||||||
|
deltaDMin = deltaD;
|
||||||
|
}
|
||||||
|
|
||||||
|
// initial guess for H
|
||||||
|
const T H0 = T(1.0) / deltaDMin;
|
||||||
|
T H = H0;
|
||||||
|
|
||||||
|
T cst0 = fun_t::cst0(H, x0);
|
||||||
|
fun_t::checkH(H, cst0, xN);
|
||||||
|
|
||||||
|
// adjust H by trial and error until succeed
|
||||||
|
size_t nInc = 0;
|
||||||
|
bool modified = false;
|
||||||
|
size_t npasses = 0;
|
||||||
|
T step = growStep(H);
|
||||||
|
uint32 seg_already_checked_from = nx;
|
||||||
|
do {
|
||||||
|
myassert((npasses++ < 2), "verification failed\n");
|
||||||
|
// if there has been an increase, then check only up to that point
|
||||||
|
uint32 last_seg_to_be_checked = seg_already_checked_from - 1;
|
||||||
|
modified = false;
|
||||||
|
uint32 inew = 0;
|
||||||
|
for (uint32 i = Gap; i <= last_seg_to_be_checked; ++i) {
|
||||||
|
uint32 iold = fun_t::f(H, cst0, px[i-Gap]);
|
||||||
|
uint32 inew = fun_t::f(H, cst0, px[i]);
|
||||||
|
while (inew == iold) {
|
||||||
|
seg_already_checked_from = i;
|
||||||
|
last_seg_to_be_checked = nx-1; // everything needs to be checked
|
||||||
|
modified = true;
|
||||||
|
H = H + step;
|
||||||
|
step *= 2;
|
||||||
|
// recalculate all constants and indices
|
||||||
|
cst0 = fun_t::cst0(H, x0);
|
||||||
|
fun_t::checkH(H, cst0, xN);
|
||||||
|
iold = fun_t::f(H, cst0, px[i - Gap]);
|
||||||
|
inew = fun_t::f(H, cst0, px[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (SAFETY_MULTI_PASS && modified);
|
||||||
|
|
||||||
|
return HResults<T>(H, (((double)H) / H0) - 1.0, nInc);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void populateIndex(BucketElem<A, T> *buckets, uint32 index_size, const T *px, uint32 x_size, T scaler, T cst0)
|
||||||
|
{
|
||||||
|
for (uint32 i = x_size-1, b = index_size-1, j=0; ; --i) {
|
||||||
|
uint32 idx = fun_t::f(scaler, cst0, px[i]);
|
||||||
|
while (b > idx) { // in the 1st iteration it is j=0 but this condition is always false
|
||||||
|
buckets[b].set( j, px );
|
||||||
|
--b;
|
||||||
|
}
|
||||||
|
if (Gap==1 || b == idx) { // if Gap==1, which is known at compile time, the check b==idx is redundant
|
||||||
|
j = i - (Gap-1); // subtracting (Gap-1) points to the index of the first X-element to check
|
||||||
|
buckets[b].set(j, px);
|
||||||
|
if (b-- == 0)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectInfo(const Data& d)
|
||||||
|
: data(d)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
DirectInfo(const T* px, const uint32 n)
|
||||||
|
{
|
||||||
|
HResults<T> res = computeH(px, n);
|
||||||
|
|
||||||
|
#ifdef PAPER_TEST
|
||||||
|
nInc = res.nInc;
|
||||||
|
hRatio = res.hRatio;
|
||||||
|
#endif
|
||||||
|
const uint32 npad = Gap-1;
|
||||||
|
const uint32 n_sz = n + npad; // size of padded vector
|
||||||
|
|
||||||
|
if (npad)
|
||||||
|
xi.resize(n_sz);
|
||||||
|
|
||||||
|
T H = res.H;
|
||||||
|
T cst0 = fun_t::cst0(H, px[0]);
|
||||||
|
const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]);
|
||||||
|
buckets.resize(maxIndex + 1);
|
||||||
|
|
||||||
|
data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
bucketvec_t buckets;
|
||||||
|
AlignedVec<T,8> xi;
|
||||||
|
|
||||||
|
#ifdef PAPER_TEST
|
||||||
|
public:
|
||||||
|
double hRatio;
|
||||||
|
size_t nInc;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace DirectAux
|
||||||
|
} // namespace Details
|
||||||
|
} // namespace BinSearch
|
305
include/Algo-Direct2.h
Normal file
305
include/Algo-Direct2.h
Normal file
|
@ -0,0 +1,305 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Algo-Direct-Common.h"
|
||||||
|
|
||||||
|
namespace BinSearch {
|
||||||
|
namespace Details {
|
||||||
|
|
||||||
|
template <typename T, Algos A>
|
||||||
|
struct AlgoScalarBase<T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : DirectAux::DirectInfo<2, T, A>
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
typedef DirectAux::DirectInfo<2, T, A> base_t;
|
||||||
|
static const size_t Offset=2;
|
||||||
|
|
||||||
|
public:
|
||||||
|
AlgoScalarBase(const T* x, const uint32 n)
|
||||||
|
: base_t(x, n)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE uint32 scalar(T z) const
|
||||||
|
{
|
||||||
|
const T* px = base_t::data.xi;
|
||||||
|
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||||
|
uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z);
|
||||||
|
uint32 iidx = buckets[bidx];
|
||||||
|
px += iidx;
|
||||||
|
if (z < *px)
|
||||||
|
--iidx;
|
||||||
|
if (z < *(px+1))
|
||||||
|
--iidx;
|
||||||
|
return iidx;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <InstrSet I, typename T, Algos A>
|
||||||
|
struct AlgoVecBase<I, T, A, typename std::enable_if<DirectAux::IsDirect2<A>::value>::type> : AlgoScalarBase<T, A>
|
||||||
|
{
|
||||||
|
static const uint32 nElem = sizeof(typename InstrFloatTraits<I, T>::vec_t) / sizeof(T);
|
||||||
|
|
||||||
|
typedef FVec<I, T> fVec;
|
||||||
|
typedef IVec<SSE, T> i128;
|
||||||
|
|
||||||
|
struct Constants
|
||||||
|
{
|
||||||
|
fVec vscaler;
|
||||||
|
fVec vcst0;
|
||||||
|
IVec<I, T> one;
|
||||||
|
};
|
||||||
|
|
||||||
|
private:
|
||||||
|
typedef AlgoScalarBase<T, A> base_t;
|
||||||
|
|
||||||
|
FORCE_INLINE
|
||||||
|
//NO_INLINE
|
||||||
|
void resolve(const FVec<SSE, float>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
|
||||||
|
{
|
||||||
|
union U {
|
||||||
|
__m128i vec;
|
||||||
|
uint32 ui32[4];
|
||||||
|
} u;
|
||||||
|
|
||||||
|
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||||
|
const float *xi = base_t::data.xi;
|
||||||
|
|
||||||
|
// read indices t
|
||||||
|
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
|
||||||
|
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
|
||||||
|
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
|
||||||
|
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
// read pairs ( X(t-1), X(t) )
|
||||||
|
__m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3));
|
||||||
|
__m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2));
|
||||||
|
__m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1));
|
||||||
|
__m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0));
|
||||||
|
|
||||||
|
// build:
|
||||||
|
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
|
||||||
|
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
|
||||||
|
__m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6));
|
||||||
|
__m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6));
|
||||||
|
__m128 u01 = _mm_unpacklo_ps(h02, h13);
|
||||||
|
__m128 u23 = _mm_unpackhi_ps(h02, h13);
|
||||||
|
__m128 vxm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6));
|
||||||
|
__m128 vxp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6));
|
||||||
|
#else
|
||||||
|
__m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2));
|
||||||
|
__m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0));
|
||||||
|
__m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6));
|
||||||
|
__m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6));
|
||||||
|
#endif
|
||||||
|
IVec<SSE, float> i(u.vec);
|
||||||
|
IVec<SSE, float> vlem = vz < vxm;
|
||||||
|
IVec<SSE, float> vlep = vz < vxp;
|
||||||
|
i = i + vlem + vlep;
|
||||||
|
i.store(pr);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE
|
||||||
|
//NO_INLINE
|
||||||
|
void resolve(const FVec<SSE, double>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
|
||||||
|
{
|
||||||
|
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||||
|
const double *xi = base_t::data.xi;
|
||||||
|
|
||||||
|
uint32 b1 = buckets[bidx.get1()];
|
||||||
|
uint32 b0 = buckets[bidx.get0()];
|
||||||
|
|
||||||
|
const double *p1 = &xi[b1];
|
||||||
|
const double *p0 = &xi[b0];
|
||||||
|
|
||||||
|
// read pairs ( X(t-1), X(t) )
|
||||||
|
__m128d vx1 = _mm_loadu_pd(p1);
|
||||||
|
__m128d vx0 = _mm_loadu_pd(p0);
|
||||||
|
|
||||||
|
// build:
|
||||||
|
// { X(t(0)-1), X(t(1)-1) }
|
||||||
|
// { X(t(0)), X(t(1)) }
|
||||||
|
__m128d vxm = _mm_shuffle_pd(vx0, vx1, 0);
|
||||||
|
__m128d vxp = _mm_shuffle_pd(vx0, vx1, 3);
|
||||||
|
|
||||||
|
IVec<SSE, double> i(b1, b0);
|
||||||
|
IVec<SSE, double> vlem = (vz < vxm);
|
||||||
|
IVec<SSE, double> vlep = (vz < vxp);
|
||||||
|
i = i + vlem + vlep;
|
||||||
|
|
||||||
|
union {
|
||||||
|
__m128i vec;
|
||||||
|
uint32 ui32[4];
|
||||||
|
} u;
|
||||||
|
u.vec = i;
|
||||||
|
pr[0] = u.ui32[0];
|
||||||
|
pr[1] = u.ui32[2];
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef USE_AVX
|
||||||
|
|
||||||
|
FORCE_INLINE
|
||||||
|
//NO_INLINE
|
||||||
|
void resolve(const FVec<AVX, float>& vz, const IVec<AVX, float>& bidx, uint32 *pr) const
|
||||||
|
{
|
||||||
|
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||||
|
const float *xi = base_t::data.xi;
|
||||||
|
|
||||||
|
#if 0 // use gather instructions
|
||||||
|
|
||||||
|
IVec<AVX,float> idxm;
|
||||||
|
idxm.setidx(buckets, bidx);
|
||||||
|
__m256i z = _mm256_setzero_si256();
|
||||||
|
IVec<AVX,float> minusone = _mm256_cmpeq_epi32(z,z);
|
||||||
|
IVec<AVX,float> idxp = idxm - minusone;
|
||||||
|
|
||||||
|
FVec<AVX, float> vxm = _mm256_i32gather_ps(xi, idxm, sizeof(float));
|
||||||
|
FVec<AVX, float> vxp = _mm256_i32gather_ps(xi, idxp, sizeof(float));
|
||||||
|
IVec<AVX, float> ip = idxm;
|
||||||
|
|
||||||
|
#else // do not use gather instrucions
|
||||||
|
|
||||||
|
union U {
|
||||||
|
__m256i vec;
|
||||||
|
uint32 ui32[8];
|
||||||
|
} u;
|
||||||
|
|
||||||
|
// read indices t
|
||||||
|
|
||||||
|
const double *p7 = reinterpret_cast<const double *>(&xi[(u.ui32[7] = buckets[bidx.get7()])]);
|
||||||
|
const double *p6 = reinterpret_cast<const double *>(&xi[(u.ui32[6] = buckets[bidx.get6()])]);
|
||||||
|
const double *p5 = reinterpret_cast<const double *>(&xi[(u.ui32[5] = buckets[bidx.get5()])]);
|
||||||
|
const double *p4 = reinterpret_cast<const double *>(&xi[(u.ui32[4] = buckets[bidx.get4()])]);
|
||||||
|
const double *p3 = reinterpret_cast<const double *>(&xi[(u.ui32[3] = buckets[bidx.get3()])]);
|
||||||
|
const double *p2 = reinterpret_cast<const double *>(&xi[(u.ui32[2] = buckets[bidx.get2()])]);
|
||||||
|
const double *p1 = reinterpret_cast<const double *>(&xi[(u.ui32[1] = buckets[bidx.get1()])]);
|
||||||
|
const double *p0 = reinterpret_cast<const double *>(&xi[(u.ui32[0] = buckets[bidx.get0()])]);
|
||||||
|
|
||||||
|
#if 0 // perform 8 loads in double precision
|
||||||
|
|
||||||
|
// read pairs ( X(t-1), X(t) )
|
||||||
|
__m128 xp7 = _mm_castpd_ps(_mm_load_sd(p7));
|
||||||
|
__m128 xp6 = _mm_castpd_ps(_mm_load_sd(p6));
|
||||||
|
__m128 xp5 = _mm_castpd_ps(_mm_load_sd(p5));
|
||||||
|
__m128 xp4 = _mm_castpd_ps(_mm_load_sd(p4));
|
||||||
|
__m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3));
|
||||||
|
__m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2));
|
||||||
|
__m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1));
|
||||||
|
__m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0));
|
||||||
|
|
||||||
|
// build:
|
||||||
|
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
|
||||||
|
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
|
||||||
|
__m128 h57 = _mm_shuffle_ps(xp5, xp7, (1 << 2) + (1 << 6)); // F- F+ H- H+
|
||||||
|
__m128 h46 = _mm_shuffle_ps(xp4, xp6, (1 << 2) + (1 << 6)); // E- E+ G- G+
|
||||||
|
__m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6)); // B- B+ D- D+
|
||||||
|
__m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6)); // A- A+ C- C+
|
||||||
|
|
||||||
|
__m128 u01 = _mm_unpacklo_ps(h02, h13); // A- B- A+ B+
|
||||||
|
__m128 u23 = _mm_unpackhi_ps(h02, h13); // C- D- C+ D+
|
||||||
|
__m128 u45 = _mm_unpacklo_ps(h46, h57); // E- F- E+ F+
|
||||||
|
__m128 u67 = _mm_unpackhi_ps(h46, h57); // G- H- G+ H+
|
||||||
|
|
||||||
|
__m128 abcdm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // A- B- C- D-
|
||||||
|
__m128 abcdp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // A+ B+ C+ D+
|
||||||
|
__m128 efghm = _mm_shuffle_ps(u45, u67, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // E- F- G- H-
|
||||||
|
__m128 efghp = _mm_shuffle_ps(u45, u67, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // E+ F+ G+ H+
|
||||||
|
|
||||||
|
FVec<AVX, float> vxp = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdm), efghm, 1);
|
||||||
|
FVec<AVX, float> vxm = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdp), efghp, 1);
|
||||||
|
|
||||||
|
IVec<AVX, float> ip(u.vec);
|
||||||
|
|
||||||
|
#else // use __mm256_set_pd
|
||||||
|
|
||||||
|
// read pairs ( X(t-1), X(t) )
|
||||||
|
__m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) }
|
||||||
|
__m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) }
|
||||||
|
|
||||||
|
// { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) }
|
||||||
|
FVec<AVX, float> vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6) );
|
||||||
|
// { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) }
|
||||||
|
FVec<AVX, float> vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6) );
|
||||||
|
|
||||||
|
IVec<AVX, float> ip(u.vec);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
IVec<AVX, float> vlem = vz < vxm;
|
||||||
|
IVec<AVX, float> vlep = vz < vxp;
|
||||||
|
ip = ip + vlem + vlep;
|
||||||
|
|
||||||
|
ip.store(pr);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
FORCE_INLINE
|
||||||
|
//NO_INLINE
|
||||||
|
void resolve(const FVec<AVX, double>& vz, const IVec<SSE, float>& bidx, uint32 *pr) const
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
__m256i vec;
|
||||||
|
uint64 ui64[4];
|
||||||
|
} u;
|
||||||
|
|
||||||
|
const uint32* buckets = reinterpret_cast<const uint32 *>(base_t::data.buckets);
|
||||||
|
const double *xi = base_t::data.xi;
|
||||||
|
|
||||||
|
// read indices t
|
||||||
|
const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])];
|
||||||
|
const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])];
|
||||||
|
const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])];
|
||||||
|
const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])];
|
||||||
|
|
||||||
|
// read pairs ( X(t-1), X(t) )
|
||||||
|
__m128d xp3 = _mm_loadu_pd(p3);
|
||||||
|
__m128d xp2 = _mm_loadu_pd(p2);
|
||||||
|
__m128d xp1 = _mm_loadu_pd(p1);
|
||||||
|
__m128d xp0 = _mm_loadu_pd(p0);
|
||||||
|
|
||||||
|
// build:
|
||||||
|
// { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) }
|
||||||
|
// { X(t(0)), X(t(1)), X(t(2)), X(t(3)) }
|
||||||
|
__m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1);
|
||||||
|
__m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1);
|
||||||
|
FVec<AVX, double> vxm = _mm256_unpacklo_pd(x02,x13);
|
||||||
|
FVec<AVX, double> vxp = _mm256_unpackhi_pd(x02,x13);
|
||||||
|
|
||||||
|
|
||||||
|
// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0);
|
||||||
|
// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0);
|
||||||
|
// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3);
|
||||||
|
// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3);
|
||||||
|
// FVec<AVX, double> vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1);
|
||||||
|
// FVec<AVX, double> vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1);
|
||||||
|
|
||||||
|
IVec<AVX, double> i(u.vec);
|
||||||
|
IVec<AVX, double> vlem = vz < vxm;
|
||||||
|
IVec<AVX, double> vlep = vz < vxp;
|
||||||
|
i = i + vlem + vlep;
|
||||||
|
i.extractLo32s().store(pr);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
public:
|
||||||
|
|
||||||
|
AlgoVecBase(const T* x, const uint32 n) : base_t(x, n) {}
|
||||||
|
|
||||||
|
void initConstants(Constants& cst) const
|
||||||
|
{
|
||||||
|
cst.vscaler.setN(base_t::data.scaler);
|
||||||
|
cst.vcst0.setN(base_t::data.cst0);
|
||||||
|
cst.one.setN(uint32(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
|
||||||
|
{
|
||||||
|
fVec vz(pz);
|
||||||
|
resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace Details
|
||||||
|
} // namespace BinSearch
|
23
include/AlgoXCodes.h
Normal file
23
include/AlgoXCodes.h
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
ALGOENUM(DirectCacheFMA, 5)
|
||||||
|
ALGOENUM(DirectFMA, 15)
|
||||||
|
ALGOENUM(Direct2FMA, 25)
|
||||||
|
ALGOENUM(DirectCache, 10)
|
||||||
|
ALGOENUM(Direct, 20)
|
||||||
|
ALGOENUM(Direct2, 30)
|
||||||
|
ALGOENUM(Nonary, 40)
|
||||||
|
ALGOENUM(Pentary, 50)
|
||||||
|
ALGOENUM(Ternary, 60)
|
||||||
|
ALGOENUM(Eytzinger, 70)
|
||||||
|
ALGOENUM(BitSet, 80)
|
||||||
|
ALGOENUM(ClassicOffset, 90)
|
||||||
|
#ifdef PAPER_TEST
|
||||||
|
ALGOENUM(MorinOffset, 100)
|
||||||
|
ALGOENUM(BitSetNoPad, 110)
|
||||||
|
ALGOENUM(ClassicMod, 120)
|
||||||
|
ALGOENUM(MorinBranchy, 130)
|
||||||
|
ALGOENUM(Classic, 140)
|
||||||
|
ALGOENUM(LowerBound, 145)
|
||||||
|
#ifdef USE_MKL
|
||||||
|
ALGOENUM(MKL, 150)
|
||||||
|
#endif
|
||||||
|
#endif
|
77
include/BinAlgo.h
Normal file
77
include/BinAlgo.h
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Type.h"
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
namespace BinSearch {
|
||||||
|
|
||||||
|
template <InstrSet I, typename T, Algos A, bool L=false, bool R=false>
|
||||||
|
struct BinAlgo : Details::BinAlgoBase<I,T,A>
|
||||||
|
{
|
||||||
|
typedef Details::BinAlgoBase<I,T,A> base_t;
|
||||||
|
|
||||||
|
BinAlgo(const T* px, const uint32 n) : base_t(px, n), x0(px[0]), xN(px[n-1]), N(n) {}
|
||||||
|
BinAlgo(const T* px, const uint32 n, const typename base_t::Data& d) : base_t(d), x0(px[0]), xN(px[n-1]), N(n) {}
|
||||||
|
|
||||||
|
FORCE_INLINE
|
||||||
|
uint32 scalar(T z) const
|
||||||
|
{
|
||||||
|
if (!L || z >= x0)
|
||||||
|
if (!R || z < xN)
|
||||||
|
return base_t::scalar(z);
|
||||||
|
else
|
||||||
|
return N;
|
||||||
|
else
|
||||||
|
return std::numeric_limits<uint32>::max();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
FORCE_INLINE
|
||||||
|
void vectorial(uint32 *pr, const T *pz, uint32 n) const
|
||||||
|
{
|
||||||
|
if (!L && !R) {
|
||||||
|
Details::Loop<T,base_t>::loop(*this, pr, pz, n);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
const uint32 nElem = base_t::nElem;
|
||||||
|
const uint32 idealbufsize = 256;
|
||||||
|
const uint32 bufsize = nElem * (idealbufsize / nElem + ((idealbufsize % nElem) ? 1 : 0));
|
||||||
|
T databuf[bufsize];
|
||||||
|
uint32 resbuf[bufsize];
|
||||||
|
uint32 indexbuf[bufsize];
|
||||||
|
|
||||||
|
uint32 *prend = pr + n;
|
||||||
|
while(pr != prend) {
|
||||||
|
uint32 cnt = 0;
|
||||||
|
uint32 niter = std::min(bufsize, (uint32)std::distance(pr,prend));
|
||||||
|
for (uint32 j = 0; j < niter; ++j) {
|
||||||
|
T z = pz[j];
|
||||||
|
// FIXME: use SSE2?
|
||||||
|
if (!L || z >= x0)
|
||||||
|
if (!R || z < xN) {
|
||||||
|
databuf[cnt] = z;
|
||||||
|
indexbuf[cnt] = j;
|
||||||
|
++cnt;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
pr[j] = N;
|
||||||
|
else
|
||||||
|
pr[j] = std::numeric_limits<uint32>::max();
|
||||||
|
}
|
||||||
|
// FIXME: merge these two loops
|
||||||
|
Details::Loop<T,base_t>::loop(*this, resbuf, databuf, cnt);
|
||||||
|
for (uint32 j = 0; j < cnt; ++j)
|
||||||
|
pr[indexbuf[j]] = resbuf[j];
|
||||||
|
pr += niter;
|
||||||
|
pz += niter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Details::CondData<T,L> x0;
|
||||||
|
Details::CondData<T,R> xN;
|
||||||
|
Details::CondData<uint32,R> N;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // namespace BinSearch
|
11
include/BinSearch.h
Normal file
11
include/BinSearch.h
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "AAlloc.h"
|
||||||
|
#include "BinAlgo.h"
|
||||||
|
#include "SIMD.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
|
||||||
|
#include "Algo-Direct2.h"
|
151
include/Portable.h
Normal file
151
include/Portable.h
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
#pragma once
|
||||||
|
#include <limits>
|
||||||
|
#include <cmath>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
#ifdef __FMA__
|
||||||
|
#define USE_FMA
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __AVX2__
|
||||||
|
#define USE_AVX2
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __AVX__
|
||||||
|
#define USE_AVX
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#ifdef __SSE4_1__
|
||||||
|
#define USE_SSE41
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef __SSE4_2__
|
||||||
|
#define USE_SSE42
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
#include <stdint.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace BinSearch {
|
||||||
|
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
typedef int8_t int8;
|
||||||
|
typedef uint8_t uint8;
|
||||||
|
typedef int32_t int32;
|
||||||
|
typedef uint32_t uint32;
|
||||||
|
typedef int64_t int64;
|
||||||
|
typedef uint64_t uint64;
|
||||||
|
#else
|
||||||
|
typedef __int8 int8;
|
||||||
|
typedef unsigned __int8 uint8;
|
||||||
|
typedef __int32 int32;
|
||||||
|
typedef unsigned __int32 uint32;
|
||||||
|
typedef __int64 int64;
|
||||||
|
typedef unsigned __int64 uint64;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace Details {
|
||||||
|
|
||||||
|
#define myassert(cond, msg) if (!cond){ std::ostringstream os; os << "\nassertion failed: " << #cond << ", " << msg << "\n"; throw std::invalid_argument(os.str()); }
|
||||||
|
|
||||||
|
// log2 is not defined in VS2008
|
||||||
|
#if defined(_MSC_VER)
|
||||||
|
inline uint32 log2 (uint32 val) {
|
||||||
|
if (val == 1) return 0;
|
||||||
|
uint32 ret = 0;
|
||||||
|
do {
|
||||||
|
ret++;
|
||||||
|
val >>= 1;
|
||||||
|
} while (val > 1);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _DEBUG
|
||||||
|
#define DEBUG
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
# define FORCE_INLINE __forceinline
|
||||||
|
# define NO_INLINE __declspec(noinline)
|
||||||
|
#else
|
||||||
|
# define NO_INLINE __attribute__((noinline))
|
||||||
|
# ifdef DEBUG
|
||||||
|
# define FORCE_INLINE NO_INLINE
|
||||||
|
# else
|
||||||
|
# define FORCE_INLINE __attribute__((always_inline)) inline
|
||||||
|
# endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef USE_AVX
|
||||||
|
#define COMISS "vcomiss"
|
||||||
|
#define COMISD "vcomisd"
|
||||||
|
#else
|
||||||
|
#define COMISS "comiss"
|
||||||
|
#define COMISD "comisd"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// nextafter is not defined in VS2008
|
||||||
|
#if defined(_MSC_VER) && (_MSC_VER <= 1500)
|
||||||
|
#include <float.h>
|
||||||
|
inline float mynext(float x)
|
||||||
|
{
|
||||||
|
return _nextafterf(x, std::numeric_limits<float>::max());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline double mynext(double x)
|
||||||
|
{
|
||||||
|
return _nextafter(x, std::numeric_limits<double>::max());
|
||||||
|
}
|
||||||
|
inline float myprev(float x)
|
||||||
|
{
|
||||||
|
return _nextafterf(x, -std::numeric_limits<float>::max());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline double myprev(double x)
|
||||||
|
{
|
||||||
|
return _nextafter(x, -std::numeric_limits<double>::max());
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
inline float mynext(float x)
|
||||||
|
{
|
||||||
|
return std::nextafterf(x, std::numeric_limits<float>::max());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline double mynext(double x)
|
||||||
|
{
|
||||||
|
return std::nextafter(x, std::numeric_limits<double>::max());
|
||||||
|
}
|
||||||
|
inline float myprev(float x)
|
||||||
|
{
|
||||||
|
return std::nextafterf(x, -std::numeric_limits<float>::max());
|
||||||
|
}
|
||||||
|
|
||||||
|
inline double myprev(double x)
|
||||||
|
{
|
||||||
|
return std::nextafter(x, -std::numeric_limits<double>::max());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T next(T x)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < 4; ++i)
|
||||||
|
x = mynext(x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T prev(T x)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < 4; ++i)
|
||||||
|
x = myprev(x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namepsace Details
|
||||||
|
} // namespace BinSearch
|
562
include/SIMD.h
Normal file
562
include/SIMD.h
Normal file
|
@ -0,0 +1,562 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Portable.h"
|
||||||
|
|
||||||
|
#ifdef USE_SSE42
|
||||||
|
#ifndef _MSC_VER
|
||||||
|
#include <popcntintrin.h>
|
||||||
|
#define popcnt32 _mm_popcnt_u32
|
||||||
|
#else
|
||||||
|
#include <intrin.h>
|
||||||
|
#define popcnt32 __popcnt
|
||||||
|
#endif
|
||||||
|
#else // USE_SSE42
|
||||||
|
namespace BinSearch {
|
||||||
|
FORCE_INLINE int popcnt32(int x32)
|
||||||
|
{
|
||||||
|
// strictly speaking this is not correct, as it ignores higher order bits
|
||||||
|
// however, this is only used on the resuot of movemask on a 128-bit register, which is 8 at most, so it is ok
|
||||||
|
// with 256-bit registers, SSE42 is defined, and we do not use this function
|
||||||
|
uint8 x = static_cast<uint8>(x32);
|
||||||
|
x = (x & 0x55) + (x >> 1 & 0x55);
|
||||||
|
x = (x & 0x33) + (x >> 2 & 0x33);
|
||||||
|
x = (x & 0x0f) + (x >> 4 & 0x0f);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined(USE_AVX) || defined(USE_AVX2)
|
||||||
|
#include <immintrin.h>
|
||||||
|
#else
|
||||||
|
#include <emmintrin.h>
|
||||||
|
#ifdef USE_SSE41
|
||||||
|
#include <smmintrin.h>
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "Type.h"
|
||||||
|
|
||||||
|
namespace BinSearch {
|
||||||
|
namespace Details {
|
||||||
|
|
||||||
|
template <InstrSet I, class T>
|
||||||
|
struct FVec;
|
||||||
|
|
||||||
|
template <InstrSet I, class T>
|
||||||
|
struct IVec;
|
||||||
|
|
||||||
|
template <InstrSet I, class T>
|
||||||
|
struct FVec1;
|
||||||
|
|
||||||
|
template <> struct InstrIntTraits<SSE>
|
||||||
|
{
|
||||||
|
typedef __m128i vec_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct InstrFloatTraits<SSE, float>
|
||||||
|
{
|
||||||
|
typedef __m128 vec_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct InstrFloatTraits<SSE, double>
|
||||||
|
{
|
||||||
|
typedef __m128d vec_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <InstrSet I, typename T>
|
||||||
|
struct FTOITraits
|
||||||
|
{
|
||||||
|
typedef IVec<SSE, float> vec_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifdef USE_AVX
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct FTOITraits<AVX, float>
|
||||||
|
{
|
||||||
|
typedef IVec<AVX, float> vec_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct InstrIntTraits<AVX>
|
||||||
|
{
|
||||||
|
typedef __m256i vec_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct InstrFloatTraits<AVX, float>
|
||||||
|
{
|
||||||
|
typedef __m256 vec_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <> struct InstrFloatTraits<AVX, double>
|
||||||
|
{
|
||||||
|
typedef __m256d vec_t;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
template <typename TR>
|
||||||
|
struct VecStorage
|
||||||
|
{
|
||||||
|
typedef typename TR::vec_t vec_t;
|
||||||
|
|
||||||
|
FORCE_INLINE operator vec_t&() { return vec; }
|
||||||
|
FORCE_INLINE operator const vec_t&() const { return vec; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
FORCE_INLINE VecStorage() {}
|
||||||
|
FORCE_INLINE VecStorage(const vec_t& v) : vec( v ) {}
|
||||||
|
|
||||||
|
vec_t vec;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <InstrSet>
|
||||||
|
struct IVecBase;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IVecBase<SSE> : VecStorage<InstrIntTraits<SSE>>
|
||||||
|
{
|
||||||
|
protected:
|
||||||
|
FORCE_INLINE IVecBase() {}
|
||||||
|
FORCE_INLINE IVecBase( const vec_t& v) : VecStorage<InstrIntTraits<SSE>>( v ) {}
|
||||||
|
public:
|
||||||
|
FORCE_INLINE static vec_t zero() { return _mm_setzero_si128(); }
|
||||||
|
|
||||||
|
FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32( vec ); }
|
||||||
|
|
||||||
|
FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask )
|
||||||
|
{
|
||||||
|
#ifdef USE_SSE41
|
||||||
|
vec = _mm_blendv_epi8(vec, val, mask);
|
||||||
|
#else
|
||||||
|
vec = _mm_or_si128(_mm_andnot_si128(mask,vec), _mm_and_si128(mask,val));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask)
|
||||||
|
{
|
||||||
|
vec = _mm_or_si128(vec, _mm_and_si128(val,mask));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IVec<SSE, float> : IVecBase<SSE>
|
||||||
|
{
|
||||||
|
FORCE_INLINE IVec() {}
|
||||||
|
FORCE_INLINE IVec( int32 i ) : IVecBase<SSE>( _mm_set1_epi32( i ) ) {}
|
||||||
|
FORCE_INLINE IVec( const vec_t& v) : IVecBase<SSE>( v ) {}
|
||||||
|
FORCE_INLINE IVec( uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase<SSE>( _mm_set_epi32( u3, u2, u1, u0 ) ) {}
|
||||||
|
|
||||||
|
void setN( int32 i ) { vec = _mm_set1_epi32( i ); }
|
||||||
|
|
||||||
|
#ifdef USE_SSE41
|
||||||
|
FORCE_INLINE int32 get1() const { return _mm_extract_epi32(vec, 1); }
|
||||||
|
FORCE_INLINE int32 get2() const { return _mm_extract_epi32(vec, 2); }
|
||||||
|
FORCE_INLINE int32 get3() const { return _mm_extract_epi32(vec, 3); }
|
||||||
|
#else
|
||||||
|
FORCE_INLINE int32 get1() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 1 ) ); }
|
||||||
|
FORCE_INLINE int32 get2() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) ); }
|
||||||
|
FORCE_INLINE int32 get3() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 3 ) ); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
FORCE_INLINE void store( uint32 *pi ) const { _mm_storeu_si128( reinterpret_cast<vec_t*>(pi), vec ); }
|
||||||
|
|
||||||
|
FORCE_INLINE int countbit()
|
||||||
|
{
|
||||||
|
return popcnt32(_mm_movemask_ps(_mm_castsi128_ps(vec)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IVec<SSE, double> : IVecBase<SSE>
|
||||||
|
{
|
||||||
|
FORCE_INLINE IVec() {}
|
||||||
|
FORCE_INLINE IVec( int32 i ) : IVecBase<SSE>( _mm_set1_epi64x( i ) ) {}
|
||||||
|
FORCE_INLINE IVec( const vec_t& v) : IVecBase<SSE>( v ) {}
|
||||||
|
FORCE_INLINE IVec( uint64 u1, uint64 u0 ) : IVecBase<SSE>( _mm_set_epi64x(u1, u0) ) {}
|
||||||
|
|
||||||
|
void setN( int32 i ) { vec = _mm_set1_epi64x( i ); }
|
||||||
|
|
||||||
|
FORCE_INLINE int32 get1() const
|
||||||
|
{
|
||||||
|
#ifdef USE_SSE41
|
||||||
|
return _mm_extract_epi32(vec, 2);
|
||||||
|
#else
|
||||||
|
return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) );
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the 2 32 bits integers no. 0, 2 and store them in a __m128i
|
||||||
|
FORCE_INLINE IVec<SSE,float> extractLo32s() const
|
||||||
|
{
|
||||||
|
return _mm_shuffle_epi32(vec, ((2 << 2) | 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE void store( uint32 *pi ) const
|
||||||
|
{
|
||||||
|
pi[0] = get0();
|
||||||
|
pi[1] = get1();
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE int countbit()
|
||||||
|
{
|
||||||
|
#if 1
|
||||||
|
// takes 4 cycles
|
||||||
|
__m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle
|
||||||
|
__m128i s = _mm_add_epi32(vec, hi);
|
||||||
|
int32 x = _mm_cvtsi128_si32(s);
|
||||||
|
return -x;
|
||||||
|
#else
|
||||||
|
// takes 6 cycles
|
||||||
|
return popcnt32(_mm_movemask_pd(_mm_castsi128_pd(vec)));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<SSE,T> operator>> (const IVec<SSE,T>& a, unsigned n) { return _mm_srli_epi32(a, n); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<SSE,T> operator<< (const IVec<SSE,T>& a, unsigned n) { return _mm_slli_epi32(a, n); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<SSE,T> operator& (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_and_si128( a, b ); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<SSE,T> operator| (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_or_si128( a, b ); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<SSE,T> operator^ (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_xor_si128( a, b ); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<SSE,T> operator+ (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_add_epi32( a, b ); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<SSE,T> operator- (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_sub_epi32( a, b ); }
|
||||||
|
#ifdef USE_SSE41
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<SSE,T> min (const IVec<SSE,T>& a, const IVec<SSE,T>& b ) { return _mm_min_epi32( a, b ); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef VecStorage<InstrFloatTraits<SSE,float>> FVec128Float;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct FVec1<SSE, float> : FVec128Float
|
||||||
|
{
|
||||||
|
FORCE_INLINE FVec1() {}
|
||||||
|
FORCE_INLINE FVec1( float f ) : FVec128Float( _mm_load_ss( &f ) ) {}
|
||||||
|
FORCE_INLINE FVec1( const vec_t& v ): FVec128Float( v ) {}
|
||||||
|
|
||||||
|
FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct FVec<SSE, float> : FVec128Float
|
||||||
|
{
|
||||||
|
FORCE_INLINE FVec() {}
|
||||||
|
FORCE_INLINE FVec( float f ) : FVec128Float( _mm_set1_ps( f ) ) {}
|
||||||
|
FORCE_INLINE FVec( const float *v ) : FVec128Float( _mm_loadu_ps( v ) ) {}
|
||||||
|
FORCE_INLINE FVec( const vec_t& v) : FVec128Float(v) {}
|
||||||
|
FORCE_INLINE FVec( float f3, float f2, float f1, float f0 ) : FVec128Float( _mm_set_ps(f3, f2, f1, f0) ) {}
|
||||||
|
|
||||||
|
void set0( float f ) { vec = _mm_load_ss( &f ); }
|
||||||
|
void setN( float f ) { vec = _mm_set1_ps( f ); }
|
||||||
|
|
||||||
|
FORCE_INLINE void setidx( const float *xi, const IVec<SSE,float>& idx )
|
||||||
|
{
|
||||||
|
uint32 i0 = idx.get0();
|
||||||
|
uint32 i1 = idx.get1();
|
||||||
|
uint32 i2 = idx.get2();
|
||||||
|
uint32 i3 = idx.get3();
|
||||||
|
vec = _mm_set_ps( xi[i3], xi[i2], xi[i1], xi[i0] );
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); }
|
||||||
|
FORCE_INLINE float get1() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 1 ) ); }
|
||||||
|
FORCE_INLINE float get2() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 2 ) ); }
|
||||||
|
FORCE_INLINE float get3() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 3 ) ); }
|
||||||
|
};
|
||||||
|
|
||||||
|
FORCE_INLINE FVec1<SSE,float> operator+ (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_add_ss( a, b ); }
|
||||||
|
FORCE_INLINE FVec1<SSE,float> operator- (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_sub_ss( a, b ); }
|
||||||
|
FORCE_INLINE FVec1<SSE,float> operator* (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_mul_ss( a, b ); }
|
||||||
|
FORCE_INLINE FVec1<SSE,float> operator/ (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_div_ss( a, b ); }
|
||||||
|
FORCE_INLINE int ftoi (const FVec1<SSE,float>& a) { return _mm_cvttss_si32(a); }
|
||||||
|
FORCE_INLINE IVec<SSE,float> operator> (const FVec1<SSE,float>& a, const FVec1<SSE,float>& b) { return _mm_castps_si128( _mm_cmpgt_ss( a, b ) ); }
|
||||||
|
#ifdef USE_FMA
|
||||||
|
FORCE_INLINE FVec1<SSE, float> mulSub(const FVec1<SSE, float>& a, const FVec1<SSE, float>& b, const FVec1<SSE, float>& c) { return _mm_fmsub_ss(a, b, c); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
FORCE_INLINE FVec<SSE,float> operator- (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_sub_ps( a, b ); }
|
||||||
|
FORCE_INLINE FVec<SSE,float> operator* (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_mul_ps( a, b ); }
|
||||||
|
FORCE_INLINE FVec<SSE,float> operator/ (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_div_ps( a, b ); }
|
||||||
|
FORCE_INLINE IVec<SSE,float> ftoi (const FVec<SSE,float>& a) { return _mm_cvttps_epi32(a); }
|
||||||
|
FORCE_INLINE IVec<SSE,float> operator<= (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128( _mm_cmple_ps( a, b ) ); }
|
||||||
|
FORCE_INLINE IVec<SSE,float> operator>= (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128( _mm_cmpge_ps( a, b ) ); }
|
||||||
|
FORCE_INLINE IVec<SSE,float> operator< (const FVec<SSE,float>& a, const FVec<SSE,float>& b) { return _mm_castps_si128(_mm_cmplt_ps(a, b)); }
|
||||||
|
#ifdef USE_FMA
|
||||||
|
FORCE_INLINE FVec<SSE, float> mulSub(const FVec<SSE, float>& a, const FVec<SSE, float>& b, const FVec<SSE, float>& c) { return _mm_fmsub_ps(a, b, c); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef VecStorage<InstrFloatTraits<SSE,double>> FVec128Double;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct FVec1<SSE, double> : FVec128Double
|
||||||
|
{
|
||||||
|
FORCE_INLINE FVec1() {}
|
||||||
|
FORCE_INLINE FVec1( double f ) : FVec128Double( _mm_load_sd( &f ) ) {}
|
||||||
|
FORCE_INLINE FVec1( const vec_t& v ) : FVec128Double( v ) {}
|
||||||
|
|
||||||
|
FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct FVec<SSE, double> : FVec128Double
|
||||||
|
{
|
||||||
|
FORCE_INLINE FVec() {}
|
||||||
|
FORCE_INLINE FVec( double d ) : FVec128Double( _mm_set1_pd( d ) ) {}
|
||||||
|
FORCE_INLINE FVec( const double *v ) : FVec128Double( _mm_loadu_pd( v ) ) {}
|
||||||
|
FORCE_INLINE FVec( const vec_t& v) : FVec128Double( v ) {}
|
||||||
|
FORCE_INLINE FVec( double f1, double f0 ) : FVec128Double( _mm_set_pd(f1, f0) ) {}
|
||||||
|
|
||||||
|
void set0( double f ) { vec = _mm_load_sd( &f ); }
|
||||||
|
void setN( double f ) { vec = _mm_set1_pd( f ); }
|
||||||
|
|
||||||
|
FORCE_INLINE void setidx( const double *xi, const IVec<SSE,double>& idx )
|
||||||
|
{
|
||||||
|
vec = _mm_set_pd( xi[idx.get1()], xi[idx.get0()] );
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); }
|
||||||
|
FORCE_INLINE double get1() const { return _mm_cvtsd_f64( _mm_shuffle_pd( vec, vec, 1 ) ); };
|
||||||
|
};
|
||||||
|
|
||||||
|
FORCE_INLINE FVec1<SSE,double> operator+ (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_add_sd( a, b ); }
|
||||||
|
FORCE_INLINE FVec1<SSE,double> operator- (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_sub_sd( a, b ); }
|
||||||
|
FORCE_INLINE FVec1<SSE,double> operator* (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_mul_sd( a, b ); }
|
||||||
|
FORCE_INLINE FVec1<SSE,double> operator/ (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_div_sd( a, b ); }
|
||||||
|
FORCE_INLINE int ftoi (const FVec1<SSE,double>& a) { return _mm_cvttsd_si32(a); }
|
||||||
|
FORCE_INLINE IVec<SSE,double> operator> (const FVec1<SSE,double>& a, const FVec1<SSE,double>& b) { return _mm_castpd_si128( _mm_cmpgt_sd( a, b ) ); }
|
||||||
|
#ifdef USE_FMA
|
||||||
|
FORCE_INLINE FVec1<SSE, double> mulSub(const FVec1<SSE, double>& a, const FVec1<SSE, double>& b, const FVec1<SSE, double>& c) { return _mm_fmsub_sd(a, b, c); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
FORCE_INLINE FVec<SSE,double> operator- (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_sub_pd( a, b ); }
|
||||||
|
FORCE_INLINE FVec<SSE,double> operator* (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_mul_pd( a, b ); }
|
||||||
|
FORCE_INLINE FVec<SSE,double> operator/ (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_div_pd( a, b ); }
|
||||||
|
FORCE_INLINE IVec<SSE,float> ftoi (const FVec<SSE,double>& a) { return _mm_cvttpd_epi32(a); }
|
||||||
|
FORCE_INLINE IVec<SSE,double> operator<= (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128( _mm_cmple_pd( a, b ) ); }
|
||||||
|
FORCE_INLINE IVec<SSE,double> operator< (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128(_mm_cmplt_pd(a, b)); }
|
||||||
|
FORCE_INLINE IVec<SSE,double> operator>= (const FVec<SSE,double>& a, const FVec<SSE,double>& b) { return _mm_castpd_si128( _mm_cmpge_pd( a, b ) ); }
|
||||||
|
#ifdef USE_FMA
|
||||||
|
FORCE_INLINE FVec<SSE, double> mulSub(const FVec<SSE, double>& a, const FVec<SSE, double>& b, const FVec<SSE, double>& c ) { return _mm_fmsub_pd(a, b, c); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef USE_AVX
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IVecBase<AVX> : VecStorage<InstrIntTraits<AVX>>
|
||||||
|
{
|
||||||
|
protected:
|
||||||
|
FORCE_INLINE IVecBase() {}
|
||||||
|
FORCE_INLINE IVecBase( const vec_t& v) : VecStorage<InstrIntTraits<AVX>>( v ) {}
|
||||||
|
public:
|
||||||
|
FORCE_INLINE static vec_t zero() { return _mm256_setzero_si256(); }
|
||||||
|
|
||||||
|
FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32(_mm256_castsi256_si128(vec)); }
|
||||||
|
|
||||||
|
FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask ) { vec = _mm256_blendv_epi8(vec, val, mask); }
|
||||||
|
FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask)
|
||||||
|
{
|
||||||
|
vec = _mm256_blendv_epi8(vec, val, mask);
|
||||||
|
//vec = _mm256_or_si256(vec, _mm256_and_si256(val,mask));
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE __m128i lo128() const { return _mm256_castsi256_si128(vec); }
|
||||||
|
FORCE_INLINE __m128i hi128() const { return _mm256_extractf128_si256(vec, 1); }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IVec<AVX, float> : IVecBase<AVX>
|
||||||
|
{
|
||||||
|
FORCE_INLINE IVec() {}
|
||||||
|
FORCE_INLINE IVec( int32 i ) : IVecBase<AVX>( _mm256_set1_epi32( i ) ) {}
|
||||||
|
FORCE_INLINE IVec( const vec_t& v) : IVecBase<AVX>( v ) {}
|
||||||
|
FORCE_INLINE IVec(uint32 u7, uint32 u6, uint32 u5, uint32 u4, uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase<AVX>(_mm256_set_epi32(u7, u6, u5, u4, u3, u2, u1, u0)) {}
|
||||||
|
|
||||||
|
void setN( int32 i ) { vec = _mm256_set1_epi32( i ); }
|
||||||
|
|
||||||
|
FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 1); }
|
||||||
|
FORCE_INLINE int32 get2() const { return _mm256_extract_epi32(vec, 2); }
|
||||||
|
FORCE_INLINE int32 get3() const { return _mm256_extract_epi32(vec, 3); }
|
||||||
|
FORCE_INLINE int32 get4() const { return _mm256_extract_epi32(vec, 4); }
|
||||||
|
FORCE_INLINE int32 get5() const { return _mm256_extract_epi32(vec, 5); }
|
||||||
|
FORCE_INLINE int32 get6() const { return _mm256_extract_epi32(vec, 6); }
|
||||||
|
FORCE_INLINE int32 get7() const { return _mm256_extract_epi32(vec, 7); }
|
||||||
|
|
||||||
|
FORCE_INLINE void setidx( const uint32 *bi, const IVec<AVX,float>& idx )
|
||||||
|
{
|
||||||
|
vec = _mm256_i32gather_epi32(reinterpret_cast<const int32 *>(bi), idx, sizeof(uint32));
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE void store( uint32 *pi ) const { _mm256_storeu_si256( reinterpret_cast<vec_t*>(pi), vec ); }
|
||||||
|
|
||||||
|
FORCE_INLINE int countbit()
|
||||||
|
{
|
||||||
|
return popcnt32(_mm256_movemask_ps(_mm256_castsi256_ps(vec)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct IVec<AVX, double> : IVecBase<AVX>
|
||||||
|
{
|
||||||
|
FORCE_INLINE IVec() {}
|
||||||
|
FORCE_INLINE IVec( int32 i ) : IVecBase<AVX>( _mm256_set1_epi64x( i ) ) {}
|
||||||
|
FORCE_INLINE IVec( const vec_t& v) : IVecBase<AVX>( v ) {}
|
||||||
|
FORCE_INLINE IVec(uint64 u3, uint64 u2, uint64 u1, uint64 u0) : IVecBase<AVX>(_mm256_set_epi64x(u3, u2, u1, u0)) {}
|
||||||
|
|
||||||
|
void setN( int32 i ) { vec = _mm256_set1_epi64x( i ); }
|
||||||
|
|
||||||
|
// extract the 4 32 bits integers no. 0, 2, 4, 6 and store them in a __m128i
|
||||||
|
FORCE_INLINE IVec<SSE,float> extractLo32s() const
|
||||||
|
{
|
||||||
|
union {
|
||||||
|
uint32 u32[4];
|
||||||
|
__m128i u;
|
||||||
|
} mask = {0,2,4,6};
|
||||||
|
//__m256 ps256 = _mm256_castsi256_ps(vec);
|
||||||
|
//__m128 lo128 = _mm256_castps256_ps128(ps256);
|
||||||
|
//__m128 hi128 = _mm256_extractf128_ps(ps256, 1);
|
||||||
|
//__m128 blend = _mm_shuffle_ps(lo128, hi128, 0 + (2<<2) + (0<<4) + (2<<6));
|
||||||
|
__m256i blend = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(mask.u));
|
||||||
|
return _mm256_castsi256_si128(blend);
|
||||||
|
}
|
||||||
|
|
||||||
|
//int32 get1() const { return _mm256_cvtsi256_si32( _mm256_shuffle_epi32( vec, 2 ) ); };
|
||||||
|
FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 2); }
|
||||||
|
|
||||||
|
FORCE_INLINE void store( uint32 *pi ) const
|
||||||
|
{
|
||||||
|
extractLo32s().store(pi);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE int countbit()
|
||||||
|
{
|
||||||
|
return popcnt32(_mm256_movemask_pd(_mm256_castsi256_pd(vec)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<AVX,T> operator>> (const IVec<AVX,T>& a, unsigned n) { return _mm256_srli_epi32(a, n); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<AVX,T> operator<< (const IVec<AVX,T>& a, unsigned n) { return _mm256_slli_epi32(a, n); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<AVX,T> operator& (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_and_si256( a, b ); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<AVX,T> operator| (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_or_si256( a, b ); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<AVX,T> operator^ (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_xor_si256( a, b ); }
|
||||||
|
template <typename T>
|
||||||
|
FORCE_INLINE IVec<AVX,T> min (const IVec<AVX,T>& a, const IVec<AVX,T>& b ) { return _mm256_min_epi32( a, b ); }
|
||||||
|
|
||||||
|
FORCE_INLINE IVec<AVX,float> operator+ (const IVec<AVX,float>& a, const IVec<AVX,float>& b ) { return _mm256_add_epi32( a, b ); }
|
||||||
|
FORCE_INLINE IVec<AVX,float> operator- (const IVec<AVX,float>& a, const IVec<AVX,float>& b ) { return _mm256_sub_epi32( a, b ); }
|
||||||
|
FORCE_INLINE IVec<AVX,double> operator+ (const IVec<AVX,double>& a, const IVec<AVX,double>& b ) { return _mm256_add_epi64( a, b ); }
|
||||||
|
FORCE_INLINE IVec<AVX,double> operator- (const IVec<AVX,double>& a, const IVec<AVX,double>& b ) { return _mm256_sub_epi64( a, b ); }
|
||||||
|
|
||||||
|
|
||||||
|
typedef VecStorage<InstrFloatTraits<AVX,float>> FVec256Float;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct FVec<AVX, float> : FVec256Float
|
||||||
|
{
|
||||||
|
FORCE_INLINE FVec() {}
|
||||||
|
FORCE_INLINE FVec( float f ) : FVec256Float( _mm256_set1_ps( f ) ) {}
|
||||||
|
FORCE_INLINE FVec( const float *v ) : FVec256Float( _mm256_loadu_ps( v ) ) {}
|
||||||
|
FORCE_INLINE FVec( const vec_t& v) : FVec256Float(v) {}
|
||||||
|
FORCE_INLINE FVec(float f7, float f6, float f5, float f4, float f3, float f2, float f1, float f0) : FVec256Float(_mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0)) {}
|
||||||
|
|
||||||
|
//void set0( float f ) { vec = _mm256_load_ss( &f ); }
|
||||||
|
void setN( float f ) { vec = _mm256_set1_ps( f ); }
|
||||||
|
|
||||||
|
FORCE_INLINE void setidx( const float *xi, const IVec<AVX,float>& idx )
|
||||||
|
{
|
||||||
|
#if 1 // use gather primitives
|
||||||
|
vec = _mm256_i32gather_ps (xi, idx, 4);
|
||||||
|
#elif 0
|
||||||
|
uint32 i0 = idx.get0();
|
||||||
|
uint32 i1 = idx.get1();
|
||||||
|
uint32 i2 = idx.get2();
|
||||||
|
uint32 i3 = idx.get3();
|
||||||
|
uint32 i4 = idx.get4();
|
||||||
|
uint32 i5 = idx.get5();
|
||||||
|
uint32 i6 = idx.get6();
|
||||||
|
uint32 i7 = idx.get7();
|
||||||
|
vec = _mm256_set_ps( xi[i7], xi[i6], xi[i5], xi[i4], xi[i3], xi[i2], xi[i1], xi[i0] );
|
||||||
|
#else
|
||||||
|
union {
|
||||||
|
__m256i vec;
|
||||||
|
uint32 ui32[8];
|
||||||
|
} i;
|
||||||
|
i.vec = static_cast<const __m256i&>(idx);
|
||||||
|
vec = _mm256_set_ps(xi[i.ui32[7]], xi[i.ui32[6]], xi[i.ui32[5]], xi[i.ui32[4]], xi[i.ui32[3]], xi[i.ui32[2]], xi[i.ui32[1]], xi[i.ui32[0]]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE FVec<SSE, float> lo128() const { return _mm256_castps256_ps128(vec); }
|
||||||
|
FORCE_INLINE FVec<SSE, float> hi128() const { return _mm256_extractf128_ps(vec, 1); }
|
||||||
|
|
||||||
|
//FORCE_INLINE float get0() const { return _mm256_cvtss_f32( vec ); }
|
||||||
|
//FORCE_INLINE float get1() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 1 ) ); }
|
||||||
|
//FORCE_INLINE float get2() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 2 ) ); }
|
||||||
|
//FORCE_INLINE float get3() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 3 ) ); }
|
||||||
|
};
|
||||||
|
|
||||||
|
FORCE_INLINE FVec<AVX,float> operator- (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_sub_ps( a, b ); }
|
||||||
|
FORCE_INLINE FVec<AVX,float> operator* (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_mul_ps( a, b ); }
|
||||||
|
FORCE_INLINE FVec<AVX,float> operator/ (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_div_ps( a, b ); }
|
||||||
|
FORCE_INLINE IVec<AVX,float> ftoi (const FVec<AVX,float>& a) { return _mm256_cvttps_epi32(a); }
|
||||||
|
FORCE_INLINE IVec<AVX,float> operator<= (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_LE_OS) ); }
|
||||||
|
FORCE_INLINE IVec<AVX,float> operator>= (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_GE_OS ) ); }
|
||||||
|
FORCE_INLINE IVec<AVX,float> operator< (const FVec<AVX,float>& a, const FVec<AVX,float>& b) { return _mm256_castps_si256(_mm256_cmp_ps(a, b, _CMP_LT_OS )); }
|
||||||
|
#ifdef USE_FMA
|
||||||
|
FORCE_INLINE FVec<AVX, float> mulSub(const FVec<AVX, float>& a, const FVec<AVX, float>& b, const FVec<AVX, float>& c) { return _mm256_fmsub_ps(a, b, c); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef VecStorage<InstrFloatTraits<AVX,double>> FVec256Double;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct FVec<AVX, double> : FVec256Double
|
||||||
|
{
|
||||||
|
FORCE_INLINE FVec() {}
|
||||||
|
FORCE_INLINE FVec( double d ) : FVec256Double( _mm256_set1_pd( d ) ) {}
|
||||||
|
FORCE_INLINE FVec( const double *v ) : FVec256Double( _mm256_loadu_pd( v ) ) {}
|
||||||
|
FORCE_INLINE FVec( const vec_t& v) : FVec256Double( v ) {}
|
||||||
|
FORCE_INLINE FVec(double d3, double d2, double d1, double d0) : FVec256Double(_mm256_set_pd(d3, d2, d1, d0)) {}
|
||||||
|
|
||||||
|
//void set0( double f ) { vec = _mm256_load_sd( &f ); }
|
||||||
|
void setN( double f ) { vec = _mm256_set1_pd( f ); }
|
||||||
|
|
||||||
|
FORCE_INLINE void setidx( const double *xi, const IVec<SSE,float>& idx )
|
||||||
|
{
|
||||||
|
vec = _mm256_i32gather_pd(xi, idx, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE void setidx( const double *xi, const IVec<AVX,double>& idx )
|
||||||
|
{
|
||||||
|
vec = _mm256_i64gather_pd(xi, idx, 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
// FORCE_INLINE double get0() const { return _mm256_cvtsd_f64( vec ); }
|
||||||
|
// FORCE_INLINE double get1() const { return _mm256_cvtsd_f64( _mm256_shuffle_pd( vec, vec, 1 ) ); };
|
||||||
|
};
|
||||||
|
|
||||||
|
FORCE_INLINE FVec<AVX,double> operator- (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_sub_pd( a, b ); }
|
||||||
|
FORCE_INLINE FVec<AVX,double> operator* (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_mul_pd( a, b ); }
|
||||||
|
FORCE_INLINE FVec<AVX,double> operator/ (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_div_pd( a, b ); }
|
||||||
|
FORCE_INLINE IVec<SSE,float> ftoi (const FVec<AVX,double>& a) { return _mm256_cvttpd_epi32(a); }
|
||||||
|
FORCE_INLINE IVec<AVX,double> operator<= (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_LE_OS ) ); }
|
||||||
|
FORCE_INLINE IVec<AVX,double> operator< (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd(a, b, _CMP_LT_OS)); }
|
||||||
|
FORCE_INLINE IVec<AVX,double> operator>= (const FVec<AVX,double>& a, const FVec<AVX,double>& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_GE_OS ) ); }
|
||||||
|
#ifdef USE_FMA
|
||||||
|
FORCE_INLINE FVec<AVX, double> mulSub(const FVec<AVX, double>& a, const FVec<AVX, double>& b, const FVec<AVX, double>& c) { return _mm256_fmsub_pd(a, b, c); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namepsace Details
|
||||||
|
} // namespace BinSearch
|
221
include/Type.h
Normal file
221
include/Type.h
Normal file
|
@ -0,0 +1,221 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <vector>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "Portable.h"
|
||||||
|
|
||||||
|
using std::size_t;
|
||||||
|
|
||||||
|
namespace BinSearch {
|
||||||
|
|
||||||
|
enum InstrSet { Scalar, SSE, AVX };
|
||||||
|
|
||||||
|
#define ALGOENUM(x, b) x,
|
||||||
|
enum Algos
|
||||||
|
{
|
||||||
|
#include "AlgoXCodes.h"
|
||||||
|
};
|
||||||
|
#undef ALGOENUM
|
||||||
|
|
||||||
|
namespace Details {
|
||||||
|
|
||||||
|
template <InstrSet I>
|
||||||
|
struct InstrIntTraits;
|
||||||
|
|
||||||
|
template <InstrSet I, typename T>
|
||||||
|
struct InstrFloatTraits;
|
||||||
|
|
||||||
|
// base class for algorithm supporting the method:
|
||||||
|
// uint32 scalar(T z) const
|
||||||
|
template <typename T, Algos A, typename Enable=void>
|
||||||
|
struct AlgoScalarBase;
|
||||||
|
|
||||||
|
// base class for algorithm supporting the following methods, constants and definitions:
|
||||||
|
// static const uint32 nElem
|
||||||
|
// struct Constants;
|
||||||
|
// void initConstants(Constants& cst) const
|
||||||
|
// void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
|
||||||
|
// The function vectorial processes nElem items
|
||||||
|
template <InstrSet I, typename T, Algos A, typename Enable=void>
|
||||||
|
struct AlgoVecBase;
|
||||||
|
|
||||||
|
template <typename T> struct IntTraits;
|
||||||
|
|
||||||
|
template <> struct IntTraits<float>
|
||||||
|
{
|
||||||
|
typedef uint32 itype;
|
||||||
|
};
|
||||||
|
template <> struct IntTraits<double>
|
||||||
|
{
|
||||||
|
typedef uint64 itype;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int N>
|
||||||
|
struct Body
|
||||||
|
{
|
||||||
|
template <uint32 D, typename T, typename Expr>
|
||||||
|
FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const typename Expr::Constants& cst)
|
||||||
|
{
|
||||||
|
e.vectorial(ri, zi, cst);
|
||||||
|
Body<N - 1>::template iteration<D>(e, ri + D, zi + D, cst);
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct Body<0>
|
||||||
|
{
|
||||||
|
template <uint32 D, typename T, typename Expr, typename H>
|
||||||
|
FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const H&)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename Algo>
|
||||||
|
struct Loop
|
||||||
|
{
|
||||||
|
typedef Algo algo_type;
|
||||||
|
static const uint32 M = 4;
|
||||||
|
static const uint32 D = algo_type::nElem;
|
||||||
|
|
||||||
|
FORCE_INLINE static void loop(const algo_type& e, uint32 *ri, const T* zi, uint32 n)
|
||||||
|
{
|
||||||
|
typename algo_type::Constants cst;
|
||||||
|
e.initConstants(cst);
|
||||||
|
|
||||||
|
uint32 j = 0;
|
||||||
|
while (j + (D*M) <= n) {
|
||||||
|
Details::Body<M>::template iteration<D>(e, ri + j, zi + j, cst);
|
||||||
|
j += (D*M);
|
||||||
|
}
|
||||||
|
while (j + D <= n) {
|
||||||
|
e.vectorial(ri + j, zi + j, cst);
|
||||||
|
j += D;
|
||||||
|
}
|
||||||
|
while (D > 1 && j < n) {
|
||||||
|
ri[j] = e.scalar(zi[j]);
|
||||||
|
j += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <uint32 nIterTot, uint32 nIterLeft>
|
||||||
|
struct _Pipeliner
|
||||||
|
{
|
||||||
|
template <typename Expr, typename Data>
|
||||||
|
FORCE_INLINE static void go(const Expr& e, Data* d)
|
||||||
|
{
|
||||||
|
e.template run<nIterTot - nIterLeft>(d);
|
||||||
|
_Pipeliner<nIterTot, nIterLeft - 1>::go(e, d);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <uint32 nIterTot>
|
||||||
|
struct _Pipeliner<nIterTot, 0>
|
||||||
|
{
|
||||||
|
template <typename Expr, typename Data>
|
||||||
|
FORCE_INLINE static void go(const Expr& e, Data* d)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <uint32 nIter>
|
||||||
|
struct Pipeliner
|
||||||
|
{
|
||||||
|
template <typename Expr, typename Data>
|
||||||
|
FORCE_INLINE static void go(const Expr& e, Data* d)
|
||||||
|
{
|
||||||
|
_Pipeliner<nIter, nIter>::go(e, d);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
template <class T>
|
||||||
|
char is_complete_impl(char (*)[sizeof(T)]);
|
||||||
|
|
||||||
|
template <class>
|
||||||
|
long is_complete_impl(...);
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
struct IsComplete
|
||||||
|
{
|
||||||
|
static const bool value = sizeof(is_complete_impl<T>(0)) == sizeof(char);
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
template <class T, std::size_t = sizeof(T)>
|
||||||
|
std::true_type is_complete_impl(T *);
|
||||||
|
|
||||||
|
std::false_type is_complete_impl(...);
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
struct IsComplete : decltype(is_complete_impl(std::declval<T*>())) {};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T, Algos A>
|
||||||
|
struct AlgoScalarToVec : AlgoScalarBase<T,A>
|
||||||
|
{
|
||||||
|
typedef AlgoScalarBase<T, A> base_t;
|
||||||
|
|
||||||
|
AlgoScalarToVec(const typename base_t::Data& d) : base_t(d) {}
|
||||||
|
AlgoScalarToVec(const T* px, const uint32 n) : base_t(px, n) {}
|
||||||
|
|
||||||
|
static const uint32 nElem = 1;
|
||||||
|
|
||||||
|
struct Constants
|
||||||
|
{
|
||||||
|
};
|
||||||
|
|
||||||
|
void initConstants(Constants& cst) const
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCE_INLINE
|
||||||
|
void vectorial(uint32 *pr, const T *pz, const Constants& cst) const
|
||||||
|
{
|
||||||
|
*pr = base_t::scalar(*pz);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<bool B, class T, class F>
|
||||||
|
struct conditional { typedef T type; };
|
||||||
|
|
||||||
|
template<class T, class F>
|
||||||
|
struct conditional<false, T, F> { typedef F type; };
|
||||||
|
|
||||||
|
template <typename T, bool C>
|
||||||
|
struct CondData
|
||||||
|
{
|
||||||
|
FORCE_INLINE CondData(T x) : v(x) {}
|
||||||
|
FORCE_INLINE operator const T&() const { return v;}
|
||||||
|
private:
|
||||||
|
T v;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct CondData<T,false>
|
||||||
|
{
|
||||||
|
FORCE_INLINE CondData(T) {}
|
||||||
|
FORCE_INLINE operator const T() const { return 0;}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <InstrSet I, typename T, Algos A, bool L=false>
|
||||||
|
struct BinAlgoBase : Details::conditional< Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value
|
||||||
|
, Details::AlgoVecBase<I, T, A>
|
||||||
|
, Details::AlgoScalarToVec<T,A>
|
||||||
|
>::type
|
||||||
|
{
|
||||||
|
typedef typename Details::conditional< Details::IsComplete<Details::AlgoVecBase<I, T, A>>::value
|
||||||
|
, Details::AlgoVecBase<I, T, A>
|
||||||
|
, Details::AlgoScalarToVec<T,A>
|
||||||
|
>::type base_t;
|
||||||
|
|
||||||
|
BinAlgoBase(const T* px, const uint32 n) : base_t(px, n) {}
|
||||||
|
BinAlgoBase(const typename base_t::Data& d) : base_t(d) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace Details
|
||||||
|
|
||||||
|
} // namespace BinSearch
|
6
pyproject.toml
Normal file
6
pyproject.toml
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
[build-system]
|
||||||
|
requires = [
|
||||||
|
"setuptools>=42",
|
||||||
|
"wheel"
|
||||||
|
]
|
||||||
|
build-backend = "setuptools.build_meta"
|
1
requirements.txt
Normal file
1
requirements.txt
Normal file
|
@ -0,0 +1 @@
|
||||||
|
pytest
|
32
setup.py
Normal file
32
setup.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import os
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def read(fname):
|
||||||
|
return open(os.path.join(os.path.dirname(__file__), fname)).read()
|
||||||
|
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name = f"bitsandbytes-cuda{os.environ['CUDA_VERSION']}",
|
||||||
|
version = "0.0.23",
|
||||||
|
author = "Tim Dettmers",
|
||||||
|
author_email = "tim.dettmers@gmail.com",
|
||||||
|
description = ("Numpy-like library for GPUs."),
|
||||||
|
license = "MIT",
|
||||||
|
keywords = "gpu",
|
||||||
|
url = "http://packages.python.org/bitsandbytes",
|
||||||
|
packages=find_packages(),
|
||||||
|
package_data={'': ['libbitsandbytes.so']},
|
||||||
|
long_description=read('README.md'),
|
||||||
|
long_description_content_type = 'text/markdown',
|
||||||
|
classifiers=[
|
||||||
|
"Development Status :: 1 - Planning",
|
||||||
|
'Topic :: Scientific/Engineering :: Artificial Intelligence'
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
213
tests/test_functional.py
Normal file
213
tests/test_functional.py
Normal file
|
@ -0,0 +1,213 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
from bitsandbytes import functional as F
|
||||||
|
|
||||||
|
def setup():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def teardown():
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['float', 'half'])
|
||||||
|
def test_estimate_quantiles(dtype):
|
||||||
|
A = torch.rand(1024, 1024, device='cuda')
|
||||||
|
A = A.to(dtype)
|
||||||
|
code = F.estimate_quantiles(A)
|
||||||
|
|
||||||
|
percs = torch.linspace(1/512, 511/512, 256, device=A.device)
|
||||||
|
torch.testing.assert_allclose(percs, code, atol=1e-3, rtol=1e-2)
|
||||||
|
|
||||||
|
A = torch.randn(1024, 1024, device='cuda')
|
||||||
|
A = A.to(dtype)
|
||||||
|
code = F.estimate_quantiles(A)
|
||||||
|
|
||||||
|
quantiles = torch.quantile(A.float(), percs)
|
||||||
|
diff = torch.abs(code-quantiles)
|
||||||
|
assert (diff > 5e-02).sum().item() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_quantile_quantization():
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.randn(1024, 1024, device='cuda')
|
||||||
|
code = F.estimate_quantiles(A1)
|
||||||
|
C = F.quantize_no_absmax(A1, code)
|
||||||
|
A2 = F.dequantize_no_absmax(C, code)
|
||||||
|
diff = torch.abs(A1-A2).mean().item()
|
||||||
|
assert diff < 0.0075
|
||||||
|
|
||||||
|
A1 = torch.rand(1024, 1024, device='cuda')
|
||||||
|
code = F.estimate_quantiles(A1)
|
||||||
|
C = F.quantize_no_absmax(A1, code)
|
||||||
|
A2 = F.dequantize_no_absmax(C, code)
|
||||||
|
diff = torch.abs(A1-A2).mean().item()
|
||||||
|
torch.testing.assert_allclose(A1, A2, atol=5e-3, rtol=0)
|
||||||
|
assert diff < 0.001
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_quantization():
|
||||||
|
diffs = []
|
||||||
|
reldiffs = []
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.randn(1024, 1024, device='cuda')
|
||||||
|
C, S = F.quantize(A1)
|
||||||
|
A2 = F.dequantize(C, S)
|
||||||
|
diff = torch.abs(A1-A2)
|
||||||
|
reldiff = diff/torch.abs(A1+1e-8)
|
||||||
|
diffs.append(diff.mean().item())
|
||||||
|
reldiffs.append(reldiff.mean().item())
|
||||||
|
assert diff.mean().item() < 0.0135
|
||||||
|
print(sum(diffs)/len(diffs))
|
||||||
|
print(sum(reldiffs)/len(reldiffs))
|
||||||
|
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.rand(1024, 1024, device='cuda')
|
||||||
|
C, S = F.quantize(A1)
|
||||||
|
A2 = F.dequantize(C, S)
|
||||||
|
diff = torch.abs(A1-A2).mean().item()
|
||||||
|
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||||
|
assert diff < 0.004
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_blockwise_quantization():
|
||||||
|
diffs = []
|
||||||
|
reldiffs = []
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.randn(1024, 1024, device='cuda')
|
||||||
|
C, S = F.quantize_blockwise(A1)
|
||||||
|
A2 = F.dequantize_blockwise(C, S)
|
||||||
|
diff = torch.abs(A1-A2)
|
||||||
|
reldiff = diff/torch.abs(A1+1e-8)
|
||||||
|
diffs.append(diff.mean().item())
|
||||||
|
reldiffs.append(reldiff.mean().item())
|
||||||
|
assert diffs[-1] < 0.011
|
||||||
|
print(sum(diffs)/len(diffs))
|
||||||
|
print(sum(reldiffs)/len(reldiffs))
|
||||||
|
|
||||||
|
diffs = []
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.rand(1024, 1024, device='cuda')
|
||||||
|
C, S = F.quantize_blockwise(A1)
|
||||||
|
A2 = F.dequantize_blockwise(C, S)
|
||||||
|
diff = torch.abs(A1-A2).mean().item()
|
||||||
|
assert diff < 0.0033
|
||||||
|
diffs.append(diff)
|
||||||
|
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||||
|
#print(sum(diffs)/len(diffs))
|
||||||
|
|
||||||
|
def test_dynamic_blockwise_stochastic_quantization():
|
||||||
|
diffs = []
|
||||||
|
reldiffs = []
|
||||||
|
rand = torch.rand(1024).cuda()
|
||||||
|
for i in range(100):
|
||||||
|
A1 = torch.randn(1024, 1024, device='cuda')
|
||||||
|
C1, S1 = F.quantize_blockwise(A1, rand=rand)
|
||||||
|
C2, S2 = F.quantize_blockwise(A1)
|
||||||
|
# a maximunm distance of quantized values of 1
|
||||||
|
torch.testing.assert_allclose(C1, C2, atol=1, rtol=0)
|
||||||
|
fraction_smaller = (C1<C2).float().sum()/C1.numel()
|
||||||
|
fraction_larger = (C1>C2).float().sum()/C1.numel()
|
||||||
|
torch.testing.assert_allclose(fraction_larger, fraction_smaller, atol=0.01, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=['float', 'half'])
|
||||||
|
def test_percentile_clipping(gtype):
|
||||||
|
gnorm_vec1 = torch.zeros(100, device='cuda')
|
||||||
|
gnorm_vec2 = torch.zeros(100, device='cuda')
|
||||||
|
n = 4
|
||||||
|
step = 0
|
||||||
|
percentile=5
|
||||||
|
for i in range(1000):
|
||||||
|
step += 1
|
||||||
|
g = torch.randn(n, n, dtype=gtype, device='cuda')
|
||||||
|
gnorm1, clip2, gnorm_scale = F.percentile_clipping(g, gnorm_vec2, step, percentile=percentile)
|
||||||
|
assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2/gnorm1
|
||||||
|
|
||||||
|
gnorm2 = torch.norm(g.float())
|
||||||
|
if step == 1:
|
||||||
|
gnorm_vec1[:] = gnorm2
|
||||||
|
else:
|
||||||
|
gnorm_vec1[step % 100] = gnorm2
|
||||||
|
|
||||||
|
vals, idx = torch.sort(gnorm_vec1)
|
||||||
|
clip1 = vals[percentile]
|
||||||
|
|
||||||
|
torch.testing.assert_allclose(gnorm_vec1, torch.sqrt(gnorm_vec2))
|
||||||
|
torch.testing.assert_allclose(clip1, clip2)
|
||||||
|
torch.testing.assert_allclose(gnorm1, gnorm2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stable_embedding():
|
||||||
|
layer = bnb.nn.StableEmbedding(1024, 1024)
|
||||||
|
layer.reset_parameters()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamic_blockwise_quantization_cpu():
|
||||||
|
#A1 = torch.randn(1024, 1024, device='cpu')
|
||||||
|
#code = F.create_dynamic_map()
|
||||||
|
#for i in range(1000):
|
||||||
|
# C, S = F.quantize_blockwise(A1, code=code)
|
||||||
|
# A2 = F.dequantize_blockwise(C, S)
|
||||||
|
|
||||||
|
for i in range(10):
|
||||||
|
# equivalence with GPU blockwise quantization
|
||||||
|
A1 = torch.randn(1024, 1024, device='cpu')
|
||||||
|
C1, S1 = F.quantize_blockwise(A1)
|
||||||
|
C2, S2 = F.quantize_blockwise(A1.cuda())
|
||||||
|
torch.testing.assert_allclose(S1[0], S2[0].cpu())
|
||||||
|
# there seems to be some issues with precision in CUDA vs CPU
|
||||||
|
# not all elements are usually close, with couple off elements in a million
|
||||||
|
idx = torch.isclose(C1, C2.cpu())
|
||||||
|
assert (idx==0).sum().item() < 15
|
||||||
|
|
||||||
|
|
||||||
|
diffs = []
|
||||||
|
reldiffs = []
|
||||||
|
for i in range(10):
|
||||||
|
A1 = torch.randn(1024, 1024, device='cpu')
|
||||||
|
C, S = F.quantize_blockwise(A1)
|
||||||
|
A2 = F.dequantize_blockwise(C, S)
|
||||||
|
diff = torch.abs(A1-A2)
|
||||||
|
reldiff = diff/torch.abs(A1+1e-8)
|
||||||
|
diffs.append(diff.mean().item())
|
||||||
|
reldiffs.append(reldiff.mean().item())
|
||||||
|
assert diffs[-1] < 0.011
|
||||||
|
#print(sum(diffs)/len(diffs))
|
||||||
|
#print(sum(reldiffs)/len(reldiffs))
|
||||||
|
|
||||||
|
diffs = []
|
||||||
|
for i in range(10):
|
||||||
|
A1 = torch.rand(1024, 1024, device='cpu')
|
||||||
|
C, S = F.quantize_blockwise(A1)
|
||||||
|
A2 = F.dequantize_blockwise(C, S)
|
||||||
|
diff = torch.abs(A1-A2).mean().item()
|
||||||
|
assert diff < 0.0033
|
||||||
|
diffs.append(diff)
|
||||||
|
torch.testing.assert_allclose(A1, A2, atol=1e-2, rtol=0)
|
||||||
|
#print(sum(diffs)/len(diffs))
|
||||||
|
|
||||||
|
|
||||||
|
def test_histogram():
|
||||||
|
dim1, dim2 = 32, 32
|
||||||
|
source = torch.rand(dim1, dim2, device='cuda')
|
||||||
|
idx1 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
|
||||||
|
idx2 = torch.randint(0, 255, size=(dim1, dim2), device='cuda').int()
|
||||||
|
histogram1 = torch.zeros((256, 256)).cuda()
|
||||||
|
histogram2 = torch.zeros((256, 256)).cuda()
|
||||||
|
|
||||||
|
F.histogram_scatter_add_2d(histogram2, idx1, idx2, source)
|
||||||
|
|
||||||
|
for i in range(dim1):
|
||||||
|
for j in range(dim2):
|
||||||
|
histogram1[idx1[i, j].item(), idx2[i, j].item()] += source[i, j]
|
||||||
|
|
||||||
|
torch.testing.assert_allclose(histogram1, histogram2)
|
||||||
|
torch.testing.assert_allclose(histogram1.sum(), source.sum())
|
362
tests/test_optim.py
Normal file
362
tests/test_optim.py
Normal file
|
@ -0,0 +1,362 @@
|
||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import shutil
|
||||||
|
import uuid
|
||||||
|
import pytest
|
||||||
|
import ctypes
|
||||||
|
import torch
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
import bitsandbytes.functional as F
|
||||||
|
|
||||||
|
from os.path import join
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import apex
|
||||||
|
|
||||||
|
def get_temp_dir():
|
||||||
|
path = '/tmp/autoswap/{0}'.format(str(uuid.uuid4()))
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
return path
|
||||||
|
|
||||||
|
def rm_path(path):
|
||||||
|
shutil.rmtree(path)
|
||||||
|
|
||||||
|
str2optimizers = {}
|
||||||
|
str2optimizers['adam_pytorch'] = (None, torch.optim.Adam, bnb.optim.Adam)
|
||||||
|
str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||||
|
str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||||
|
str2optimizers['momentum_pytorch'] = (None, lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), bnb.optim.Adam)
|
||||||
|
str2optimizers['lamb_apex'] = (None, lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.00, use_nvlamb=True), bnb.optim.Adam)
|
||||||
|
str2optimizers['lars_apex'] = (None, lambda pxx: apex.parallel.LARC.LARC(apex.optimizers.FusedSGD(pxx, 0.01, 0.9)), bnb.optim.Adam)
|
||||||
|
|
||||||
|
str2optimizers['adam'] = (torch.optim.Adam, bnb.optim.Adam)
|
||||||
|
str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam)
|
||||||
|
str2optimizers['momentum'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False))
|
||||||
|
str2optimizers['lars'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS(pxx, 0.01, 0.9))
|
||||||
|
str2optimizers['lamb'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB)
|
||||||
|
str2optimizers['rmsprop'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False))
|
||||||
|
str2optimizers['adam8bit'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False))
|
||||||
|
str2optimizers['momentum8bit'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False))
|
||||||
|
str2optimizers['rmsprop8bit'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False))
|
||||||
|
str2optimizers['lamb8bit'] = (lambda pxx: apex.optimizers.FusedLAMB(pxx, weight_decay=0.0, max_grad_norm=10000.0, eps=1e-8, use_nvlamb=True), bnb.optim.LAMB8bit)
|
||||||
|
str2optimizers['lars8bit'] = (lambda pxx: bnb.optim.PytorchLARS(pxx, 0.01, 0.9), lambda pxx: bnb.optim.LARS8bit(pxx, 0.01, 0.9))
|
||||||
|
|
||||||
|
str2optimizers['adam8bit_blockwise'] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True))
|
||||||
|
str2optimizers['momentum8bit_blockwise'] = (lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True))
|
||||||
|
str2optimizers['rmsprop8bit_blockwise'] = (lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True))
|
||||||
|
|
||||||
|
str2statenames = {}
|
||||||
|
str2statenames['adam'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||||
|
str2statenames['momentum'] = [('momentum_buffer', 'state1')]
|
||||||
|
str2statenames['lars'] = [('momentum_buffer', 'state1')]
|
||||||
|
str2statenames['lamb'] = [('exp_avg', 'state1'), ('exp_avg_sq', 'state2')]
|
||||||
|
str2statenames['rmsprop'] = [('square_avg', 'state1')]
|
||||||
|
str2statenames['adam8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
|
||||||
|
str2statenames['lamb8bit'] = [('exp_avg', 'state1', 'qmap1', 'max1'), ('exp_avg_sq', 'state2', 'qmap2', 'max2')]
|
||||||
|
str2statenames['adam8bit_blockwise'] = [('exp_avg', 'state1', 'qmap1', 'absmax1'), ('exp_avg_sq', 'state2', 'qmap2', 'absmax2')]
|
||||||
|
str2statenames['momentum8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
|
||||||
|
str2statenames['momentum8bit_blockwise'] = [('momentum_buffer', 'state1', 'qmap1', 'absmax1')]
|
||||||
|
str2statenames['lars8bit'] = [('momentum_buffer', 'state1', 'qmap1', 'max1')]
|
||||||
|
str2statenames['rmsprop8bit'] = [('square_avg', 'state1', 'qmap1', 'max1')]
|
||||||
|
str2statenames['rmsprop8bit_blockwise'] = [('square_avg', 'state1', 'qmap1', 'absmax1')]
|
||||||
|
|
||||||
|
dim1 = [1024]
|
||||||
|
dim2 = [32, 1024, 4097, 1]
|
||||||
|
gtype = [torch.float32, torch.float16]
|
||||||
|
optimizer_names = ['adam', 'momentum', 'rmsprop', 'lars', 'lamb']
|
||||||
|
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||||
|
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||||
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||||
|
def test_optimizer32bit(dim1, dim2, gtype, optim_name):
|
||||||
|
if dim1 == 1 and dim2 == 1: return
|
||||||
|
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
|
||||||
|
p2 = p1.clone()
|
||||||
|
p1 = p1.float()
|
||||||
|
|
||||||
|
|
||||||
|
torch_optimizer = str2optimizers[optim_name][0]([p1])
|
||||||
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||||
|
|
||||||
|
if gtype == torch.float32:
|
||||||
|
atol, rtol = 1e-6, 1e-5
|
||||||
|
else:
|
||||||
|
atol, rtol = 1e-4, 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
|
||||||
|
p1.grad = g.clone().float()
|
||||||
|
p2.grad = g.clone()
|
||||||
|
|
||||||
|
bnb_optimizer.step()
|
||||||
|
torch_optimizer.step()
|
||||||
|
|
||||||
|
for name1, name2 in str2statenames[optim_name]:
|
||||||
|
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
if i % 10 == 0 and i > 0:
|
||||||
|
path = get_temp_dir()
|
||||||
|
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
|
||||||
|
del bnb_optimizer
|
||||||
|
bnb_optimizer = None
|
||||||
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||||
|
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
|
||||||
|
rm_path(path)
|
||||||
|
torch.testing.assert_allclose(p1, p2.float(), atol=atol, rtol=rtol)
|
||||||
|
for name1, name2 in str2statenames[optim_name]:
|
||||||
|
torch.testing.assert_allclose(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
if gtype == torch.float16:
|
||||||
|
# the adam buffers should also be close because they are 32-bit
|
||||||
|
# but the paramters can diverge because they are 16-bit
|
||||||
|
# the difference grow larger and larger with each update
|
||||||
|
# --> copy the state to keep weights close
|
||||||
|
p1.data = p1.data.half().float()
|
||||||
|
p2.copy_(p1.data)
|
||||||
|
torch.testing.assert_allclose(p1.half(), p2)
|
||||||
|
if optim_name in ['lars', 'lamb']:
|
||||||
|
assert bnb_optimizer.state[p2]['unorm_vec'] > 0.0
|
||||||
|
|
||||||
|
dim1 = [1024]
|
||||||
|
dim2 = [32, 1024, 4097]
|
||||||
|
gtype = [torch.float32, torch.float16]
|
||||||
|
values = list(product(dim1,dim2, gtype))
|
||||||
|
names = ['dim1_{0}_dim2_{1}_gtype_{2}'.format(*vals) for vals in values]
|
||||||
|
@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names)
|
||||||
|
def test_global_config(dim1, dim2, gtype):
|
||||||
|
if dim1 == 1 and dim2 == 1: return
|
||||||
|
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
|
||||||
|
p2 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
|
||||||
|
p3 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
|
||||||
|
mask = torch.rand_like(p2) < 0.1
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
lr = 0.001
|
||||||
|
eps = 1e-8
|
||||||
|
|
||||||
|
bnb.optim.GlobalOptimManager.get_instance().initialize()
|
||||||
|
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, 'optim_bits', 8)
|
||||||
|
|
||||||
|
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
|
||||||
|
p1 = p1.cuda()
|
||||||
|
p2 = p2.cuda()
|
||||||
|
p3 = p3.cuda()
|
||||||
|
|
||||||
|
adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)
|
||||||
|
|
||||||
|
if gtype == torch.float32:
|
||||||
|
atol, rtol = 1e-6, 1e-5
|
||||||
|
else:
|
||||||
|
atol, rtol = 1e-4, 1e-3
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
|
||||||
|
g2 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
|
||||||
|
g3 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + 0.001
|
||||||
|
p1.grad = g1
|
||||||
|
p2.grad = g2
|
||||||
|
p3.grad = g3
|
||||||
|
|
||||||
|
adam2.step()
|
||||||
|
|
||||||
|
assert adam2.state[p3]['state1'].dtype == torch.uint8
|
||||||
|
assert adam2.state[p3]['state2'].dtype == torch.uint8
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
dim1 = [1024]
|
||||||
|
dim2 = [32, 1024, 4097]
|
||||||
|
gtype = [torch.float32, torch.float16]
|
||||||
|
optimizer_names = ['adam8bit', 'momentum8bit', 'rmsprop8bit', 'adam8bit_blockwise', 'lamb8bit', 'lars8bit', 'momentum8bit_blockwise', 'rmsprop8bit_blockwise']
|
||||||
|
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||||
|
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||||
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||||
|
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
|
||||||
|
if dim1 == 1 and dim2 == 1: return
|
||||||
|
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
|
||||||
|
p2 = p1.clone()
|
||||||
|
p1 = p1.float()
|
||||||
|
blocksize = 2048
|
||||||
|
|
||||||
|
torch_optimizer = str2optimizers[optim_name][0]([p1])
|
||||||
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||||
|
|
||||||
|
if gtype == torch.float32:
|
||||||
|
atol, rtol = 3e-3, 1e-3
|
||||||
|
patol, prtol = 1e-5, 1e-3
|
||||||
|
|
||||||
|
else:
|
||||||
|
atol, rtol = 3e-3, 1e-3
|
||||||
|
patol, prtol = 1e-5, 1e-3
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
relerrors = []
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
|
||||||
|
p1.grad = g.clone().float()
|
||||||
|
p2.grad = g.clone()
|
||||||
|
|
||||||
|
bnb_optimizer.step()
|
||||||
|
torch_optimizer.step()
|
||||||
|
|
||||||
|
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
|
||||||
|
|
||||||
|
dequant_states = []
|
||||||
|
for name1, name2, qmap, max_val in str2statenames[optim_name]:
|
||||||
|
#print(bnb_optimizer.state[p2][max_val], name1)
|
||||||
|
if 'blockwise' in optim_name:
|
||||||
|
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
|
||||||
|
else:
|
||||||
|
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
|
||||||
|
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
|
||||||
|
assert num_not_close.sum().item() < 20
|
||||||
|
dequant_states.append(s1.clone())
|
||||||
|
|
||||||
|
err = torch.abs(p1-p2)
|
||||||
|
relerr = err/torch.abs(p1)
|
||||||
|
assert err.mean() < 0.0001
|
||||||
|
assert relerr.mean() < 0.001
|
||||||
|
|
||||||
|
errors.append(err.mean().item())
|
||||||
|
relerrors.append(relerr.mean().item())
|
||||||
|
|
||||||
|
if i % 10 == 0 and i > 0:
|
||||||
|
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
|
||||||
|
s1cpy = s.clone()
|
||||||
|
raws1cpy = bnb_optimizer.state[p2][name2].clone()
|
||||||
|
qmap1 = bnb_optimizer.state[p2][qmap].clone()
|
||||||
|
|
||||||
|
path = get_temp_dir()
|
||||||
|
torch.save(bnb_optimizer.state_dict(),join(path, 'opt.pt'))
|
||||||
|
del bnb_optimizer
|
||||||
|
bnb_optimizer = None
|
||||||
|
bnb_optimizer = str2optimizers[optim_name][1]([p2])
|
||||||
|
bnb_optimizer.load_state_dict(torch.load(join(path, 'opt.pt')))
|
||||||
|
rm_path(path)
|
||||||
|
torch.testing.assert_allclose(raws1cpy, bnb_optimizer.state[p2][name2])
|
||||||
|
torch.testing.assert_allclose(qmap1, bnb_optimizer.state[p2][qmap])
|
||||||
|
|
||||||
|
if 'blockwise' in optim_name:
|
||||||
|
s1 = F.dequantize_blockwise(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2], blocksize=blocksize)
|
||||||
|
else:
|
||||||
|
s1 = F.dequantize(code=bnb_optimizer.state[p2][qmap], absmax=bnb_optimizer.state[p2][max_val], A=bnb_optimizer.state[p2][name2])
|
||||||
|
torch.testing.assert_allclose(s1cpy, s1)
|
||||||
|
|
||||||
|
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol)==0
|
||||||
|
assert num_not_close.sum().item() < 20
|
||||||
|
torch.testing.assert_allclose(p1, p2.float(), atol=patol, rtol=prtol)
|
||||||
|
|
||||||
|
# the parameters diverge quickly. Here we keep them close
|
||||||
|
# together so we can test against the Adam error
|
||||||
|
p1.data = p1.data.to(gtype).float()
|
||||||
|
p2.copy_(p1.data)
|
||||||
|
torch.testing.assert_allclose(p1.to(gtype), p2)
|
||||||
|
for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states):
|
||||||
|
torch_optimizer.state[p1][name1].copy_(s.data)
|
||||||
|
|
||||||
|
#print(sum(errors)/len(errors))
|
||||||
|
#print(sum(relerrors)/len(relerrors))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
dim1 = [1024]
|
||||||
|
dim2 = [32, 1024, 4097]
|
||||||
|
gtype = [torch.float32]
|
||||||
|
optim_bits = [32, 8]
|
||||||
|
values = list(product(dim1,dim2, gtype, optim_bits))
|
||||||
|
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_bits_{3}'.format(*vals) for vals in values]
|
||||||
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names)
|
||||||
|
def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
|
||||||
|
if dim1 == 1 and dim2 == 1: return
|
||||||
|
p1 = torch.randn(dim1,dim2, device='cpu', dtype=gtype)*0.1
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
lr = 0.001
|
||||||
|
eps = 1e-8
|
||||||
|
p1 = p1.cuda()
|
||||||
|
p2 = p1.clone()
|
||||||
|
adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits)
|
||||||
|
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
|
||||||
|
|
||||||
|
gnorm_vec = torch.zeros(100).cuda()
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
for i in range(50):
|
||||||
|
step += 1
|
||||||
|
g1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1 + (0.01*i)
|
||||||
|
g2 = g1.clone()
|
||||||
|
p2.grad = g2
|
||||||
|
|
||||||
|
current_gnorm, clip_val, gnorm_scale = F.percentile_clipping(g1, gnorm_vec, step, 5)
|
||||||
|
g1 = (g1.float()*gnorm_scale).to(gtype)
|
||||||
|
p1.grad = g1
|
||||||
|
|
||||||
|
adam1.step()
|
||||||
|
adam2.step()
|
||||||
|
|
||||||
|
# gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state
|
||||||
|
if optim_bits == 32:
|
||||||
|
torch.testing.assert_allclose(p1, p2)
|
||||||
|
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=5e-5, rtol=1e-4)
|
||||||
|
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=5e-5, rtol=1e-4)
|
||||||
|
elif optim_bits == 8:
|
||||||
|
torch.testing.assert_allclose(p1, p2, atol=1e-4, rtol=1e-3)
|
||||||
|
torch.testing.assert_allclose(adam1.state[p1]['state1'], adam2.state[p2]['state1'], atol=2, rtol=1e-3)
|
||||||
|
torch.testing.assert_allclose(adam1.state[p1]['state2'], adam2.state[p2]['state2'], atol=2, rtol=1e-3)
|
||||||
|
adam1.state[p1]['state1'].copy_(adam2.state[p2]['state1'])
|
||||||
|
adam1.state[p1]['state2'].copy_(adam2.state[p2]['state2'])
|
||||||
|
if i % 10 == 0 and i > 0:
|
||||||
|
path = get_temp_dir()
|
||||||
|
torch.save(adam2.state_dict(),join(path, 'opt.pt'))
|
||||||
|
del adam2
|
||||||
|
adam2 = None
|
||||||
|
adam2 = bnb.optim.Adam([p2], lr, (beta1, beta2), eps, optim_bits=optim_bits, percentile_clipping=5)
|
||||||
|
adam2.load_state_dict(torch.load(join(path, 'opt.pt')))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
dim1 = [4096]
|
||||||
|
dim2 = [4096]
|
||||||
|
gtype = [torch.float32, torch.float16]
|
||||||
|
#optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit']
|
||||||
|
#optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch']
|
||||||
|
#optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch']
|
||||||
|
#optimizer_names = ['lamb_apex', 'lamb8bit']
|
||||||
|
#optimizer_names = ['lars_apex', 'lars8bit']
|
||||||
|
optimizer_names = ['adam8bit_blockwise']
|
||||||
|
values = list(product(dim1,dim2, gtype, optimizer_names))
|
||||||
|
names = ['dim1_{0}_dim2_{1}_gtype_{2}_optim_{3}'.format(*vals) for vals in values]
|
||||||
|
@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names)
|
||||||
|
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
|
||||||
|
if dim1 == 1 and dim2 == 1: return
|
||||||
|
p1 = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.1
|
||||||
|
|
||||||
|
|
||||||
|
bnb_optimizer = str2optimizers[optim_name][1]([p1])
|
||||||
|
|
||||||
|
g = torch.randn(dim1,dim2, device='cuda', dtype=gtype)*0.01
|
||||||
|
p1.grad = g
|
||||||
|
for i in range(5000):
|
||||||
|
if i == 500:
|
||||||
|
# 100 iterations for burn-in
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
t0 = time.time()
|
||||||
|
|
||||||
|
bnb_optimizer.step()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
s = time.time()-t0
|
||||||
|
print('')
|
||||||
|
params = 4500*4096*4096
|
||||||
|
print(optim_name, gtype, s/params)
|
||||||
|
#assert s < 3.9
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user