From 408dc6cf9260dbfa382c06c7518b1bbb910dcc0a Mon Sep 17 00:00:00 2001 From: Tim Dettmers Date: Sun, 7 May 2023 16:38:55 -0700 Subject: [PATCH] 4-bit beta initial. --- .buckconfig | 0 .gitignore | 135 + CHANGELOG.md | 230 ++ CODE_OF_CONDUCT.md | 80 + CONTRIBUTING.md | 31 + LICENSE | 21 + Makefile | 145 + NOTICE.md | 3 + README.md | 191 + benchmarking/switchback/README.md | 4 + benchmarking/switchback/info_a100_py2.jsonl | 60 + .../switchback/make_plot_with_jsonl.py | 138 + benchmarking/switchback/plot_with_info.pdf | Bin 0 -> 34876 bytes benchmarking/switchback/speed_benchmark.py | 102 + bitsandbytes/__init__.py | 27 + bitsandbytes/__main__.py | 154 + bitsandbytes/autograd/__init__.py | 1 + bitsandbytes/autograd/_functions.py | 564 +++ bitsandbytes/cextension.py | 42 + bitsandbytes/cuda_setup/__init__.py | 0 bitsandbytes/cuda_setup/env_vars.py | 52 + bitsandbytes/cuda_setup/main.py | 427 ++ bitsandbytes/functional.py | 2464 +++++++++++ bitsandbytes/nn/__init__.py | 6 + bitsandbytes/nn/modules.py | 464 +++ bitsandbytes/nn/triton_based_modules.py | 258 ++ bitsandbytes/optim/__init__.py | 16 + bitsandbytes/optim/adagrad.py | 132 + bitsandbytes/optim/adam.py | 273 ++ bitsandbytes/optim/adamw.py | 39 + bitsandbytes/optim/lamb.py | 105 + bitsandbytes/optim/lars.py | 210 + bitsandbytes/optim/lion.py | 87 + bitsandbytes/optim/optimizer.py | 724 ++++ bitsandbytes/optim/rmsprop.py | 115 + bitsandbytes/optim/sgd.py | 99 + bitsandbytes/research/__init__.py | 6 + bitsandbytes/research/autograd/__init__.py | 0 bitsandbytes/research/autograd/_functions.py | 411 ++ bitsandbytes/research/nn/__init__.py | 1 + bitsandbytes/research/nn/modules.py | 64 + bitsandbytes/triton/__init__.py | 0 bitsandbytes/triton/dequantize_rowwise.py | 64 + .../triton/int8_matmul_mixed_dequanitze.py | 163 + .../triton/int8_matmul_rowwise_dequantize.py | 164 + .../quantize_columnwise_and_transpose.py | 74 + bitsandbytes/triton/quantize_global.py | 107 + bitsandbytes/triton/quantize_rowwise.py | 68 + bitsandbytes/triton/triton_utils.py | 4 + bitsandbytes/utils.py | 199 + check_bnb_install.py | 20 + compile_from_source.md | 35 + csrc/common.cpp | 39 + csrc/common.h | 25 + csrc/cpu_ops.cpp | 73 + csrc/cpu_ops.h | 10 + csrc/kernels.cu | 3605 +++++++++++++++++ csrc/kernels.cuh | 130 + csrc/ops.cu | 846 ++++ csrc/ops.cuh | 206 + csrc/pythonInterface.c | 370 ++ cuda_install.sh | 89 + deploy.sh | 265 ++ environment.yml | 15 + errors_and_solutions.md | 21 + examples/int8_inference_huggingface.py | 27 + howto_config_override.md | 40 + include/AAlloc.h | 86 + include/Algo-Direct-Common.h | 341 ++ include/Algo-Direct2.h | 305 ++ include/AlgoXCodes.h | 23 + include/BinAlgo.h | 77 + include/BinSearch.h | 11 + include/Portable.h | 151 + include/SIMD.h | 562 +++ include/Type.h | 221 + pyproject.toml | 6 + requirements.txt | 2 + setup.py | 36 + tests/test_autograd.py | 627 +++ tests/test_cuda_setup_evaluator.py | 40 + tests/test_functional.py | 2514 ++++++++++++ tests/test_linear8bitlt.py | 143 + tests/test_modules.py | 618 +++ tests/test_optim.py | 562 +++ tests/test_triton.py | 59 + 86 files changed, 20924 insertions(+) create mode 100644 .buckconfig create mode 100644 .gitignore create mode 100644 CHANGELOG.md create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 NOTICE.md create mode 100644 README.md create mode 100644 benchmarking/switchback/README.md create mode 100644 benchmarking/switchback/info_a100_py2.jsonl create mode 100644 benchmarking/switchback/make_plot_with_jsonl.py create mode 100644 benchmarking/switchback/plot_with_info.pdf create mode 100644 benchmarking/switchback/speed_benchmark.py create mode 100644 bitsandbytes/__init__.py create mode 100644 bitsandbytes/__main__.py create mode 100644 bitsandbytes/autograd/__init__.py create mode 100644 bitsandbytes/autograd/_functions.py create mode 100644 bitsandbytes/cextension.py create mode 100644 bitsandbytes/cuda_setup/__init__.py create mode 100644 bitsandbytes/cuda_setup/env_vars.py create mode 100644 bitsandbytes/cuda_setup/main.py create mode 100644 bitsandbytes/functional.py create mode 100644 bitsandbytes/nn/__init__.py create mode 100644 bitsandbytes/nn/modules.py create mode 100644 bitsandbytes/nn/triton_based_modules.py create mode 100644 bitsandbytes/optim/__init__.py create mode 100644 bitsandbytes/optim/adagrad.py create mode 100644 bitsandbytes/optim/adam.py create mode 100644 bitsandbytes/optim/adamw.py create mode 100644 bitsandbytes/optim/lamb.py create mode 100644 bitsandbytes/optim/lars.py create mode 100644 bitsandbytes/optim/lion.py create mode 100644 bitsandbytes/optim/optimizer.py create mode 100644 bitsandbytes/optim/rmsprop.py create mode 100644 bitsandbytes/optim/sgd.py create mode 100644 bitsandbytes/research/__init__.py create mode 100644 bitsandbytes/research/autograd/__init__.py create mode 100644 bitsandbytes/research/autograd/_functions.py create mode 100644 bitsandbytes/research/nn/__init__.py create mode 100644 bitsandbytes/research/nn/modules.py create mode 100644 bitsandbytes/triton/__init__.py create mode 100644 bitsandbytes/triton/dequantize_rowwise.py create mode 100644 bitsandbytes/triton/int8_matmul_mixed_dequanitze.py create mode 100644 bitsandbytes/triton/int8_matmul_rowwise_dequantize.py create mode 100644 bitsandbytes/triton/quantize_columnwise_and_transpose.py create mode 100644 bitsandbytes/triton/quantize_global.py create mode 100644 bitsandbytes/triton/quantize_rowwise.py create mode 100644 bitsandbytes/triton/triton_utils.py create mode 100644 bitsandbytes/utils.py create mode 100644 check_bnb_install.py create mode 100644 compile_from_source.md create mode 100644 csrc/common.cpp create mode 100644 csrc/common.h create mode 100644 csrc/cpu_ops.cpp create mode 100644 csrc/cpu_ops.h create mode 100644 csrc/kernels.cu create mode 100644 csrc/kernels.cuh create mode 100644 csrc/ops.cu create mode 100644 csrc/ops.cuh create mode 100644 csrc/pythonInterface.c create mode 100644 cuda_install.sh create mode 100644 deploy.sh create mode 100644 environment.yml create mode 100644 errors_and_solutions.md create mode 100644 examples/int8_inference_huggingface.py create mode 100644 howto_config_override.md create mode 100644 include/AAlloc.h create mode 100644 include/Algo-Direct-Common.h create mode 100644 include/Algo-Direct2.h create mode 100644 include/AlgoXCodes.h create mode 100644 include/BinAlgo.h create mode 100644 include/BinSearch.h create mode 100644 include/Portable.h create mode 100644 include/SIMD.h create mode 100644 include/Type.h create mode 100644 pyproject.toml create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 tests/test_autograd.py create mode 100644 tests/test_cuda_setup_evaluator.py create mode 100644 tests/test_functional.py create mode 100644 tests/test_linear8bitlt.py create mode 100644 tests/test_modules.py create mode 100644 tests/test_optim.py create mode 100644 tests/test_triton.py diff --git a/.buckconfig b/.buckconfig new file mode 100644 index 000000000..e69de29bb diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..f8ebf71af --- /dev/null +++ b/.gitignore @@ -0,0 +1,135 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# vim +*.swp + +dependencies +cuda_build diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..2de70d371 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,230 @@ +### 0.0.21 +- Ampere, RTX 30 series GPUs now compatible with the library. + +### 0.0.22: + +- Fixed an error where a `reset_parameters()` call on the `StableEmbedding` would lead to an error in older PyTorch versions (from 1.7.0). + +### 0.0.23: + +Bugs: + - Unified quantization API: each quantization function now returns `Q, S` where `Q` is the quantized tensor and `S` the quantization state which may hold absolute max values, a quantization map or more. For dequantization all functions now accept the inputs `Q, S` so that `Q` is dequantized with the quantization state `S`. + - Fixed an issue where the CUDA 11.1 binary was not compiled with the right headers + +API changes: + - Block-wise quantization for optimizers now enabled by default + +Features: + - Block-wise quantization routines now support CPU Tensors. + + +### 0.0.24: + +- Fixed a bug where a float/half conversion led to a compilation error for CUDA 11.1 on Turning GPUs. +- removed Apex dependency for bnb LAMB + +### 0.0.25: + +Features: + - Added `skip_zeros` for block-wise and 32-bit optimizers. This ensures correct updates for sparse gradients and sparse models. + - Added support for Kepler GPUs. (#4) + - Added Analysis Adam to track 8-bit vs 32-bit quantization errors over time. + - Make compilation more user friendly. + +Bug fixes: + - fixed "undefined symbol: \_\_fatbinwrap_38" error for P100 GPUs on CUDA 10.1 (#5) + +Docs: + - Added docs with instructions to compile from source. + + +### 0.26.0: + +Features: + - Added Adagrad (without grad clipping) as 32-bit and 8-bit block-wise optimizer. + - Added AdamW (copy of Adam with weight decay init 1e-2). #10 + - Introduced ModuleConfig overrides which can be seamlessly be used at initialization time of a module. + - Added `bnb.nn.Embedding` layer which runs at 32-bit but without the layernorm. This works well if you need to fine-tune pretrained models that do not have a embedding layer norm. #19 + +Bug fixes: + - Fixed a bug where weight decay was incorrectly applied to 32-bit Adam. #13 + - Fixed an unsafe use of eval. #8 + - Fixed a bug where the StableEmbedding layer 32-bit optimizer override would not work without registering the whole model first (`bnb.optim.GlobalOptimManager.get_instance().register_parameters(model.parameters())`). #13 #15 + +Docs: + - Added instructions how to solve "\_\_fatbinwrap_" errors. + + +### 0.30.0 + +#### 8-bit Inference Update + +Features: + - Added 8-bit matrix multiplication form cuBLAS, and cuBLASLt as well as multiple GEMM kernels (GEMM, GEMMEx, GEMMLt) + - Added 8-bit Linear layers with 8-bit Params that perform memory efficient inference with an option for 8-bit mixed precision matrix decomposition for inference without performance degradation + - Added quantization methods for "fake" quantization as well as optimized kernels vector-wise quantization and equalization as well as optimized cuBLASLt transformations + - CPU only build now available (Thank you, @mryab) + +Deprecated: + - Pre-compiled release for CUDA 9.2, 10.0, 10.2 no longer available + +### 0.31.0 + +#### 8-bit Inference and Packaging Update + +Features: + - added direct outlier extraction. This enables outlier extraction without fp16 weights without performance degradation. + - Added automatic CUDA SETUP procedure and packaging all binaries into a single bitsandbytes package. + +### 0.32.0 + +#### 8-bit Inference Performance Enhancements + +We added performance enhancements for small models. This makes small models about 2x faster for LLM.int8() inference. + +Features: + - Int32 dequantization now supports fused biases. + - Linear8bitLt now uses a fused bias implementation. + - Change `.data.storage().data_ptr()` to `.data.data_ptr()` to enhance inference performance. + +Bug fixes: + - Now throws and error if LLM.int8() is used on a GPU that is not supported. + - Enhances error messaging if CUDA SETUP fails. + + +### 0.33.0 + +#### Various bug fixes + +Features: + - CPU quantization now supports a variable `blocksize` variable to enhance quantization speed or precision. + +Bug fixes: + - fixed an issue in CPU quantization where tensors with more than 2^31 elements would fail 19a7adca7a6c9bf7061a384d7e9d9b13676a1a88 + - fixed a bug where cpu binaries would fail if no GPU would be detected eab4d8232d558f2e6bd7f7cc3d00e2e6e94f4e80 + - fixed an issue where cpu binaries cause additional stdout messages 92a3363096e10ad6a5c4e944af898bd1186d806a + - fixed an import of bnb.utils 2e630b55f51d454f3bd723dffda68a07ef93190c + +We thank @mryab, @mbrukman, @chessgecko, @dbaranchuk for pull request with bug fixes and new features. + + +### 0.34.0 + +#### Bug fixes and memory efficient backprop + +Features: + - Linear8bitLt layer now supports `memory_efficient_backward=True` which enables backprop of gradients through frozen weights. + +Bug fixes: + - fixed an issue where too many threads were created in blockwise quantization on the CPU for large tensors + + +### 0.35.0 + +#### CUDA 11.8 support and bug fixes + +Features: + - CUDA 11.8 support added and binaries added to the PyPI release. + +Bug fixes: + - fixed a bug where too long directory names would crash the CUDA SETUP #35 (thank you @tomaarsen) + - fixed a bug where CPU installations on Colab would run into an error #34 (thank you @tomaarsen) + - fixed an issue where the default CUDA version with fast-DreamBooth was not supported #52 + +### 0.35.1 + +Features: + - Added CUDA instruction generator to fix some installations. + +Bug fixes: + - Fixed a problem where warning messages would be displayed even though everything worked correctly. + +### 0.35.2 + +Bug fixes: + - Fixed a bug where the CUDA setup failed due to a wrong function call. + +### 0.35.3 + +Bug fixes: + - Fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected. + +### 0.35.4 + +Bug fixes: + - Fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library. + - Fixed a bug where not finding the cuda runtime led to an incomprehensible error. + + +### 0.36.0 + +#### Improvements, Ada/Hopper support, fake k-bit quantization. + +Features: + - CUDA 11.8 and 12.0 support added + - support for Ada and Hopper GPUs added (compute capability 8.9 and 9.0) + - support for fake k-bit block-wise quantization for Int, Float, quantile quantization, and dynamic exponent data types added + - Added CUDA instruction generator to fix some installations. + - Added additional block sizes for quantization {64, 128, 256, 512, 1024} + - Added SRAM Quantile algorithm to quickly estimate less than 256 quantiles + - Added option to suppress the bitsandbytes welcome message (@Cyberes) + +Regression: + - Compute capability 3.0 removed: GTX 600s and 700s series is no longer supported (except GTX 780 and GTX 780 Ti) + +Bug fixes: + - fixed a bug where too long directory names would crash the CUDA SETUP #35 (@tomaarsen) + - fixed a bug where CPU installations on Colab would run into an error #34 (@tomaarsen) + - fixed an issue where the default CUDA version with fast-DreamBooth was not supported #52 + - fixed a bug where the CUDA setup failed due to a wrong function call. + - fixed a bug in the CUDA Setup which led to an incomprehensible error if no GPU was detected. + - fixed a bug in the CUDA Setup failed with the cuda runtime was found, but not the cuda library. + - fixed a bug where not finding the cuda runtime led to an incomprehensible error. + - fixed a bug where with missing CUDA the default was an error instead of the loading the CPU library + - fixed a bug where the CC version of the GPU was not detected appropriately (@BlackHC) + - fixed a bug in CPU quantization which lead to errors when the input buffer exceeded 2^31 elements + +Improvements: + - multiple improvements in formatting, removal of unused imports, and slight performance improvements (@tomaarsen) + - StableEmbedding layer now has device and dtype parameters to make it 1:1 replaceable with regular Embedding layers (@lostmsu) + - runtime performance of block-wise quantization slightly improved + - added error message for the case multiple libcudart.so are installed and bitsandbytes picks the wrong one + + +### 0.37.0 + +#### Int8 Matmul + backward support for all GPUs + +Features: + - Int8 MatmulLt now supports backward through inversion of the ColTuring/ColAmpere format. Slow, but memory efficient. Big thanks to @borzunov + - Int8 now supported on all GPUs. On devices with compute capability < 7.5, the Int weights are cast to 16/32-bit for the matrix multiplication. Contributed by @borzunov + +Improvements: + - Improved logging for the CUDA detection mechanism. + +### 0.38.0 + +#### 8-bit Lion, Load/Store 8-bit Models directly from/to HF Hub + +Features: + - Support for 32 and 8-bit Lion has been added. Thank you @lucidrains + - Support for serialization of Linear8bitLt layers (LLM.int8()). This allows to store and load 8-bit weights directly from the HuggingFace Hub. Thank you @myrab + - New bug report features `python -m bitsandbytes` now gives extensive debugging details to debug CUDA setup failures. + +Bug fixes: + - Fixed a bug where some bitsandbytes methods failed in a model-parallel setup on multiple GPUs. Thank you @tonylins + - Fixed a bug where cudart.so libraries could not be found in newer PyTorch releases. + +Improvements: + - Improved the CUDA Setup procedure by doing a more extensive search for CUDA libraries + +Deprecated: + - Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0. + - Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0 + + +### 0.38.1 + +Features: + - Added Int8 SwitchBack layers + - Added Fake FP8 layers for research purposes (available under `bnb.research.nn. ...`) diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 000000000..08b500a22 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..0fae0ace5 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to bitsandbytes +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Facebook's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to bitsandbytes, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..b96dcb048 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 000000000..ea6ee87d5 --- /dev/null +++ b/Makefile @@ -0,0 +1,145 @@ +MKFILE_PATH := $(abspath $(lastword $(MAKEFILE_LIST))) +ROOT_DIR := $(patsubst %/,%,$(dir $(MKFILE_PATH))) + +GPP:= /usr/bin/g++ +#GPP:= /sw/gcc/11.2.0/bin/g++ +ifeq ($(CUDA_HOME),) + CUDA_HOME:= $(shell which nvcc | rev | cut -d'/' -f3- | rev) +endif + +ifndef CUDA_VERSION +$(warning WARNING: CUDA_VERSION not set. Call make with CUDA string, for example: make cuda11x CUDA_VERSION=115 or make cpuonly CUDA_VERSION=CPU) +CUDA_VERSION:= +endif + + + +NVCC := $(CUDA_HOME)/bin/nvcc + +########################################### + +CSRC := $(ROOT_DIR)/csrc +BUILD_DIR:= $(ROOT_DIR)/build + +FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu +FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c + +INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include +INCLUDE_10x := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/dependencies/cub -I $(ROOT_DIR)/include +LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcurand -lcusparse -L $(CONDA_PREFIX)/lib + +# NVIDIA NVCC compilation flags +COMPUTE_CAPABILITY += -gencode arch=compute_50,code=sm_50 # Maxwell +COMPUTE_CAPABILITY += -gencode arch=compute_52,code=sm_52 # Maxwell +COMPUTE_CAPABILITY += -gencode arch=compute_60,code=sm_60 # Pascal +COMPUTE_CAPABILITY += -gencode arch=compute_61,code=sm_61 # Pascal +COMPUTE_CAPABILITY += -gencode arch=compute_70,code=sm_70 # Volta +COMPUTE_CAPABILITY += -gencode arch=compute_72,code=sm_72 # Volta + +CC_KEPLER := -gencode arch=compute_35,code=sm_35 # Kepler +CC_KEPLER += -gencode arch=compute_37,code=sm_37 # Kepler + +# Later versions of CUDA support the new architectures +CC_CUDA10x += -gencode arch=compute_75,code=sm_75 + +CC_CUDA110 := -gencode arch=compute_75,code=sm_75 +CC_CUDA110 += -gencode arch=compute_80,code=sm_80 + +CC_CUDA11x := -gencode arch=compute_75,code=sm_75 +CC_CUDA11x += -gencode arch=compute_80,code=sm_80 +CC_CUDA11x += -gencode arch=compute_86,code=sm_86 + + +CC_cublasLt110 := -gencode arch=compute_75,code=sm_75 +CC_cublasLt110 += -gencode arch=compute_80,code=sm_80 + +CC_cublasLt111 := -gencode arch=compute_75,code=sm_75 +#CC_cublasLt111 += -gencode arch=compute_80,code=sm_80 +#CC_cublasLt111 += -gencode arch=compute_86,code=sm_86 + +CC_ADA_HOPPER := -gencode arch=compute_89,code=sm_89 +CC_ADA_HOPPER += -gencode arch=compute_90,code=sm_90 + + +all: $(BUILD_DIR) env + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) + +cuda92: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA92) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) + +cuda10x_nomatmul: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE_10x) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA10x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) + +cuda110_nomatmul: $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA110) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) + +cuda11x_nomatmul: $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_KEPLER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) + +cuda12x_nomatmul: $(BUILD_DIR) env + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) -D NO_CUBLASLT + $(NVCC) $(COMPUTE_CAPABILITY) $(CC_CUDA11x) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION)_nocublaslt.so $(LIB) + +cuda110: $(BUILD_DIR) env + $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt110) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) + +cuda11x: $(BUILD_DIR) env + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) + +cuda12x: $(BUILD_DIR) env + $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR) + $(NVCC) $(CC_cublasLt111) $(CC_ADA_HOPPER) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o + $(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB) + +cpuonly: $(BUILD_DIR) env + $(GPP) -std=c++14 -shared -fPIC -I $(ROOT_DIR)/csrc -I $(ROOT_DIR)/include $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cpu.so + +env: + @echo "ENVIRONMENT" + @echo "============================" + @echo "CUDA_VERSION: $(CUDA_VERSION)" + @echo "============================" + @echo "NVCC path: $(NVCC)" + @echo "GPP path: $(GPP) VERSION: `$(GPP) --version | head -n 1`" + @echo "CUDA_HOME: $(CUDA_HOME)" + @echo "CONDA_PREFIX: $(CONDA_PREFIX)" + @echo "PATH: $(PATH)" + @echo "LD_LIBRARY_PATH: $(LD_LIBRARY_PATH)" + @echo "============================" + +cutlass: + if [ ! -d "$(ROOT_DIR)/dependencies/cutlass" ]; then \ + git clone https://github.com/NVIDIA/cutlass.git $(ROOT_DIR)/dependencies/cutlass; \ + fi \ + +$(BUILD_DIR): + mkdir -p build + mkdir -p dependencies + +$(ROOT_DIR)/dependencies/cub: + git clone https://github.com/NVlabs/cub $(ROOT_DIR)/dependencies/cub + cd dependencies/cub; git checkout 1.11.0 + +clean: + rm build/* + +cleaneggs: + rm -rf *.egg* + +cleanlibs: + rm ./bitsandbytes/libbitsandbytes*.so diff --git a/NOTICE.md b/NOTICE.md new file mode 100644 index 000000000..660658b05 --- /dev/null +++ b/NOTICE.md @@ -0,0 +1,3 @@ +The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license. + +We thank Fabio Cannizzo for this work on FastBinarySearch which is included in this project. diff --git a/README.md b/README.md new file mode 100644 index 000000000..727a86cb5 --- /dev/null +++ b/README.md @@ -0,0 +1,191 @@ +# bitsandbytes + +The bitsandbytes is a lightweight wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and quantization functions. + + + +Resources: +- [8-bit Optimizer Paper](https://arxiv.org/abs/2110.02861) -- [Video](https://www.youtube.com/watch?v=IxrlHAJtqKE) -- [Docs](https://bitsandbytes.readthedocs.io/en/latest/) + +- [LLM.int8() Paper](https://arxiv.org/abs/2208.07339) -- [LLM.int8() Software Blog Post](https://huggingface.co/blog/hf-bitsandbytes-integration) -- [LLM.int8() Emergent Features Blog Post](https://timdettmers.com/2022/08/17/llm-int8-and-emergent-features/) + +## TL;DR +**Requirements** +Python >=3.8. Linux distribution (Ubuntu, MacOS, etc.) + CUDA > 10.0. + +(Deprecated: CUDA 10.0 is deprecated and only CUDA >= 11.0) will be supported with release 0.39.0) + +**Installation**: + +``pip install bitsandbytes`` + +In some cases it can happen that you need to compile from source. If this happens please consider submitting a bug report with `python -m bitsandbytes` information. What now follows is some short instructions which might work out of the box if `nvcc` is installed. If these do not work see further below. + +Compilation quickstart: +```bash +git clone https://github.com/timdettmers/bitsandbytes.git +cd bitsandbytes + +# CUDA_VERSIONS in {110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 120} +# make argument in {cuda110, cuda11x, cuda12x} +# if you do not know what CUDA you have, try looking at the output of: python -m bitsandbytes +CUDA_VERSION=117 make cuda11x +python setup.py install +``` + +**Using Int8 inference with HuggingFace Transformers** + +```python +from transformers import AutoModelForCausalLM +model = AutoModelForCausalLM.from_pretrained( + 'decapoda-research/llama-7b-hf, + device_map='auto', + load_in_8bit=True, + max_memory=f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB') +``` + +A more detailed example, can be found in [examples/int8_inference_huggingface.py](examples/int8_inference_huggingface.py). + +**Using 8-bit optimizer**: +1. Comment out optimizer: ``#torch.optim.Adam(....)`` +2. Add 8-bit optimizer of your choice ``bnb.optim.Adam8bit(....)`` (arguments stay the same) +3. Replace embedding layer if necessary: ``torch.nn.Embedding(..) -> bnb.nn.Embedding(..)`` + + +**Using 8-bit Inference**: +1. Comment out torch.nn.Linear: ``#linear = torch.nn.Linear(...)`` +2. Add bnb 8-bit linear light module: ``linear = bnb.nn.Linear8bitLt(...)`` (base arguments stay the same) +3. There are two modes: + - Mixed 8-bit training with 16-bit main weights. Pass the argument ``has_fp16_weights=True`` (default) + - Int8 inference. Pass the argument ``has_fp16_weights=False`` +4. To use the full LLM.int8() method, use the ``threshold=k`` argument. We recommend ``k=6.0``. +```python +# LLM.int8() +linear = bnb.nn.Linear8bitLt(dim1, dim2, bias=True, has_fp16_weights=False, threshold=6.0) +# inputs need to be fp16 +out = linear(x.to(torch.float16)) +``` + + +## Features +- 8-bit Matrix multiplication with mixed precision decomposition +- LLM.int8() inference +- 8-bit Optimizers: Adam, AdamW, RMSProp, LARS, LAMB, Lion (saves 75% memory) +- Stable Embedding Layer: Improved stability through better initialization, and normalization +- 8-bit quantization: Quantile, Linear, and Dynamic quantization +- Fast quantile estimation: Up to 100x faster than other algorithms + +## Requirements & Installation + +Requirements: anaconda, cudatoolkit, pytorch + +Hardware requirements: + - LLM.int8(): NVIDIA Turing (RTX 20xx; T4) or Ampere GPU (RTX 30xx; A4-A100); (a GPU from 2018 or older). + - 8-bit optimizers and quantization: NVIDIA Kepler GPU or newer (>=GTX 78X). + +Supported CUDA versions: 10.2 - 12.0 + +The bitsandbytes library is currently only supported on Linux distributions. Windows is not supported at the moment. + +The requirements can best be fulfilled by installing pytorch via anaconda. You can install PyTorch by following the ["Get Started"](https://pytorch.org/get-started/locally/) instructions on the official website. + +To install run: + +``pip install bitsandbytes`` + +## Using bitsandbytes + +### Using Int8 Matrix Multiplication + +For straight Int8 matrix multiplication with mixed precision decomposition you can use ``bnb.matmul(...)``. To enable mixed precision decomposition, use the threshold parameter: +```python +bnb.matmul(..., threshold=6.0) +``` + +For instructions how to use LLM.int8() inference layers in your own code, see the TL;DR above or for extended instruction see [this blog post](https://github.com/huggingface/transformers). + +### Using the 8-bit Optimizers + +With bitsandbytes 8-bit optimizers can be used by changing a single line of code in your codebase. For NLP models we recommend also to use the StableEmbedding layers (see below) which improves results and helps with stable 8-bit optimization. To get started with 8-bit optimizers, it is sufficient to replace your old optimizer with the 8-bit optimizer in the following way: +```python +import bitsandbytes as bnb + +# adam = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # comment out old optimizer +adam = bnb.optim.Adam8bit(model.parameters(), lr=0.001, betas=(0.9, 0.995)) # add bnb optimizer +adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) # equivalent + + +torch.nn.Embedding(...) -> bnb.nn.StableEmbedding(...) # recommended for NLP models +``` + +Note that by default all parameter tensors with less than 4096 elements are kept at 32-bit even if you initialize those parameters with 8-bit optimizers. This is done since such small tensors do not save much memory and often contain highly variable parameters (biases) or parameters that require high precision (batch norm, layer norm). You can change this behavior like so: +``` +# parameter tensors with less than 16384 values are optimized in 32-bit +# it is recommended to use multiplies of 4096 +adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384) +``` + +### Change Bits and other Hyperparameters for Individual Parameters + +If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details + +### Fairseq Users + +To use the Stable Embedding Layer, override the respective `build_embedding(...)` function of your model. Make sure to also use the `--no-scale-embedding` flag to disable scaling of the word embedding layer (nor replaced with layer norm). You can use the optimizers by replacing the optimizer in the respective file (`adam.py` etc.). + +## Release and Feature History + +For upcoming features and changes and full history see [Patch Notes](CHANGELOG.md). + +## Errors + +1. RuntimeError: CUDA error: no kernel image is available for execution on the device. [Solution](errors_and_solutions.md#No-kernel-image-available) +2. __fatbinwrap_.. [Solution](errors_and_solutions.md#fatbinwrap_) + +## Compile from source +To compile from source, you need an installation of CUDA. If `nvcc` is not installed, you can install the CUDA Toolkit with nvcc through the following commands. + +```bash +wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh +# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121} +# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True + +# For example, the following installs CUDA 11.8 to ~/local/cuda-11.8 and exports the path to your .bashrc +bash cuda install 118 ~/local 1 +``` + +To use a specific CUDA version just for a single compile run, you can set the variable `CUDA_HOME`, for example the following command compiles `libbitsandbytes_cuda117.so` using compiler flags for cuda11x with the cuda version at `~/local/cuda-11.7`: + +``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` + +For more detailed instruction, please follow the [compile_from_source.md](compile_from_source.md) instructions. + +## License + +The majority of bitsandbytes is licensed under MIT, however portions of the project are available under separate license terms: Pytorch is licensed under the BSD license. + +We thank Fabio Cannizzo for his work on [FastBinarySearch](https://github.com/fabiocannizzo/FastBinarySearch) which we use for CPU quantization. + +## How to cite us +If you found this library and found LLM.int8() useful, please consider citing our work: + +```bibtex +@article{dettmers2022llmint8, + title={LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale}, + author={Dettmers, Tim and Lewis, Mike and Belkada, Younes and Zettlemoyer, Luke}, + journal={arXiv preprint arXiv:2208.07339}, + year={2022} +} +``` + +For 8-bit optimizers or quantization routines, please consider citing the following work: + +```bibtex +@article{dettmers2022optimizers, + title={8-bit Optimizers via Block-wise Quantization}, + author={Dettmers, Tim and Lewis, Mike and Shleifer, Sam and Zettlemoyer, Luke}, + journal={9th International Conference on Learning Representations, ICLR}, + year={2022} +} +``` diff --git a/benchmarking/switchback/README.md b/benchmarking/switchback/README.md new file mode 100644 index 000000000..bb33b5bbd --- /dev/null +++ b/benchmarking/switchback/README.md @@ -0,0 +1,4 @@ +Steps: + +1. Run `python speed_benchmark/speed_benchmark.py` which times operations and writes their time to `speed_benchmark/info_a100_py2.jsonl` (change the name of the jsonl to a different name for your profiling). +2. Run `python speed_benchmark/make_plot_with_jsonl.py`, which produces the `speed_benchmark/plot_with_info.pdf`. Again make sure you change the jsonl which is being processed. \ No newline at end of file diff --git a/benchmarking/switchback/info_a100_py2.jsonl b/benchmarking/switchback/info_a100_py2.jsonl new file mode 100644 index 000000000..53cda62cf --- /dev/null +++ b/benchmarking/switchback/info_a100_py2.jsonl @@ -0,0 +1,60 @@ +{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.28139352798461914, "standard_gw": 0.2811811864376068, "standard_gx": 0.30258670449256897, "rowwise_fwd": 0.1994594931602478, "rowwise_bwd": 0.16159191727638245, "global_fwd": 0.19502267241477966, "global_bwd": 0.16080215573310852, "x_quantize_rowwise": 0.03306940197944641, "g_quantize_rowwise": 0.08210167288780212, "w_quantize_rowwise": 0.03385916352272034, "w_quantize_colwise_transpose": 0.08635595440864563, "w_quantize_global": 0.09237229824066162, "w_quantize_global_transpose": 0.10007619857788086, "time_standard": 0.8651614189147949, "time_rowwise": 0.8776187896728516, "time_global": 0.944625586271286} +{"repeat": 64, "batch_size": 8192, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.262625515460968, "standard_gw": 0.2806223928928375, "standard_gx": 0.31118839979171753, "rowwise_fwd": 0.1828707754611969, "rowwise_bwd": 0.21236762404441833, "global_fwd": 0.16665831208229065, "global_bwd": 0.19929558038711548, "x_quantize_rowwise": 0.08227676153182983, "g_quantize_rowwise": 0.03310292959213257, "w_quantize_rowwise": 0.032648444175720215, "w_quantize_colwise_transpose": 0.09015202522277832, "w_quantize_global": 0.0988692045211792, "w_quantize_global_transpose": 0.10057538747787476, "time_standard": 0.8544363081455231, "time_rowwise": 0.9140409529209137, "time_global": 0.96140056848526} +{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 0.5731917917728424, "standard_gw": 0.5709454417228699, "standard_gx": 0.5963630974292755, "rowwise_fwd": 0.37662312388420105, "rowwise_bwd": 0.281747430562973, "global_fwd": 0.36768242716789246, "global_bwd": 0.28043612837791443, "x_quantize_rowwise": 0.046547502279281616, "g_quantize_rowwise": 0.15532970428466797, "w_quantize_rowwise": 0.032436102628707886, "w_quantize_colwise_transpose": 0.08635222911834717, "w_quantize_global": 0.0947415828704834, "w_quantize_global_transpose": 0.10129809379577637, "time_standard": 1.7405003309249878, "time_rowwise": 1.5499815344810486, "time_global": 1.616980880498886} +{"repeat": 64, "batch_size": 16384, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 0.5341619253158569, "standard_gw": 0.5690865218639374, "standard_gx": 0.599835067987442, "rowwise_fwd": 0.3233291208744049, "rowwise_bwd": 0.41359663009643555, "global_fwd": 0.2831108868122101, "global_bwd": 0.37280842661857605, "x_quantize_rowwise": 0.15563145279884338, "g_quantize_rowwise": 0.046741217374801636, "w_quantize_rowwise": 0.03306940197944641, "w_quantize_colwise_transpose": 0.09020790457725525, "w_quantize_global": 0.0925213098526001, "w_quantize_global_transpose": 0.09945780038833618, "time_standard": 1.7030835151672363, "time_rowwise": 1.6316622495651245, "time_global": 1.6193576157093048} +{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 1.2199915945529938, "standard_gw": 1.1069811880588531, "standard_gx": 1.09761580824852, "rowwise_fwd": 0.738043338060379, "rowwise_bwd": 0.5549229681491852, "global_fwd": 0.7219798862934113, "global_bwd": 0.5512163043022156, "x_quantize_rowwise": 0.08748471736907959, "g_quantize_rowwise": 0.3023110330104828, "w_quantize_rowwise": 0.03182142972946167, "w_quantize_colwise_transpose": 0.08632615208625793, "w_quantize_global": 0.09445473551750183, "w_quantize_global_transpose": 0.10032951831817627, "time_standard": 3.424588590860367, "time_rowwise": 2.9078908264636993, "time_global": 2.9647573828697205} +{"repeat": 64, "batch_size": 32768, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 1.1040829122066498, "standard_gw": 1.1221766471862793, "standard_gx": 1.1548101902008057, "rowwise_fwd": 0.581938773393631, "rowwise_bwd": 0.7480122148990631, "global_fwd": 0.5537159740924835, "global_bwd": 0.7232688367366791, "x_quantize_rowwise": 0.30193477869033813, "g_quantize_rowwise": 0.08745118975639343, "w_quantize_rowwise": 0.03374740481376648, "w_quantize_colwise_transpose": 0.09068101644515991, "w_quantize_global": 0.09645149111747742, "w_quantize_global_transpose": 0.10189786553382874, "time_standard": 3.3810697495937347, "time_rowwise": 2.9659420251846313, "time_global": 2.9868967831134796} +{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 2.4533793330192566, "standard_gw": 2.1938569843769073, "standard_gx": 2.179361879825592, "rowwise_fwd": 1.4615543186664581, "rowwise_bwd": 1.0522231459617615, "global_fwd": 1.4288239181041718, "global_bwd": 1.0450035333633423, "x_quantize_rowwise": 0.1691766083240509, "g_quantize_rowwise": 0.5951300263404846, "w_quantize_rowwise": 0.03337860107421875, "w_quantize_colwise_transpose": 0.08653849363327026, "w_quantize_global": 0.0940859317779541, "w_quantize_global_transpose": 0.09976327419281006, "time_standard": 6.826598197221756, "time_rowwise": 5.5918581783771515, "time_global": 5.625840276479721} +{"repeat": 64, "batch_size": 65536, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 2.1698065102100372, "standard_gw": 2.1875128149986267, "standard_gx": 2.2887587547302246, "rowwise_fwd": 1.0762326419353485, "rowwise_bwd": 1.4638006687164307, "global_fwd": 1.0450668632984161, "global_bwd": 1.4308765530586243, "x_quantize_rowwise": 0.5953535437583923, "g_quantize_rowwise": 0.16899779438972473, "w_quantize_rowwise": 0.03240257501602173, "w_quantize_colwise_transpose": 0.09106099605560303, "w_quantize_global": 0.09546056389808655, "w_quantize_global_transpose": 0.09852275252342224, "time_standard": 6.6460780799388885, "time_rowwise": 5.615361034870148, "time_global": 5.621790885925293} +{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 1024, "wm": 4, "switch": false, "standard_fwd": 4.858218133449554, "standard_gw": 4.3631307780742645, "standard_gx": 4.404045641422272, "rowwise_fwd": 2.9063820838928223, "rowwise_bwd": 2.094462513923645, "global_fwd": 2.8426870703697205, "global_bwd": 2.0792782306671143, "x_quantize_rowwise": 0.33241137862205505, "g_quantize_rowwise": 1.1817105114459991, "w_quantize_rowwise": 0.03374367952346802, "w_quantize_colwise_transpose": 0.08633732795715332, "w_quantize_global": 0.09231641888618469, "w_quantize_global_transpose": 0.100012868642807, "time_standard": 13.62539455294609, "time_rowwise": 10.998178273439407, "time_global": 10.991547256708145} +{"repeat": 64, "batch_size": 131072, "dim_out": 1024, "dim_in": 4096, "wm": 4, "switch": true, "standard_fwd": 4.246581345796585, "standard_gw": 4.42587211728096, "standard_gx": 4.581417888402939, "rowwise_fwd": 2.1114833652973175, "rowwise_bwd": 2.9050447046756744, "global_fwd": 2.0806826651096344, "global_bwd": 2.85966694355011, "x_quantize_rowwise": 1.1816024780273438, "g_quantize_rowwise": 0.33330172300338745, "w_quantize_rowwise": 0.033445656299591064, "w_quantize_colwise_transpose": 0.09065866470336914, "w_quantize_global": 0.09239837527275085, "w_quantize_global_transpose": 0.09984523057937622, "time_standard": 13.253871351480484, "time_rowwise": 11.081408709287643, "time_global": 11.073369532823563} +{"repeat": 64, "batch_size": 8192, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 0.4859529435634613, "standard_gw": 0.46338513493537903, "standard_gx": 0.42321905493736267, "rowwise_fwd": 0.2761557698249817, "rowwise_bwd": 0.20775198936462402, "global_fwd": 0.2713911235332489, "global_bwd": 0.20639970898628235, "x_quantize_rowwise": 0.033095479011535645, "g_quantize_rowwise": 0.11894106864929199, "w_quantize_rowwise": 0.03125518560409546, "w_quantize_colwise_transpose": 0.1424551010131836, "w_quantize_global": 0.07288157939910889, "w_quantize_global_transpose": 0.08071959018707275, "time_standard": 1.372557133436203, "time_rowwise": 1.2730397284030914, "time_global": 1.2468136847019196} +{"repeat": 64, "batch_size": 8192, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.3920421004295349, "standard_gw": 0.44424086809158325, "standard_gx": 0.4759356379508972, "rowwise_fwd": 0.23231282830238342, "rowwise_bwd": 0.28430670499801636, "global_fwd": 0.20883232355117798, "global_bwd": 0.2741999924182892, "x_quantize_rowwise": 0.12018159031867981, "g_quantize_rowwise": 0.03195926547050476, "w_quantize_rowwise": 0.026017427444458008, "w_quantize_colwise_transpose": 0.14733895659446716, "w_quantize_global": 0.07734447717666626, "w_quantize_global_transpose": 0.0788569450378418, "time_standard": 1.3122186064720154, "time_rowwise": 1.2863576412200928, "time_global": 1.235615462064743} +{"repeat": 64, "batch_size": 16384, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 1.0111741721630096, "standard_gw": 0.9267590939998627, "standard_gx": 0.8254274725914001, "rowwise_fwd": 0.5434826016426086, "rowwise_bwd": 0.4077926278114319, "global_fwd": 0.5318708717823029, "global_bwd": 0.40537863969802856, "x_quantize_rowwise": 0.059738755226135254, "g_quantize_rowwise": 0.2299174666404724, "w_quantize_rowwise": 0.02545863389968872, "w_quantize_colwise_transpose": 0.14269724488258362, "w_quantize_global": 0.07300823926925659, "w_quantize_global_transpose": 0.07878988981246948, "time_standard": 2.7633607387542725, "time_rowwise": 2.335846424102783, "time_global": 2.305462956428528} +{"repeat": 64, "batch_size": 16384, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 0.8095316588878632, "standard_gw": 0.8607134222984314, "standard_gx": 0.9204968810081482, "rowwise_fwd": 0.4275888204574585, "rowwise_bwd": 0.5485899746417999, "global_fwd": 0.41000545024871826, "global_bwd": 0.5317628383636475, "x_quantize_rowwise": 0.2301819622516632, "g_quantize_rowwise": 0.059254467487335205, "w_quantize_rowwise": 0.02466142177581787, "w_quantize_colwise_transpose": 0.14865398406982422, "w_quantize_global": 0.07582828402519226, "w_quantize_global_transpose": 0.08231401443481445, "time_standard": 2.5907419621944427, "time_rowwise": 2.2996440529823303, "time_global": 2.2500604391098022} +{"repeat": 64, "batch_size": 32768, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 2.0658522844314575, "standard_gw": 1.718364655971527, "standard_gx": 1.6660578548908234, "rowwise_fwd": 1.066897064447403, "rowwise_bwd": 0.8070804178714752, "global_fwd": 1.0473169386386871, "global_bwd": 0.8021742105484009, "x_quantize_rowwise": 0.11274218559265137, "g_quantize_rowwise": 0.4518181085586548, "w_quantize_rowwise": 0.026501715183258057, "w_quantize_colwise_transpose": 0.14259666204452515, "w_quantize_global": 0.07484853267669678, "w_quantize_global_transpose": 0.07976219058036804, "time_standard": 5.450274795293808, "time_rowwise": 4.326000809669495, "time_global": 4.287026822566986} +{"repeat": 64, "batch_size": 32768, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 2.7549192309379578, "standard_gw": 1.6954988241195679, "standard_gx": 1.8179528415203094, "rowwise_fwd": 0.8649080991744995, "rowwise_bwd": 1.0746456682682037, "global_fwd": 0.8023083209991455, "global_bwd": 1.0471977293491364, "x_quantize_rowwise": 0.45225024223327637, "g_quantize_rowwise": 0.11286512017250061, "w_quantize_rowwise": 0.0252649188041687, "w_quantize_colwise_transpose": 0.14732033014297485, "w_quantize_global": 0.07537379860877991, "w_quantize_global_transpose": 0.0807642936706543, "time_standard": 6.268370896577835, "time_rowwise": 4.372753202915192, "time_global": 4.266258329153061} +{"repeat": 64, "batch_size": 65536, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 4.098430275917053, "standard_gw": 3.3501461148262024, "standard_gx": 5.560480058193207, "rowwise_fwd": 2.112947404384613, "rowwise_bwd": 1.605246216058731, "global_fwd": 2.0697638392448425, "global_bwd": 1.5953518450260162, "x_quantize_rowwise": 0.21921470761299133, "g_quantize_rowwise": 0.8956789970397949, "w_quantize_rowwise": 0.02710893750190735, "w_quantize_colwise_transpose": 0.14268234372138977, "w_quantize_global": 0.07259473204612732, "w_quantize_global_transpose": 0.07899105548858643, "time_standard": 13.009056448936462, "time_rowwise": 8.35302472114563, "time_global": 8.281741291284561} +{"repeat": 64, "batch_size": 65536, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 5.586959421634674, "standard_gw": 3.358360379934311, "standard_gx": 3.6434978246688843, "rowwise_fwd": 1.6269534826278687, "rowwise_bwd": 2.128206193447113, "global_fwd": 1.5950687229633331, "global_bwd": 2.0831897854804993, "x_quantize_rowwise": 0.8954145014286041, "g_quantize_rowwise": 0.21914392709732056, "w_quantize_rowwise": 0.026203691959381104, "w_quantize_colwise_transpose": 0.14658644795417786, "w_quantize_global": 0.07478520274162292, "w_quantize_global_transpose": 0.07964670658111572, "time_standard": 12.58881762623787, "time_rowwise": 8.400868624448776, "time_global": 8.305609226226807} +{"repeat": 64, "batch_size": 131072, "dim_out": 5120, "dim_in": 1280, "wm": 4, "switch": false, "standard_fwd": 8.229725062847137, "standard_gw": 6.791356950998306, "standard_gx": 6.806455552577972, "rowwise_fwd": 4.252471029758453, "rowwise_bwd": 3.2062679529190063, "global_fwd": 4.175614565610886, "global_bwd": 3.1837262213230133, "x_quantize_rowwise": 0.4321373999118805, "g_quantize_rowwise": 1.787092536687851, "w_quantize_rowwise": 0.0270158052444458, "w_quantize_colwise_transpose": 0.1424252986907959, "w_quantize_global": 0.07348507642745972, "w_quantize_global_transpose": 0.07829815149307251, "time_standard": 21.827537566423416, "time_rowwise": 16.63876697421074, "time_global": 16.52171090245247} +{"repeat": 64, "batch_size": 131072, "dim_out": 1280, "dim_in": 5120, "wm": 4, "switch": true, "standard_fwd": 11.279478669166565, "standard_gw": 6.7345499992370605, "standard_gx": 7.206875830888748, "rowwise_fwd": 3.209315240383148, "rowwise_bwd": 4.256397485733032, "global_fwd": 3.180190920829773, "global_bwd": 4.177983850240707, "x_quantize_rowwise": 1.7836056649684906, "g_quantize_rowwise": 0.4321075975894928, "w_quantize_rowwise": 0.03205239772796631, "w_quantize_colwise_transpose": 0.14675036072731018, "w_quantize_global": 0.09316205978393555, "w_quantize_global_transpose": 0.10086596012115479, "time_standard": 25.220904499292374, "time_rowwise": 16.5947787463665, "time_global": 16.502466052770615} +{"repeat": 64, "batch_size": 8192, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 0.5776733160018921, "standard_gw": 0.5300231277942657, "standard_gx": 0.6005913019180298, "rowwise_fwd": 0.33330172300338745, "rowwise_bwd": 0.2957060933113098, "global_fwd": 0.32876431941986084, "global_bwd": 0.29108673334121704, "x_quantize_rowwise": 0.03466755151748657, "g_quantize_rowwise": 0.12264400720596313, "w_quantize_rowwise": 0.033874064683914185, "w_quantize_colwise_transpose": 0.1775398850440979, "w_quantize_global": 0.09503215551376343, "w_quantize_global_transpose": 0.10617449879646301, "time_standard": 1.7082877457141876, "time_rowwise": 1.5277564525604248, "time_global": 1.5083923935890198} +{"repeat": 64, "batch_size": 8192, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 0.5164109170436859, "standard_gw": 0.5367249250411987, "standard_gx": 0.5876161158084869, "rowwise_fwd": 0.3132447600364685, "rowwise_bwd": 0.3396235406398773, "global_fwd": 0.2943649888038635, "global_bwd": 0.33209100365638733, "x_quantize_rowwise": 0.12357160449028015, "g_quantize_rowwise": 0.035997480154037476, "w_quantize_rowwise": 0.03213062882423401, "w_quantize_colwise_transpose": 0.17676874995231628, "w_quantize_global": 0.09861215949058533, "w_quantize_global_transpose": 0.0998862087726593, "time_standard": 1.6407519578933716, "time_rowwise": 1.5580616891384125, "time_global": 1.5212483704090118} +{"repeat": 64, "batch_size": 16384, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 1.2096501886844635, "standard_gw": 1.0663382709026337, "standard_gx": 1.0961703956127167, "rowwise_fwd": 0.6396733224391937, "rowwise_bwd": 0.5173943936824799, "global_fwd": 0.6296299397945404, "global_bwd": 0.5130060017108917, "x_quantize_rowwise": 0.06211921572685242, "g_quantize_rowwise": 0.2361498773097992, "w_quantize_rowwise": 0.03260001540184021, "w_quantize_colwise_transpose": 0.17679482698440552, "w_quantize_global": 0.09361281991004944, "w_quantize_global_transpose": 0.09913742542266846, "time_standard": 3.372158855199814, "time_rowwise": 2.7310699224472046, "time_global": 2.6999935507774353} +{"repeat": 64, "batch_size": 16384, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 1.1065565049648285, "standard_gw": 1.0664314031600952, "standard_gx": 1.1266544461250305, "rowwise_fwd": 0.5352050065994263, "rowwise_bwd": 0.6464086472988129, "global_fwd": 0.513765960931778, "global_bwd": 0.6284862756729126, "x_quantize_rowwise": 0.23620948195457458, "g_quantize_rowwise": 0.062271952629089355, "w_quantize_rowwise": 0.031460076570510864, "w_quantize_colwise_transpose": 0.17675384879112244, "w_quantize_global": 0.09486451745033264, "w_quantize_global_transpose": 0.09898096323013306, "time_standard": 3.2996423542499542, "time_rowwise": 2.7547404170036316, "time_global": 2.7010105550289154} +{"repeat": 64, "batch_size": 32768, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 2.4367496371269226, "standard_gw": 2.0806193351745605, "standard_gx": 2.19624862074852, "rowwise_fwd": 1.2554042041301727, "rowwise_bwd": 1.0227933526039124, "global_fwd": 1.2322552502155304, "global_bwd": 1.0152235627174377, "x_quantize_rowwise": 0.11792033910751343, "g_quantize_rowwise": 0.4639364778995514, "w_quantize_rowwise": 0.03241002559661865, "w_quantize_colwise_transpose": 0.17657503485679626, "w_quantize_global": 0.09655207395553589, "w_quantize_global_transpose": 0.09958073496818542, "time_standard": 6.713617593050003, "time_rowwise": 5.149658769369125, "time_global": 5.106087774038315} +{"repeat": 64, "batch_size": 32768, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 2.1935217082500458, "standard_gw": 2.0055584609508514, "standard_gx": 2.1882541477680206, "rowwise_fwd": 1.0396353900432587, "rowwise_bwd": 1.2542344629764557, "global_fwd": 1.0161921381950378, "global_bwd": 1.233428716659546, "x_quantize_rowwise": 0.4642195999622345, "g_quantize_rowwise": 0.11782720685005188, "w_quantize_rowwise": 0.033117830753326416, "w_quantize_colwise_transpose": 0.17696991562843323, "w_quantize_global": 0.09416043758392334, "w_quantize_global_transpose": 0.10101497173309326, "time_standard": 6.387334316968918, "time_rowwise": 5.091562867164612, "time_global": 5.032401531934738} +{"repeat": 64, "batch_size": 65536, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 4.804681986570358, "standard_gw": 4.763372242450714, "standard_gx": 4.064023494720459, "rowwise_fwd": 2.484843134880066, "rowwise_bwd": 1.9691288471221924, "global_fwd": 2.441786229610443, "global_bwd": 1.9574686884880066, "x_quantize_rowwise": 0.2294592559337616, "g_quantize_rowwise": 0.9196549654006958, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.1768544316291809, "w_quantize_global": 0.09644776582717896, "w_quantize_global_transpose": 0.09847059845924377, "time_standard": 13.632077723741531, "time_rowwise": 10.574690997600555, "time_global": 10.506659746170044} +{"repeat": 64, "batch_size": 65536, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 4.0907710790634155, "standard_gw": 3.9793066680431366, "standard_gx": 4.302978515625, "rowwise_fwd": 1.992940902709961, "rowwise_bwd": 2.4996213614940643, "global_fwd": 1.9551962614059448, "global_bwd": 2.457551658153534, "x_quantize_rowwise": 0.9200014173984528, "g_quantize_rowwise": 0.2293996512889862, "w_quantize_rowwise": 0.0313781201839447, "w_quantize_colwise_transpose": 0.17882883548736572, "w_quantize_global": 0.09540095925331116, "w_quantize_global_transpose": 0.09880587458610535, "time_standard": 12.373056262731552, "time_rowwise": 9.831476956605911, "time_global": 9.73566249012947} +{"repeat": 64, "batch_size": 131072, "dim_out": 5632, "dim_in": 1408, "wm": 4, "switch": false, "standard_fwd": 9.655728936195374, "standard_gw": 8.261296898126602, "standard_gx": 8.064884692430496, "rowwise_fwd": 5.007706582546234, "rowwise_bwd": 3.8615092635154724, "global_fwd": 4.920527338981628, "global_bwd": 3.8330331444740295, "x_quantize_rowwise": 0.45276060700416565, "g_quantize_rowwise": 1.8306002020835876, "w_quantize_rowwise": 0.031366944313049316, "w_quantize_colwise_transpose": 0.1766495406627655, "w_quantize_global": 0.09412690997123718, "w_quantize_global_transpose": 0.09780004620552063, "time_standard": 25.981910526752472, "time_rowwise": 19.621890038251877, "time_global": 19.49014514684677} +{"repeat": 64, "batch_size": 131072, "dim_out": 1408, "dim_in": 5632, "wm": 4, "switch": true, "standard_fwd": 8.033104240894318, "standard_gw": 8.2889124751091, "standard_gx": 8.622754365205765, "rowwise_fwd": 3.8747042417526245, "rowwise_bwd": 5.003921687602997, "global_fwd": 3.8315393030643463, "global_bwd": 4.9162134528160095, "x_quantize_rowwise": 1.8304847180843353, "g_quantize_rowwise": 0.4522763192653656, "w_quantize_rowwise": 0.03413110971450806, "w_quantize_colwise_transpose": 0.1771189272403717, "w_quantize_global": 0.09519979357719421, "w_quantize_global_transpose": 0.09930506348609924, "time_standard": 24.944771081209183, "time_rowwise": 19.661549478769302, "time_global": 19.51393112540245} +{"repeat": 64, "batch_size": 8192, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 0.7954612374305725, "standard_gw": 0.7456131279468536, "standard_gx": 0.8799619972705841, "rowwise_fwd": 0.43267011642456055, "rowwise_bwd": 0.34622475504875183, "global_fwd": 0.42615458369255066, "global_bwd": 0.344250351190567, "x_quantize_rowwise": 0.03748014569282532, "g_quantize_rowwise": 0.13304129242897034, "w_quantize_rowwise": 0.03294646739959717, "w_quantize_colwise_transpose": 0.2407953143119812, "w_quantize_global": 0.094633549451828, "w_quantize_global_transpose": 0.10305643081665039, "time_standard": 2.4210363626480103, "time_rowwise": 1.96877121925354, "time_global": 1.8842294812202454} +{"repeat": 64, "batch_size": 8192, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 0.7120333611965179, "standard_gw": 0.7622130215167999, "standard_gx": 0.8262209594249725, "rowwise_fwd": 0.3702230751514435, "rowwise_bwd": 0.4419572651386261, "global_fwd": 0.3479123115539551, "global_bwd": 0.4306286573410034, "x_quantize_rowwise": 0.13308599591255188, "g_quantize_rowwise": 0.037495046854019165, "w_quantize_rowwise": 0.03398209810256958, "w_quantize_colwise_transpose": 0.23782625794410706, "w_quantize_global": 0.09853765368461609, "w_quantize_global_transpose": 0.10247156023979187, "time_standard": 2.3004673421382904, "time_rowwise": 2.016782760620117, "time_global": 1.9123442471027374} +{"repeat": 64, "batch_size": 16384, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 1.6292817890644073, "standard_gw": 1.5109702944755554, "standard_gx": 1.482747495174408, "rowwise_fwd": 0.8386112749576569, "rowwise_bwd": 0.6844550371170044, "global_fwd": 0.8220970630645752, "global_bwd": 0.6802082061767578, "x_quantize_rowwise": 0.06883963942527771, "g_quantize_rowwise": 0.25641173124313354, "w_quantize_rowwise": 0.033054500818252563, "w_quantize_colwise_transpose": 0.24027004837989807, "w_quantize_global": 0.0967271625995636, "w_quantize_global_transpose": 0.102948397397995, "time_standard": 4.622999578714371, "time_rowwise": 3.6326125264167786, "time_global": 3.5382024943828583} +{"repeat": 64, "batch_size": 16384, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 1.4877021312713623, "standard_gw": 1.5015341341495514, "standard_gx": 1.529306173324585, "rowwise_fwd": 0.715944916009903, "rowwise_bwd": 0.8529908955097198, "global_fwd": 0.680088996887207, "global_bwd": 0.8224695920944214, "x_quantize_rowwise": 0.2568177878856659, "g_quantize_rowwise": 0.06864592432975769, "w_quantize_rowwise": 0.03343448042869568, "w_quantize_colwise_transpose": 0.23645907640457153, "w_quantize_global": 0.09399279952049255, "w_quantize_global_transpose": 0.10286271572113037, "time_standard": 4.518542438745499, "time_rowwise": 3.665827214717865, "time_global": 3.5264119505882263} +{"repeat": 64, "batch_size": 32768, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 3.261040896177292, "standard_gw": 2.8816498816013336, "standard_gx": 2.8357282280921936, "rowwise_fwd": 1.6594752669334412, "rowwise_bwd": 1.359265297651291, "global_fwd": 1.6287527978420258, "global_bwd": 1.3503879308700562, "x_quantize_rowwise": 0.13146549463272095, "g_quantize_rowwise": 0.5035959184169769, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.24086236953735352, "w_quantize_global": 0.0945068895816803, "w_quantize_global_transpose": 0.10332837700843811, "time_standard": 8.978419005870819, "time_rowwise": 6.8106986582279205, "time_global": 6.693687289953232} +{"repeat": 64, "batch_size": 32768, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 2.848360687494278, "standard_gw": 2.8955675661563873, "standard_gx": 3.0499882996082306, "rowwise_fwd": 1.3900883495807648, "rowwise_bwd": 1.6595833003520966, "global_fwd": 1.3514049351215363, "global_bwd": 1.629263162612915, "x_quantize_rowwise": 0.5036592483520508, "g_quantize_rowwise": 0.13118237257003784, "w_quantize_rowwise": 0.03438442945480347, "w_quantize_colwise_transpose": 0.23709610104560852, "w_quantize_global": 0.0951625406742096, "w_quantize_global_transpose": 0.10216236114501953, "time_standard": 8.793916553258896, "time_rowwise": 6.851561367511749, "time_global": 6.708402186632156} +{"repeat": 64, "batch_size": 65536, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 6.4978525042533875, "standard_gw": 6.462603807449341, "standard_gx": 5.5987648665905, "rowwise_fwd": 3.2996535301208496, "rowwise_bwd": 2.6320070028305054, "global_fwd": 3.2426007091999054, "global_bwd": 2.612769603729248, "x_quantize_rowwise": 0.2561397850513458, "g_quantize_rowwise": 0.9984448552131653, "w_quantize_rowwise": 0.033076852560043335, "w_quantize_colwise_transpose": 0.24232640862464905, "w_quantize_global": 0.09618699550628662, "w_quantize_global_transpose": 0.10257214307785034, "time_standard": 18.559221178293228, "time_rowwise": 13.9242522418499, "time_global": 13.771317899227142} +{"repeat": 64, "batch_size": 65536, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 5.5702440440654755, "standard_gw": 5.717620253562927, "standard_gx": 6.08203187584877, "rowwise_fwd": 2.649586647748947, "rowwise_bwd": 3.315173089504242, "global_fwd": 2.6132799685001373, "global_bwd": 3.257807344198227, "x_quantize_rowwise": 0.9980201721191406, "g_quantize_rowwise": 0.256560742855072, "w_quantize_rowwise": 0.03356859087944031, "w_quantize_colwise_transpose": 0.23729726672172546, "w_quantize_global": 0.09495764970779419, "w_quantize_global_transpose": 0.103779137134552, "time_standard": 17.369896173477173, "time_rowwise": 13.207826763391495, "time_global": 13.04202526807785} +{"repeat": 64, "batch_size": 131072, "dim_out": 6656, "dim_in": 1664, "wm": 4, "switch": false, "standard_fwd": 13.058379292488098, "standard_gw": 11.480242013931274, "standard_gx": 11.092845350503922, "rowwise_fwd": 6.637874990701675, "rowwise_bwd": 5.24790957570076, "global_fwd": 6.521012634038925, "global_bwd": 5.214303731918335, "x_quantize_rowwise": 0.5057565867900848, "g_quantize_rowwise": 1.989319920539856, "w_quantize_rowwise": 0.03439188003540039, "w_quantize_colwise_transpose": 0.24280324578285217, "w_quantize_global": 0.09520724415779114, "w_quantize_global_transpose": 0.10240450501441956, "time_standard": 35.631466656923294, "time_rowwise": 26.138298213481903, "time_global": 25.908246636390686} +{"repeat": 64, "batch_size": 131072, "dim_out": 1664, "dim_in": 6656, "wm": 4, "switch": true, "standard_fwd": 11.13397628068924, "standard_gw": 11.371888220310211, "standard_gx": 12.12756335735321, "rowwise_fwd": 5.2495077252388, "rowwise_bwd": 6.638709455728531, "global_fwd": 5.215313285589218, "global_bwd": 6.5222084522247314, "x_quantize_rowwise": 1.9870512187480927, "g_quantize_rowwise": 0.5058236420154572, "w_quantize_rowwise": 0.034634023904800415, "w_quantize_colwise_transpose": 0.23674964904785156, "w_quantize_global": 0.09457767009735107, "w_quantize_global_transpose": 0.10183081030845642, "time_standard": 34.63342785835266, "time_rowwise": 26.024363934993744, "time_global": 25.798693299293518} +{"repeat": 64, "batch_size": 8192, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 1.2125298380851746, "standard_gw": 1.1111274361610413, "standard_gx": 1.0840706527233124, "rowwise_fwd": 0.6057210266590118, "rowwise_bwd": 0.51865354180336, "global_fwd": 0.5952082574367523, "global_bwd": 0.5167685449123383, "x_quantize_rowwise": 0.045686960220336914, "g_quantize_rowwise": 0.15827640891075134, "w_quantize_rowwise": 0.04361197352409363, "w_quantize_colwise_transpose": 0.34067779779434204, "w_quantize_global": 0.13644620776176453, "w_quantize_global_transpose": 0.14925003051757812, "time_standard": 3.407727926969528, "time_rowwise": 2.823755145072937, "time_global": 2.7127638459205627} +{"repeat": 64, "batch_size": 8192, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 1.0731369256973267, "standard_gw": 1.1365897953510284, "standard_gx": 1.1498592793941498, "rowwise_fwd": 0.5573518574237823, "rowwise_bwd": 0.615488737821579, "global_fwd": 0.5220361053943634, "global_bwd": 0.5939789116382599, "x_quantize_rowwise": 0.15765801072120667, "g_quantize_rowwise": 0.04369020462036133, "w_quantize_rowwise": 0.047359615564346313, "w_quantize_colwise_transpose": 0.5526281893253326, "w_quantize_global": 0.13606995344161987, "w_quantize_global_transpose": 0.15017390251159668, "time_standard": 3.359586000442505, "time_rowwise": 3.1107664108276367, "time_global": 2.7401968836784363} +{"repeat": 64, "batch_size": 16384, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 2.4274885654449463, "standard_gw": 2.1799951791763306, "standard_gx": 2.1426528692245483, "rowwise_fwd": 1.195710152387619, "rowwise_bwd": 1.027170568704605, "global_fwd": 1.1747106909751892, "global_bwd": 1.0251589119434357, "x_quantize_rowwise": 0.08098781108856201, "g_quantize_rowwise": 0.3052949905395508, "w_quantize_rowwise": 0.043764710426330566, "w_quantize_colwise_transpose": 0.33987686038017273, "w_quantize_global": 0.13646483421325684, "w_quantize_global_transpose": 0.14739856123924255, "time_standard": 6.750136613845825, "time_rowwise": 5.172800272703171, "time_global": 5.050010979175568} +{"repeat": 64, "batch_size": 16384, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 2.1661892533302307, "standard_gw": 2.0948275923728943, "standard_gx": 2.306375652551651, "rowwise_fwd": 1.0587647557258606, "rowwise_bwd": 1.1999905109405518, "global_fwd": 1.0296404361724854, "global_bwd": 1.1749230325222015, "x_quantize_rowwise": 0.3054030239582062, "g_quantize_rowwise": 0.08077546954154968, "w_quantize_rowwise": 0.047225505113601685, "w_quantize_colwise_transpose": 0.600133091211319, "w_quantize_global": 0.13613328337669373, "w_quantize_global_transpose": 0.1484006643295288, "time_standard": 6.567392498254776, "time_rowwise": 5.387119948863983, "time_global": 4.97010350227356} +{"repeat": 64, "batch_size": 32768, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 4.807606339454651, "standard_gw": 4.170913249254227, "standard_gx": 4.117622971534729, "rowwise_fwd": 2.370934933423996, "rowwise_bwd": 1.9481778144836426, "global_fwd": 2.3383721709251404, "global_bwd": 1.9443817436695099, "x_quantize_rowwise": 0.1547597348690033, "g_quantize_rowwise": 0.6000511348247528, "w_quantize_rowwise": 0.04361942410469055, "w_quantize_colwise_transpose": 0.3403201699256897, "w_quantize_global": 0.13600289821624756, "w_quantize_global_transpose": 0.1474134624004364, "time_standard": 13.096142560243607, "time_rowwise": 9.628776460886002, "time_global": 9.491894394159317} +{"repeat": 64, "batch_size": 32768, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 4.1619837284088135, "standard_gw": 4.181284457445145, "standard_gx": 4.635505378246307, "rowwise_fwd": 1.9684135913848877, "rowwise_bwd": 2.3750364780426025, "global_fwd": 1.9445866346359253, "global_bwd": 2.3551955819129944, "x_quantize_rowwise": 0.6004162132740021, "g_quantize_rowwise": 0.15468522906303406, "w_quantize_rowwise": 0.04730746150016785, "w_quantize_colwise_transpose": 0.5999617278575897, "w_quantize_global": 0.1364201307296753, "w_quantize_global_transpose": 0.14847144484519958, "time_standard": 12.978773564100266, "time_rowwise": 9.927105158567429, "time_global": 9.521059691905975} +{"repeat": 64, "batch_size": 65536, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 9.52371209859848, "standard_gw": 8.354485034942627, "standard_gx": 8.69860127568245, "rowwise_fwd": 4.717472940683365, "rowwise_bwd": 3.8843750953674316, "global_fwd": 4.645414650440216, "global_bwd": 3.8761012256145477, "x_quantize_rowwise": 0.3024861216545105, "g_quantize_rowwise": 1.1897757649421692, "w_quantize_rowwise": 0.04366785287857056, "w_quantize_colwise_transpose": 0.33988431096076965, "w_quantize_global": 0.1359507441520691, "w_quantize_global_transpose": 0.14724582433700562, "time_standard": 26.576798409223557, "time_rowwise": 18.832147121429443, "time_global": 18.651459366083145} +{"repeat": 64, "batch_size": 65536, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 8.307881653308868, "standard_gw": 8.214320987462997, "standard_gx": 9.21182706952095, "rowwise_fwd": 3.8919784128665924, "rowwise_bwd": 4.72346693277359, "global_fwd": 3.8761794567108154, "global_bwd": 4.673641175031662, "x_quantize_rowwise": 1.1893920600414276, "g_quantize_rowwise": 0.3024972975254059, "w_quantize_rowwise": 0.04708021879196167, "w_quantize_colwise_transpose": 0.6039328873157501, "w_quantize_global": 0.13624504208564758, "w_quantize_global_transpose": 0.14867261052131653, "time_standard": 25.734029710292816, "time_rowwise": 18.972668796777725, "time_global": 18.540948629379272} +{"repeat": 64, "batch_size": 131072, "dim_out": 8192, "dim_in": 2048, "wm": 4, "switch": false, "standard_fwd": 19.30372044444084, "standard_gw": 16.480475664138794, "standard_gx": 17.61433482170105, "rowwise_fwd": 9.49602946639061, "rowwise_bwd": 7.768530398607254, "global_fwd": 9.3533955514431, "global_bwd": 7.749464362859726, "x_quantize_rowwise": 0.5977451801300049, "g_quantize_rowwise": 2.3684948682785034, "w_quantize_rowwise": 0.04375725984573364, "w_quantize_colwise_transpose": 0.34042075276374817, "w_quantize_global": 0.13628974556922913, "w_quantize_global_transpose": 0.14671683311462402, "time_standard": 53.398530930280685, "time_rowwise": 37.09545359015465, "time_global": 36.83258220553398} +{"repeat": 64, "batch_size": 131072, "dim_out": 2048, "dim_in": 8192, "wm": 4, "switch": true, "standard_fwd": 18.041003495454788, "standard_gw": 17.770148813724518, "standard_gx": 17.70009845495224, "rowwise_fwd": 7.756810635328293, "rowwise_bwd": 9.502101689577103, "global_fwd": 7.7384114265441895, "global_bwd": 9.36170294880867, "x_quantize_rowwise": 2.3686252534389496, "g_quantize_rowwise": 0.5980581045150757, "w_quantize_rowwise": 0.04723668098449707, "w_quantize_colwise_transpose": 0.6035342812538147, "w_quantize_global": 0.13603642582893372, "w_quantize_global_transpose": 0.1485198736190796, "time_standard": 53.511250764131546, "time_rowwise": 38.64651545882225, "time_global": 38.121502846479416} +{"repeat": 64, "batch_size": 8192, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 4.598241299390793, "standard_gw": 4.294309765100479, "standard_gx": 4.261095076799393, "rowwise_fwd": 2.0976848900318146, "rowwise_bwd": 1.9718967378139496, "global_fwd": 2.0763762295246124, "global_bwd": 1.9703581929206848, "x_quantize_rowwise": 0.08216872811317444, "g_quantize_rowwise": 0.4405900835990906, "w_quantize_rowwise": 0.1553371548652649, "w_quantize_colwise_transpose": 1.6110725700855255, "w_quantize_global": 0.481240451335907, "w_quantize_global_transpose": 0.5061514675617218, "time_standard": 13.153646141290665, "time_rowwise": 10.653059929609299, "time_global": 9.85119491815567} +{"repeat": 64, "batch_size": 8192, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 4.35885414481163, "standard_gw": 4.29583340883255, "standard_gx": 4.5370906591415405, "rowwise_fwd": 2.0015686750411987, "rowwise_bwd": 2.097565680742264, "global_fwd": 1.969795674085617, "global_bwd": 2.075403928756714, "x_quantize_rowwise": 0.43984130024909973, "g_quantize_rowwise": 0.08216127753257751, "w_quantize_rowwise": 0.22544339299201965, "w_quantize_colwise_transpose": 2.4342015385627747, "w_quantize_global": 0.48087164759635925, "w_quantize_global_transpose": 0.5099289119243622, "time_standard": 13.19177821278572, "time_rowwise": 11.576615273952484, "time_global": 9.85383614897728} +{"repeat": 64, "batch_size": 16384, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 9.09888744354248, "standard_gw": 8.230950683355331, "standard_gx": 8.465446531772614, "rowwise_fwd": 4.182614386081696, "rowwise_bwd": 3.747660666704178, "global_fwd": 4.138719290494919, "global_bwd": 3.74777615070343, "x_quantize_rowwise": 0.15515834093093872, "g_quantize_rowwise": 0.8699297904968262, "w_quantize_rowwise": 0.15544891357421875, "w_quantize_colwise_transpose": 1.6132444143295288, "w_quantize_global": 0.48100948333740234, "w_quantize_global_transpose": 0.5051903426647186, "time_standard": 25.795284658670425, "time_rowwise": 18.955007195472717, "time_global": 18.128734081983566} +{"repeat": 64, "batch_size": 16384, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 8.378107100725174, "standard_gw": 8.923027664422989, "standard_gx": 9.049762040376663, "rowwise_fwd": 3.765825182199478, "rowwise_bwd": 4.183519631624222, "global_fwd": 3.744799643754959, "global_bwd": 4.1590481996536255, "x_quantize_rowwise": 0.8693933486938477, "g_quantize_rowwise": 0.1553073525428772, "w_quantize_rowwise": 0.2258792519569397, "w_quantize_colwise_transpose": 2.4386271834373474, "w_quantize_global": 0.4811100661754608, "w_quantize_global_transpose": 0.5102269351482391, "time_standard": 26.350896805524826, "time_rowwise": 20.5615796148777, "time_global": 18.842913210392} +{"repeat": 64, "batch_size": 32768, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 18.266115337610245, "standard_gw": 17.671160399913788, "standard_gx": 17.10302010178566, "rowwise_fwd": 8.347474038600922, "rowwise_bwd": 7.514089345932007, "global_fwd": 8.263226598501205, "global_bwd": 7.487393915653229, "x_quantize_rowwise": 0.3021806478500366, "g_quantize_rowwise": 1.7319358885288239, "w_quantize_rowwise": 0.15519559383392334, "w_quantize_colwise_transpose": 1.6133114695549011, "w_quantize_global": 0.48247724771499634, "w_quantize_global_transpose": 0.506427139043808, "time_standard": 53.04029583930969, "time_rowwise": 37.3353473842144, "time_global": 36.44480183720589} +{"repeat": 64, "batch_size": 32768, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 17.73649826645851, "standard_gw": 16.359902918338776, "standard_gx": 18.0993489921093, "rowwise_fwd": 7.493957877159119, "rowwise_bwd": 8.352488279342651, "global_fwd": 7.486194372177124, "global_bwd": 8.28903540968895, "x_quantize_rowwise": 1.7313472926616669, "g_quantize_rowwise": 0.30205026268959045, "w_quantize_rowwise": 0.2255477011203766, "w_quantize_colwise_transpose": 2.4363920092582703, "w_quantize_global": 0.4815347492694855, "w_quantize_global_transpose": 0.5103759467601776, "time_standard": 52.195750176906586, "time_rowwise": 36.90168634057045, "time_global": 35.16044095158577} +{"repeat": 64, "batch_size": 65536, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 36.309611052274704, "standard_gw": 32.85098075866699, "standard_gx": 34.34552624821663, "rowwise_fwd": 16.74525812268257, "rowwise_bwd": 15.026237815618515, "global_fwd": 16.574162989854813, "global_bwd": 14.977734535932541, "x_quantize_rowwise": 0.5954466760158539, "g_quantize_rowwise": 3.4569576382637024, "w_quantize_rowwise": 0.15521422028541565, "w_quantize_colwise_transpose": 1.6133897006511688, "w_quantize_global": 0.4822872579097748, "w_quantize_global_transpose": 0.5065612494945526, "time_standard": 103.50611805915833, "time_rowwise": 70.44348493218422, "time_global": 69.44413110613823} +{"repeat": 64, "batch_size": 65536, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 35.40017828345299, "standard_gw": 33.037226647138596, "standard_gx": 36.30436211824417, "rowwise_fwd": 15.043705701828003, "rowwise_bwd": 16.756191849708557, "global_fwd": 15.011314302682877, "global_bwd": 16.580048948526382, "x_quantize_rowwise": 3.4548528492450714, "g_quantize_rowwise": 0.5951337516307831, "w_quantize_rowwise": 0.22584572434425354, "w_quantize_colwise_transpose": 2.4329908192157745, "w_quantize_global": 0.4813261330127716, "w_quantize_global_transpose": 0.5101598799228668, "time_standard": 104.74176704883575, "time_rowwise": 71.54594734311104, "time_global": 69.67006251215935} +{"repeat": 64, "batch_size": 131072, "dim_out": 16384, "dim_in": 4096, "wm": 4, "switch": false, "standard_fwd": 73.40333238244057, "standard_gw": 73.76311346888542, "standard_gx": 70.41774317622185, "rowwise_fwd": 33.37597846984863, "rowwise_bwd": 30.345775187015533, "global_fwd": 33.00366923213005, "global_bwd": 30.218638479709625, "x_quantize_rowwise": 1.1825822293758392, "g_quantize_rowwise": 6.902601569890976, "w_quantize_rowwise": 0.15529245138168335, "w_quantize_colwise_transpose": 1.6109198331832886, "w_quantize_global": 0.48149004578590393, "w_quantize_global_transpose": 0.5066059529781342, "time_standard": 217.58418902754784, "time_rowwise": 147.33626320958138, "time_global": 146.05870097875595} +{"repeat": 64, "batch_size": 131072, "dim_out": 4096, "dim_in": 16384, "wm": 4, "switch": true, "standard_fwd": 71.5160183608532, "standard_gw": 73.76786693930626, "standard_gx": 72.98104092478752, "rowwise_fwd": 30.291248112916946, "rowwise_bwd": 33.36654230952263, "global_fwd": 30.181586742401123, "global_bwd": 33.082425594329834, "x_quantize_rowwise": 6.902430206537247, "g_quantize_rowwise": 1.1815279722213745, "w_quantize_rowwise": 0.2262219786643982, "w_quantize_colwise_transpose": 2.4421699345111847, "w_quantize_global": 0.4816502332687378, "w_quantize_global_transpose": 0.5105249583721161, "time_standard": 218.26492622494698, "time_rowwise": 148.17800745368004, "time_global": 146.1080126464367} diff --git a/benchmarking/switchback/make_plot_with_jsonl.py b/benchmarking/switchback/make_plot_with_jsonl.py new file mode 100644 index 000000000..8897564e7 --- /dev/null +++ b/benchmarking/switchback/make_plot_with_jsonl.py @@ -0,0 +1,138 @@ +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +import os + +import matplotlib.gridspec as gridspec + +cmap=plt.get_cmap('cool') + +if __name__ == '__main__': + + fig = plt.figure(tight_layout=True, figsize=(12,3.5)) + gs = gridspec.GridSpec(1, 2) + + dims_to_consider = [1024, 1280, 1408, 1664, 2048, 4096] + batch_size_for_plot1 = 32768 + batch_sizes_for_plot2 = [2**14, 2**15, 2**16, 2**17] + dims_to_xtick = [1024, 2048, 4096] + logscale_plot1 = True + + ax = fig.add_subplot(gs[0, 0]) + + # TODO: change this to what you want. + rdf = pd.read_json('speed_benchmark/info_a100_py2.jsonl', lines=True) + df = rdf[rdf.batch_size == batch_size_for_plot1] + + # first plot the time occupied by different operations + for k, marker, ls, color, name in [ + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (sum of parts)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (sum of parts)'), + + ('standard_fwd', '^', '--', 'C2', 'Matmul XW (standard)'), + ('standard_gw', '^', '-.', 'C2', 'Matmul GW (standard)'), + ('standard_gx', '^', ':', 'gray', 'Matmul GX (both)'), + + ('global_fwd', '^', '--', 'C4', 'Int8 Matmul XW (switchback)'), + ('global_bwd', '^', '-.', 'C4', 'Int8 Matmul GW (switchback)'), + + ('x_quantize_rowwise', 'P', '--', 'C4', 'Quantize rowwise X (switchback)'), + ('g_quantize_rowwise', 'P', '-.', 'C4', 'Quantize rowwise G (switchback)'), + ('w_quantize_global', '.', '--', 'C4', 'Quatnize global W (switchback)'), + ('w_quantize_global_transpose', '.', '-.', 'C4', 'Quantize gloabl and\ntranspose W (switchback)'), + ]: + xs = [] + ys = [] + for embed_dim in dims_to_consider: + # average over dim -> 4*dim and 4*dim -> dim + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + + + ax.plot(xs, ys, color=color, label=name, marker=marker, markersize=5 if marker=='s' else 5, linestyle=ls, linewidth=2 if '+' in k else 1.) + + + ax.set_xlabel('dim', fontsize=13) + ax.set_ylabel('time (ms)', fontsize=13) + + ax.grid() + + ax.set_xscale('log') + if logscale_plot1: + ax.set_yscale('log') + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks(dims_to_xtick) + ax.set_xticklabels(dims_to_xtick) + ax.set_xticks([], minor=True) + + leg = ax.legend(loc='upper center', bbox_to_anchor=(-0.64, 1.), ncol=1, fontsize=10) + leg.get_texts()[0].set_fontweight('bold') + leg.get_texts()[1].set_fontweight('bold') + plt.subplots_adjust(left=0.1) + ax.set_title(' Linear layer, batch * sequence length = 32k', fontsize=10, loc='left', y=1.05, pad=-20) + + + ax = fig.add_subplot(gs[0, 1]) + + # now plot the % speedup for different batch sizes + for j, batch_size in enumerate(batch_sizes_for_plot2): + all_xs, all_ys = [], [] + for k, marker, ls, color, name in [ + ('standard_gx+standard_gw+standard_fwd', 's', '-', 'C2', 'Standard fp16 (total time)'), + ('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'), + ]: + + xs, ys = [], [] + df = rdf[rdf.batch_size == batch_size] + for embed_dim in dims_to_consider: + df_ = df[df.dim_in == embed_dim] + df_ = df_[df_.dim_out == embed_dim * 4] + xs.append(embed_dim) + y_ = 0 + for k_ in k.split('+'): + y_ += df_[k_].values[0] + df_ = df[df.dim_in == embed_dim * 4] + df_ = df_[df_.dim_out == embed_dim] + for k_ in k.split('+'): + y_ += df_[k_].values[0] + ys.append(y_ * 0.5) + all_xs.append(xs) + all_ys.append(ys) + + color = cmap(j * 0.25) + real_ys = [-((all_ys[1][i] - all_ys[0][i]) / all_ys[0][i]) * 100 for i in range(len(all_ys[0]))] + markers = ['^', 'v', 'P', 'o'] + ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5) + + ax.legend() + ax.set_xlabel('dim', fontsize=13) + ax.set_xscale('log') + ax.grid() + ax.set_ylabel(r'% speedup', fontsize=13) + + + ax.tick_params(axis='x', labelsize=11) + ax.tick_params(axis='y', labelsize=11) + + ax.set_xticks(dims_to_xtick) + ax.set_xticklabels(dims_to_xtick) + ax.set_xticks([], minor=True) + + ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20) + + + + plt.savefig('speed_benchmark/plot_with_info.pdf', bbox_inches='tight') + diff --git a/benchmarking/switchback/plot_with_info.pdf b/benchmarking/switchback/plot_with_info.pdf new file mode 100644 index 0000000000000000000000000000000000000000..d186e91b7d96c6e605fd2802ee37881e6294cdd7 GIT binary patch literal 34876 zcmb@u1yq(x*8mFA-CZx;`O>9ycSv_iN~hA@4T4A`A*Cptl2U>+NOy;Tpxo!hbN=uA z-{HTy?z+o0%)HM$Gkee8`!-4rKTbbU6RL00`a?>|HDY?0%uou5Rh-4!0@JvH1b*Qu=IH1I{NjX+ z`6pE%-G7r&(bCb{%?8Bz7iei)dtl)pc4>Qn8j_agP8OCB{9WB#EKMDdy|TX;%DIyD zCbqU2kEKVxqWff}HTf(fm_N%j#@z4KM?&V#6KGgUWG*N+sG;iTKm2qT40E&=s=oTt zAgZG#1}aY^4L;7)Eq{6Ad>lV`8vXF0@9dAw_$K<(sjGcu&$i{mU1fLRZ?CfF$kD~Q z^75_^!MlB!&?PF-sx^AuIhh2_tL>?UgY<>uL$xdX(A<5GHFKw&FKa^OzUAi5HlJ;G zV&9EzXZp~-wh%uQ!~R0Aez0Q5cPmZev}x6sB}#|ao;X%B-hUzZ#AV?+`ha~w>0-^- zY@_tryw{Am`bUjMiwd%oRs??S)NEdYk9F#aS`?@h$GVBabEL^r0|OcPOT}m&0yoOQ za?hwX@p+9ea=G7pOIKT&Z^bCFH(%N~I`MIfOaH6AtMh~H=!N9l$(+fBXye*tqvqKk z?|cjKW)2(M+}VGAxUM2-+RW?-*T|78<71rBkz+0`DLpVTe^CbRh+mJPc1<~uow+=B z+e|s=&l}Q=&Cc|RIzZ>Ge=S7PrOGBq65o}h)vvIil@`&$OXri!n)tmC)lQz zvd^blE*&(~@ zz&H}@aS@p#(dn|@^{`n!#vo|R`SCTJV#Lvr!q|Ax@5BBxuW^yvEZ#36Cj)fGb8d9G z_8K$!)wW+&J{Pm>vouY%zMm8z>_1^DAaxSH(bXM2~EyP-hPOq}^G*_TI4jl0k?ZJ;gp-kC4eEz8v%M%3`d^zoQ6~U1D z8%c*~p`hpVoA!p!=-8R~(J_(b+t_#3i|l*&Zg(A@`q;xyI=| z&ggf5i{Z(pIw|~Sq*6;WHv2KI@O<<`afLziuj~kMfGIark( zr)wWYdzRd5wOz@Hmp&j39N89emyZx{YKi+DqMAj*>{KE@E8^G?hJF&92*Xu7O;Uu! z_T^0~Pci+6zyxj=@<-M0djt!&vi68c7y>NXv7?<%Bw5cH8 zUAkJh!dnP-%)LTi!NQ|9%W*hDlnv~%AQpCcfbaeC^Aaa!h8WzFBSNJ9B-a@S1>qe; zG`okF8%PNyaralHROUq-Nn@EPL~0LA z%A&R?yxc_UDB_`GVhDPk^XltXj_zu{7*+9=ECFs9gBud*kRRd-UL5I-dF2c6Jh zTutF!JseZO-%iwuB@f|krb*aNMPRe7<=PM5Fjv+L)==GA`)-4iJ-NPOnwl7z>jqyd z_+gMo8vhwrCsKGov#yKF$d+Vn;@InlI0a2`(t*AN;%+?BrlhO25)K?u5&M`chQVsl zxMbhyzM4$X5_X1U=p|4!J7-=f{mHmBNKJj5BWE|4l@KTVfn<{DW=fM;3g2J$fWW z-CY7G#cU%rlt&?E(kZq%G>YUcBa788Nyy1p&E(Ti98}oE;0h`ZD7ZG(fwd~$ABOm_ zYyz>WwJ%h~Z)HzXFaWp=-qo>`QSC}rMb0m9nQ*%a;mjtdZSq23VF$FD6YD#BKhH@? zbadquGulk>k%9&{7kx~M!HNXhZ%rE{$X^C0&U)4u zq`E0#Pt6JYjgD?kuoKfqD&oo42jX>SVUNq|&Tn|d(VjUE(7jjp2~;)fDo1+tjO0}W zI{s-P4;!&xai79MpTs^f4vJqEAuipnrS$r9f?*ux7~xuLLul}cG_RueYSMcfJa*5C z>X^u@YKiHWYMCtTU$&cxXvO-=pJm9G9l47MAfa#KORt5%WN7wKJ7wOeOvlSTl6bK- z_ACORud2FJ&_=+4ZE8GnYKGAhcD0x)%)6l4Wqb(CjvS7CZAzvQxZ18D?2s+_&#mXAyONSh4U2WML-Vq+HkUQ^xU-M^gB}bG|TkS z?g|kviKnr#aNOQj8{Rfm&NU%)(z4=%>YYxD#)2d2g*sJn=3$-V9aaw6gT!3ClL@dh zG6e<|l-nXf7K*QRBPNjRNI87CXbv&rq^v}StUo{F77ZgzAa70b%uXwsn~4wox*B@R zXsLL;d`YN0O$SA1N>$$@`@%#p9n}R<36N6gd2w)*eYLHgxJn!K;k`z7ZC?{Ekr)CNx_!?KzuD znd)MGhVtt@5g6)`5VrW%&Llf}>#8Z`;ru;DYp2~63F_e|*l5@lb! z)*P0bnT;rpHO|THr+yr+4}F}uKG@mIB{;ktZ4qt~x_x(jaqw+rAw2rUY4pL)%EtLZ z)xvYPu%B$6wz+;zVQRwU#$j>cU4u;%8sX4|bp8*-;%Z?SCYV*1AGY+IxffIEL??-g zr>SC%>m2#fms(6XTj(ZTo*~|DOFnGg9`ZeLWa|kz{d9D6`i1Qk_s@_-%#$j4yRMRf zcbA);V_%MRKd=qqE;ChTpW5z}H2HW|UuQ2osm$c5Z5f2Oe)kyGt19gXN@yTsKq6VP zCF=th-trTw5OKotPazerU6$;HpU}m>vpN+s&=SkRSG_$x*xUFrbNzUM-3cq?`7^ic zX!7dw5H zV}rqM7g2kGX$X@UA*AiwzX3I+o3^Ykug_2*DhaOpESav}0ru>~uOsZViFFv+G_CXO zV=$IKqyST}V`02}nRDCH)Z7sLqF|%7KIL?4tp8lktM)`Bj`a8k)SN9fO0#B~?f~zX z$gsn)=*YtYKYXhViK`!DS1Q(?QxtC+bANV5QOr+ylC2n^Q^gy-``$L9G~R)n7|2Q; z^-3`SV~-lu2SUw_eLjNu(O64B)6Cy610L60|_A{U-? z0%fO1O_EWg7=zaM{%OAAKzhpRTc<8f^S1A|;r^cMinYeik=mYRV_d!Mzcj%R5Ir3_XV|=UP%-kyl|x|&BOdM5JfnLr8K@M zlF119KsaSieV6+bmLAwNB#p*z`O9P?jA*7{-P$E-1fHgLsuJnQqb^dF<30R{h~_u_ z+yn{zai=%fY`{j80x9TA$ zh9N8|a%is9$@v?@h2qxrR$oT;6maB($T`*G`Uvo+-lVUtAB6P|?w0ehkGyF_7_NEn zS(m=$NRco+ibl9k^?3bR11S^Oy~ko{JSVLRXLt+EPmM0IGF*O%DE8Qql;m=4oPj$w zRkz2sBEYO*9Soa~6{gn8t0icjOO{d=j$Dd+Rt+msgOSC$mwcHbA%)A7fg6i=roDil zqOym!YJ&gDj{r+%pdzFpHov3+{CYIAQujRc$AIMLI2qJb6W`gFktyv#fiB-p%DQg6 zzq-QBgEsP+$TH^VTOKww_+M{)GjpAYjW6ii)Sh^2zAVh4*v*I49j3G<8FD^XFtdG#{CDLB?)hwX@tKOM#U z%Mtji{BfWJ>2X7ChEt;dfIcGz|IbjEI=_^e&B3`k1m#opz`uuU{h4QP{O+YBGH^F6qzOo8uIl;tV*`!4;O23v76VZ*=iWZM| ze-VBrYr&xuiy4rQqVMD8zNq;U`YAjPZw^&)dT+P=5ZYKZmcy{ zcBo-RJ?@^Ht-4cl%%hupTuOwnv>UzSjR;{)6Q#FfWgi9<{AqEDL{7S1`RiEvd&@5- z(zcQjTO6|o5=9J^U@8Wkd*|2@FNPINVlpF{ zp7SaMyI9cAU#p2&-p>tDB4=mZ;+nQJ?#uyi4x@>xq}G<}yr%$x^W@m307sppNxRwm z9I80Zj~VJM;pD*7hgE6=tko%>&sM1~h`qKxK{0bGhA*E_4k8d9QA$tK5tn7$QnlmA z@bOxAx7XFEs5-YU=tE-^@%YxjdquM62&vaY8$IJ0%TjGa zse2rR?4|$l8%$Lnl3dAdvd<|kjG~UI{JvS%0ozV&FDkit%vDNTrIOt`$yOf!*QGj- zZ}KGU#79Xkqr|7C#`EWs$?)BLD{84<44>XgRU=rZoFuiH_nHK|aMrR5g@|xn`PEhB z7H4g0!-z7sJ4tu9o4drmz;|C;_U=$@ zXstQ_4t*$8a^n_Xi^nHpT%%*GYply1zMZzaKQ(yn(Hiyi?1{gjSEE7AdKUw3c+r>?$N8ynv~_~IQIh&&V+*Jan1>0)@tt{x6s z<#;91kV&jgJQ!*e`xqCQF)no2Gqk4!`OM2r9h0(N&W}#)tb<@CJU=uu(@oS>1gGke zWUimJaoR_y7xHP(V27gVWXzUtG7Y6&bx?h{YWgg0DX- z%yys}b4#|65PLGNhaU5}>Ggiy#uJ0CqfBC^{vSa6!MI;G+PZ4Wj+wXmrHg#KKlg9&BI39 zqb3>CU$)24;#PZ9BVR%E0uFzz0D7&fZ?I$uA+q@R`p8wh=kd|1L9RN=t4sz=n#hZ^ zf=E#`o-UFD$@|CP7Ph4F@8~8GM8jWDUfcu`sb&f?s${Qc(y$kv zR-ilND{H2%iOFj6W#($s1vKyFVC^2|wJg1D_e!&&L+Yn}I+0+nV&ihq^x4N$gU}|e zmxY$EZcV>eePDXovO$7gj<@rrn7a6v?`sFQ(2975fe=gbU)ImRY|3}0Qi%DK2h0op z+t7SxV}(4_2O0kT3=FZU{+FGViw_JKSASz<6-iV`>f%6@yby$w!Caw!N?xK~=X}tX zfVK=iIkG&V`RfoKAc!% zsH;!%wFA);YO(|@w2sgW6vuE#l@xTsk16wgB6N;_swFUR30{L;#P_L@h@)AFj`ZAA znhh*P^d=Aax0Q>ula@pn8;ad9YM!EueCxKX+9dz_0}2M_87@g4^21tAO?1Ba z8mJ-DOdV(CUBaG^?}DDY^Esrd%c)e_t$91Q;P%BGy-PrbJeaV4T zRqXTNGSY`Ccxvpp);q8#{5yi8DpqFf5^uU1v2rvXs$7z-c}ge_CkqbY&|ABAFzUSf zN~}OHUO(N?_u?pLLO$nA%mbuNITKqLoun+N#@1jUshsM7=-n1N-B`Cx+i9)z@LZdq zD!~?^R#mpXcADMEzB9`!--qcv7YU8EF@hrkm!ktehdLLfUTha#`9d-0d92)AYxN2T z5{pCrbT-w%FuTz)P(bLOSO(`@TakoNhj4)&XutCy!tODaDn-L&jcPd1xGke#JNq}*7oE}de@haMcl_0pX*C#+R_M{#86)G!!CQh zrmjq&z5L8G3fC5VV}DbDD-l8cP0WUh3vS$6`JM)~zn65S5Gp&!GUeQGAU9S%3 z!z1L3L?^yk$9H1IHeuvB)oD*?xcD&)wC9qA39%99=X^z9Q4KE zeDd@aJzt6>#ylT7q!)Fo``HfM3y*Z(Kf=((2GJye7E`2IyT z{y(5vyn-VPCm_16$c7+5boWjbRowuxk*E(tWkw{>p1gJ&NRah@J?NI4C9&uur5vey z@-=FxZP!~gD%DA)QYkQ;Au`d7%}AquKE2>v{%LF+=Ud&w!Ja2Ykr)rjxahxMYAker zey3A`>VPPNZPZm;Qm{30?nc&cVi?3Ja}kdZC2uR|E=?S^h>M_l1}*dj)ODiG2E@L<#VU=;SvtH9(x#ruT=UcCCU}_h& z&v;?kR}&fUR5M}DLljto^EDwa=9TmUA#_wwRH>kNpM##!10+H$N>&Q~+$O`Eq z5#fE#AXFyPp)qOPGFqlDZ9nFxzFL-f8E+}&F^2_(`&G+V#Il5#D|Z;L#L|(CI3s@& z>72iHmhX${S!>?)|{`((hsnj<@71qU3RM z)V)ESMwnt4q7p9X$FLNjFijwmu|LdE7C51h3nV0N1@rK{Wshh>`II2^zDojy zT!cK8BX|~hW)z<0pzMozhSS5n{ZkBI;{45f82e8b`rpoC@E>gDO7UXYFqDX}hs{Cs zP(dFKXEdK+Y}`)8WbybuZ6qxsX7A|qS#tIqX1Y;tJO;n^aV}Z$PM#^G3dF64(M?gQ zg%`I|YLFjBosWeNLCR_F3ns0)w!fv3O>K1Lah|XK8Gj(g`b=}VqN|2yvNSu3W7NaB zbJj%IgSqi!e$JzD>jzvMHTd&A3%bwyEx`GENtH?vbLrwhg*o(rd-*64MBCPzql~eA zTLqm;Co&yPX+b=ih2B+cgxAf~HATuw{h&>Th&;IAq2f=Z+foh$*KO*kNZDRidd$H` zMiFHBlSaLfa9MP%Vp2;iB28rD@Q1D*^E&Qc@e#W`nJ#q8f>=#!UbeUW)Wf^g)-HvD z-!@1Vj-GDZM!nP9hpk#%lJ$g@TA#>GPaz(|)?KhBU!+d0*J_ee6Z+g}s>D;l|6X3- z%XnMlSEblx4J&wL#xR3Nt*;{+ z#MuaBQM(7A>rex@b!^cUrO+S4M<9KMS3@-u!+k-NB{9)KgxP@lh8v0|1;6JWjQ$xN z@t3IabNwD8o}nhRb`C_i_Lq!=9nd$`7m@k6-+nr9P@xI_e6-oqMFq+)IQCbwO+L%i zW8?{8IdLR^@U{?#<A< zY`Vr1?y%Yv5#dSg`gB``a%08%+cB}VU$T_GlO!I4N<=CcxFQvt z6eqZGw#bFqEkuYe<1LaTl}ujlgDPYR7MZBBaR>R#$eSHD@V)YU&BUB6yu>Q|2og4> zd&m=@?cOteTY9>nIPpqnPV7P$-HjX}sM6!%(-yLb`NY?A&?R5dr(K0iZf3<+pN>`t zmH4%DKN9?Dx$|O~To2lxI#}LmvAd~`eIC21bK%-$z~ig>k9ld1OyO=)JbQJ%s&VOO0V zM_OgWp^d4DHq>&|d_(-TOS?UK16*;r{bA=zQPZNQY@RM#%SZ+P>&EFx<>-1|3x%(; zml_ONCS?gFK5KNi4l-7bUO)Q%SA9LAa%WywE=crRIVObg@sViA4o(z5U+j9{@*LXn znbG+@T)jtiuj%`E+67lO0 zWksHa5@-z0T_=~e_UdUO zN=x2_<9D79vR|<31koz(14}xY-^+2pxrY`ARyjYKAm34R6SdO~g zO)bmptLi*$I9Xd7*di+9DI*ra7eZT8ZNa9nOtM4GBEr#&Zl$$F{fwp3kXTT?7( z?vBDjB9D?A8cZ!@=J>c=37O%m*7w}Nk8y7-3$Qa>#R)(4Crg*@UT6dul{%!VAxogs zM$(soH!CZ;SAW89)VCRbm@OB1=;Xl9^6zEx+@mppf$#pe_B@^>IY&uOz@8@xM_|*h zNIra8ppBn%?Q{wnK(=>Q)8cm;Na6N=tG7QT zW1&v!(?j?a0x)iURpBe)CH@jp-J(dL{LU+|V}x14gG#HI#r#=1PCYD#JW1IHkFj)j zavE|EjtrW0wM>bgFg|@+ zlI+`aAERBnbu`-|UaJ|#ugo3%fjQ*rz6uA@_EhJoRJkEyKJ8WXL7WquuC4Ys!Re8F zI|z>UHBj_>92vw18k~_&Ic36`ncO4c55~~a*q*93s1y#oWT~OYCMCgB(v;6jU~%S0 zkda1d>{l#r$6#}`Deoplj^wq$;8^Y7MxxGKxA5VNbPfEO*DFkx)iLK6w8;{H&L6)4 z>t7@wn)?8ANad;TPjwRkY6}xTqaE1-K2z6t#z=GC;T~#k^o!>p^4%XOhD-kUY<@iV z=+2yfus48tYg9P`x^N|c`vN%;%sr+W;Eed!4Tcjiy8rfjl!_Nam!Jf`$Lt_#M6P;mYbT+;+pi(~BQeVGr^BH|E7gZ&WPYxB{iFuzA{6B+M;eB0=tT(ZR8f$tpRo)EX zyVWX5$!)oCtkZc6T~@4^NV$gkL$8+0U|XU< z%cM!PecH?E$3pAeCJ;rUc%SUtqwop*iDnQph!p_b2oE*Ikd_I=AI7axetV9WfBXRD za0s567k7|T`qk@Iu>{XCw0Wbf_GQ_{q_tfojv3ocd7ggTlGVe8w1?Nlc6QddhwJj@ z{(a{A#$Fv)P?(G4ME~X-u)X|841Z-M+@%_@i<`Py{%T2S%Sni-FiToKG1Ye0Fm-fg z6?d|?fbiDE)lI_2)CI(W%&utqSDTB62j~J~`6Vq~&0TDr-JD!NT#zGH9mKBTZsvBE z_y7R}MrKzwbpX-_{?GY=NyLBU4X|=@@Ns}xxdAIYrvM-Dhm(&7a8=#GQ8IONvAs(! zVB_F`?4ZEE|7k(;3IGgMft&|d5E#-6BqT^V0+|7jjEMhaV*LFGnO)M>%E}VRXMiLy z=z}01q^q;3Igncc?8DBGLFyR3XRm-~` zUZ5oo;s;i$0TKXg*QOu=;9YYNy9JQe0L*L!RKOq`5W6kV0$#8KD&P$V5W6EVCjbRl zDe!^|FfoW7fa;FS$qDp(0&55I8rUD(TDaK&c_Vj;7XQf9xTE*q){*}(^&cF$qug)E zZ*OW1(2VOBmH%HE0b&<(HHV~t2mp}TAz&asSh*mUHVIQ_*)oeNdBK>n7c(n)}#;O1x&O5KNRHuuXBI_m6#jg zG7lFgz+8S1uK+JHCkMb#em*V`n1dJSqPmz~05j%LU>Ca|4|KMfkyYh5rsKK4dTi8K3}}HXvTWb0Q$X z1H8uv;s-QAKEMZgR|8)a9~WdSfIQ$DhRn&y$pHo$oInQvfQttivKk21kQ#yo4*(oe zc!7!efG?5<*wi33FCRb_@ULflz+@1gCZzBJV}TkV5jOw<=-}n&yBh!*#0?DQ0WjkR z2H*7oC$0bhn3o&C8ffzXnB7eVOaiF^1b*=iLP7wEJJJF`c>(0^$nrOKkor!9AnW4> zMsjlU^4$#q*7UddAm4EWpeSVcT>%T+t?hrN`qh9S2$_OI0Kg2gdz80M&^Ue0u%^xkbX!X&tDS&%z=#l z`wv7K?gRlM4F6R)Irw>ixd6a_Nygvx-;(in{Y$idNd`nJ{#OBD-~Hy~LWW4l-{np? zfIdE8KY|FteFb1RfH=Tlh-CbuApb2Hf7ibR;+J4R+W#yNAKAw%l_ z1TFkigBFME-oMTr(4Eg6IDt$7mi>BW4gzR+*J}Z=>sRa7J_URucgKVkz<^)RZ9srn z+`VfHIS1}qc7R0uYB_)axxRbm2(ao`%L&AN2M77?WB_FOuHPASXJ`TN_|Hkg{Z|&> zKM$0@z`8(=7GNMUu>0R#-++O42h3l9e;rJ~`tK%p1Kp{Vz&q}cqv-CLCkQx=A+7)H zqrCq-@c*(?a{);FfvP6w*v)|oSc7lcq0%ZodaFcl(URfrcVJR07_1z0OX3^V zt4@iBuXafrA-xDZ44%xsB(7O`I4fP4`cP9Jgu|&uWH((8Lj&4OM4xdw7HP^%3`9fHKp}1PI!m80+Lrlxb##!D(LcW zCg{@>2ymWuvIz;{o8h34gukUo6|K4P+a0=mVOv?0@KlV7`(x2p>*n4~uUn{86$83| zD@?rq6tTbP_b1j6a4okM;CXlpg&mf*fxID|cISBzjYu7Y!x{WvUWdoTP;D`2+V_Bl zaQg3K{V&h}I=|%{n1|=_FVMxs(*FyzrHw7R)p~#~e$OL#S)y!HsKSmcR}&wX4_qHU zLk;3#e!2&Vd+sfNV(J!&8-(~dBn}lj5SKL`Tx)7!^#7a+^YiJz(C_+(=SuHWLO>{09%F?@*G_Vl#HW~1uRKUwR z`VQL%y?r9y(-ah;Fe8SQ-B2qlo|itpR&zrBl^S1?^@CZy$ZgEg`J>)xh19<0M7yu> zPd!2=wqx-vTH2+2q`ib@`{+o!J~+)FT%lp8HP;G+{uF{b5)67Wu|NFLBW^0>tLyV% z6vCSC76qj{=-<9(X4UCl2GdwJ>qo>8xg&P9eDiaRP9GH8x$N^LBHunmKwmPy$JYp) zVgI6|0N3xuO(eDh!fBXWfYX={x&cm8Z2$a9hqZ%10Y5`%Vp_+u)G^bZ{&|+KrBEt) z7E$}1)Pm~ECfp(&LzyDQPL)id=jv{x`_AZdUnkf3l~%WmFGhbrW5|=E+{5d84A*}m zjjXXX5FBBBRXPxt5Mm%sw|>Cs=Fyd_T-DnZft1~=WP2u;&Mh-G)ondCXQKEd(PFDt zU?rgzTe`tu`OD|hkRVVrY%AL1v&`q_BiX4X3lICQNyhKN@E!v(7^3U{HjXaouaFGU zP&Ul=?=vq?3mFXX<-csVZcJn1Xbn}c0?^Lt?9g8RsM`O0B63B>{;TS z5=$TX&JyJ$L5}JnveukkpOA#j=7!h1xU$*ZaRE5*qP|-O%|ooX1=PHL6>D0*i?T>h-pyBNW!yT{b5dC*&(sOBAYsoo25mr zPhJ)ff)VisOV^D#-Tl#c%7OYavAuo}?zwygeaq# z>!{U-q|?Tyu+Kg-!47EDf87qkb}0YudR>9ya6OBo?BTm3w~iCBnO4zmxr&$kQ3qZ` ztwDj)KVio{k3O%NAvf4w^htzdI)n96WJd?XDEeN*zV(V0rH$c2iGz8mw0!S)8=8_* z65^F4l65`z0shbDBQ32{*BHu_nppP`>>hQ7_Yb}hFz;Jc4pk1m$D=&lfMUbNr3&2c z)MPo=08vE%yEd2v37DisLMNCa(sU5Jd@#{-Z5IU9s#XdZ5E6J-MU2WAnr)L4O zTGCMb*!Jx$22R<<5;;{$G>U0hqVr_Ms^iNzf?JQ&5gmj>-peRmBh2Ye#mUF`)@*Zd z=C;}00{xCLhB!5`?v*7PIx-x2s3C6kT`3?}&BC+(IIfK?{zUqh(TCsWSUd!tq3@#b z#!}*KUNZRv@doGbc^^I?b`iTZT(~5cYTBX&ks#qus+k>RuZI}7c({?jW59M5&k|$9 zp&Z&~D#pxGt*9$_rLN}uTC`-VBKP^au>tEn^twkC%9&ENN?ZBrO<-HJ__5>4Zg3~4TtXc*?1K^E zwpYxXh$Bdr4Ss0KXCBus#R;Fi?N%&;KqqkBi5%ieV!5X+OCMQA6}JP9V|>Zvo9l(o zvfsz6PoE^R23l5EkL)kXO&NITFpz>QXf+1Qg59uWfuqDf*d(VOLzKd`LtgAbiph#)J3VBoU)6GP@(n8f--OjIS2fG(sa8I0Kicri+&ye22i95)4KiGK$CApWPS(y-LO@q^h zw_t5XYO*<23;l*jQ!>qWuQ%fLaA{1lWVfh_(szb5MON(|M-3w5o0OOdKQ}n0*W4#5 zC!<<*qbNF&-q_y+Ze%YvC+#W21RW<&BG)OvAGxrHO%w!7^3qpjmgaQbvEGn1d>D2M zq)44~R(!krtnXaqr`w#~kbAQWV*aCbC9$I>^f%Hq-Q8TGIkAjB`7G7MS;(dQaX_VYtB9SSvc{??L9x4EfRF#yRmZ5`!!5KoqB668X>B{$mea!rfVw#Md$}ua2~X|zZ_RX z=|(Cgjxwq{DY{-^=Lb8pt`)s8;n;b{YiTP>x)8K@3-NivfC$j;gIYqC6kc8&F5O|6 zeYAcwJ(H=|Bx=z?%qf#dZu#2XWuKPnHYPpzBv?v$Q5ie|jd z^f*Wo_gPjFqWVORT*xKF%KXl^d&e>G`@=!?qaXbjmt<8KLpKK5_fYU2B^ofr{g$Uv zu^m{D*qK|vm4b{st3sCy9O|){Tz)2-=m#NiDD~db6mY3@MtC|1^5*lZ@sC=TU0D0) z;V_)(pU;kJW~YsN+9t|rG0qoad*7m4( zZJUO3pzLyb!s**Gh^DXqLh;Mc5|W9%*PuqR#XxlfEj7(lii$s#L`Crpw+jtb@ueIR zoJVSWubo0g>|SCiZtm>yYR#E+@eH2c`N3GxP`=vGK%YP#X5LG@##-MA^qC_c1KGyI zlp#ve!CTcp7Rv<}WLh!(QO`ZBuNY)+f|!F$tL@l=_bs_5Zoh})8yxyynDw#+;8$$6 zcA!I*mf*aF>m!;y;AOOd(rS#B>nu3&arDTyFVNb)B838@bGWoe(CU-5jx9o^`5SIY z#YX*Lex#=R=%(+BN5>*(zpOL<27FwLv+&1EX9fPA2alp!RtSBLXs0*WAHRvnIzn=9 zv|w-Ud^x+ZDV~hzSqX)1WYJYwUQn90zbzc}&05MuI%Xo8Uyw3M@3TPA%I8<)daDL} zH7gF(H4Try9#AZ2vm1QdB~aK4>ApvX?$N6`0juiog(#u-DI5^vbXfL)mLmZ9yrTCM zWA|rO*x9(451}}!nM!B1!Q-c!QUTA^QTW(1$7nwbLNbVtc^6$k0X5oK{BH#=7 z?OMTH?}1IQnj0EB44>}*ijRFun;jHk9gqQX!2Afpm#oQ*B$6b{hE5z@l81q!ox|_H z2cLV?Ou#DrTYN-531CBfp`$M~Zm2{0s4kw8Mtj|AbBw{0m;1b;Lss~)`;RW@wRoM)n=(_cSGSz~mgve7DW3z1HdUN>qpA1>lZ+i-83z5L#N-<>gYi-p$C>PS`&a{hsvTJU+ z_F-HI{v*k-p}=d4rv)+Yn(j=_iX*A#Y718$x+(rL=$6RS6|SWAO9KvFg!IxcDi6>v zB0(zj8ZsaEK{H1P^s)Y_<;>1G_TR=7IM88#h~s3eML*g_4eS}xI9tQsTV3_o3^Z}_ z#*&@E>`SC1Bn$G$F}PisylT098YVRIj_3{P;fiO*q(`r@SrVQrbo4x9>82f<-V~AG zVludQt@t`@xWOd7DOmCF>fz+b_n$E2j8@<8A>}`{xL@whzyFK8|G1o`om>F7>VMpr zfR~Jwn->h?8PA@9TX<3lLO@0W5NWBLM}0{RBRXQm#mIZ%WIk=L-af7a z3U0y)nwSlW+zJY>r(J+Mtc0~TzuD$>h~IuTXwtWB<>>-ua5D zXUuzArnb5YgGldlnyeln&1(oiLe40V1;5?1)_e3GFxMaKLOg0s4j(yC_juq4I6%*6 zhgLzdiSM7k5So20JKjv8WyDovz~C| zKP}Q*o~$N|<-!?-w-qW2+ZE;b0`Z7BK&@T!mE-XRMx!qFBVybe+^wwgG##@CXwjjS zGk7c73yeA+$}c@=8W50NICZx)(50dn-?#e_vCJguW$(J^BHudtKOi;@<=?A|mR#~~ zhf$>YX%S#QBC5Z)vm?dp5yrng8W0`uVu%J>OoP8%^n>a3lctT$Oa^*>fyt_T&Q>(# z4F&bZ#|^zBuf!HF!NS&_ajNK9s)%O`4?y7y?b0El??&3$VCvXyfoFT4w4J7Jdbb9w)a(;d&&RgW2j*Yln2|eKhMLzbA+`zq9_>xv9!JTd+@C*gaR7OHJuE3k+%!JeUYQ78n z9nHZFMmwCmMjtdIlpRx`DsP7ND2mMxE%(nz}=Ye3U`qn{*isT)xreT(dV znx!6>*JkHjG{U}4uQ7cOiSE(f{^W>~in&YRkn}A!fqh2zO;jzVbFmE#lh*MA{ep-A z8B!*rBe+Wf{v|3)nBmjFvosDA0~yu0oV^0mFReSUvFF9=r-v9RFZU`|!giq_h!b1g zgW5f+3m4#${OzWpq}uNQaPPVanTMl&i?oRxAzHZh$8(7Rq@;8kg_XRM6#cXT%L{4& zr~bxuiE!mq+Y`spp{kmn1*i`^ZdElgD`{)wKN);>tl^eppXuU{O}Bd&xB1#PM7O}} zXmxs=F>qUitHY|UMS!DRWEed``9n|D&bUx24NPQX!l`%>@-F!VZGJrST!wm%GV3~F zdhV95UmTRk$HmzBvED&~80mMcrDIQ0`7tW5$EZQ z70AnM>2fxK8LSH_EUHv*#;}CmdSEy!xXLXb6bIG$l_BQ*JAB!V5#PV*26zJgr3G;U zp48vc4HCKx1ZU6vKz;@aaM!6{Os3jO`1!nY>lGO+)+?{R+5&Spbk0DzUed_)w-h(r zv5#1Sv#5Qwf5%jV*g6uYpc%gRYaSgMbxiLXM*2xzOi)IpA=XOnrOij==aek5M_CVg+?1l7UDk*`Q2au2*%v}! z`oUdN?qV&aR`-^*T#S&+oO&amn^}D%AWuT}P@i$d!@WSeS1})=CtZ=2(g1t z_wieqfqBh{-3Rut$7h2*>_^4w2J5V3l9pM~%F}v}zu!qM+nky!G)KsTDX055W6?9q z%@=X$^)5!wtTSxwQmky(B$$w z7l|Fr3g@u;PS;?cMIBMW1$&O}ejY;B_E0P>GcLbke*SFU9H08gE&IAD){JO}!_#-U zR_Ay&%c<>}I{9N))LE^E{YI->a$RMXCVUXYTlb1}MW$BlJ5r?nmhm-%`KW+tf8=KQ z`#7Z(9RZmXe+`Rpl@i0wS2J>2b5p6r^m8*)BlgFcntD_03d0Q|jH9lUYLISKTpcAM z>Cf<|n~w{~ogR7TL(=}{qJ7(Kf};BU8Cu7eJ04iB9=u<_X>@jf8$m5`7o_aC%^9HEjWx|5As z&s){3zUZ=bslH*7%0!>(#@;m=YKyI6c1Dz}i+h%FIIF15O3AeS-F^F`kEi{ydAZt+ z?t$2qf~gh}%U$5+2=!{`PhIO2ye-DhB1M~1!o1Gm0;|PILQ)AQ)=bNoFS@LQIp6kA zLb3a;`DG!*-)Aj4p6yw@s%A#Wb48KGx9;w$4V5z;`!aEiTbq*O#**W|kD6f;logj? zS2v;gx}4~QmDyuT{ca{^;>)FH7fxI$B6YoNqn_-Bs%dY|thkU-+rivwvGuW2|ID#x zrRRRyEO1fUX(gYbF?`9(UbT|Iv}44fp_jrC2gI3B;OU6DhK>;*OD~-``{hT8$W4R% zr-G_8YNnW844v$$^|V|+ucm$76Q!kTOU_P7lQA!Li&dXzXkzG=W_VIIa(1>|g;Qph z_52FU2bEq;=jyv0I$D+_`XzpidpgGYtJyQhjMt`-$I?h&rdijiMx}Lt6^8o2lbllP{_V}s;Hyf)G+q9z|96Q_=#WcLrvM}~i zN_=&ia@6J}O1-2;jmC&qhu_G`elBfmtq{3D=-N1O!G{|=+~poNN!@mlY<7|CwJX~o zcW85nRl)g^Ikz6h|Ilg*R2EX+qTe32s(1gfq(45VP4W7e>2v*PXqk!np@Om*-e!R- z6;?T{IkixhDbo7k=&ebXwStqQPEoD|KA3miM!#}OFwJnSqxbrxZ+1)DqkJ9pcg{}P zUouN0x1hC60oRk9BO5S(w9B*Xqd@`E%l(V@wNdxRS5b^So}V!+&pNlsiC*zt)XL+; zrn$_3{2L`>UzM*hZf|EC-FJGoy=Tx~iMoOv%nQrRmYuPsJeOlgjw^9>us!QFIm>n2 z9r0CG7FOAzQs4i4`|^$Yw(kN;fm;IkbPJ!QJ)FIb*eetY%?jaf1;vYsqB%@aFe{ig zrM7UT+2d1o&rs1F??&`-^)7Jsa2>PS9#6v8Ae)rTY>Cl)O2cP$O98XL5sQO>6X>!C zred{r;cK%i?+zV`KLcufiPMrMMCXT!1p8~oXKu*fDLmi1V&+Hr^2PU)`?K3Dn?6pd zn9Ni@X26_OrHm7@kkGWtB1q1D8xwvd81@6*0e z>q@p99Wm!tpqKyHxB!K??FDuFy5}A#$Q#hK&$`r|Zd}+dW4i2#K-0x0qg*GW&$(Hd zOP8eWdA_xL@j!G#nueAEsWt2cEzT!jSfNk0P`<5k^1`xa!8(uZj^$6B&&kFpCfJUz zI4JA=c7?>Q0+}NZrMigsFA=R=^B$B5-+1im9`W+gL+Y_}vHce5PUrnz?!LTn>{ZY2 z(o-(};U<|DW>8!%KEaFL5ZkbCCB13oI|+tZL#*5H*-D~ro(jdGNpTJB`zi`uE;P$J zyni>*`stP!J{`el4gD*-4D2+K4VRF4Jlc{^%6yiW-+d+AG^4D;PetTus;nFRyW~tcsgnSg*qx+juVo!-X*03)kC@U6m*no| zD!q;{njbA^9-lcqRoUIQD(ntz!tvv~tO}FL;tvp$KU8{4%C5>^F>6EGqR7Vigu4%G zWLgSu*51#LKeN3+fHaP`HemkAEBh$O#C#;R1J#m&0GH3d;kM$#eDA#Kg01n_`Uxdn zlSlgpS!Sp^SO;HAcA`NZk5inzbGNYX1x)zl1pR)Ue?2 z+lHcy%}T)*GY2U1ND@zOOY7a$371dHWHi+T4Rn8MFVFpcD?zP(;ryp79PHh$X-W5I zTu>rE4YeBxdUvxqa=hxR{I7i-qT;IViD?%^j@)KM=Vw{igjDIKPq<8KmU}0*@4VEw zrC&5A*C)$M$Yiv1i%rPezDo8$LJsY|;N$#7K9VA0Jr(Cyo(h*-I(1#zM}ZQPC*g4K z!J89jIy+~|^oQ|Tn2<0S0gnaQFD$H;EsO$2>x@PBf$Z=o?snU=AO@}j8GBJ=%!GLD zAcvmZg#CCW+qy@ohds}CQ%Dir!ZF(xbo$u!YFp!t+{~$Y!EZtbWN-Orj%!Q5StjaH zI`M0Q;@u}r4@gpLsyf|o{Px1>m}ZtvYpG+Syu?(K4wu<$#1>_dRVo8D)$o1#c6(e> zd~K#v#FJE`4CSH@?Jc+>q$8eiR6mgxF!|3yLUK^_^2CFr7lFz?tC*(rb?q3uc5d3+>8nla z=;^MDefVd>vA({~Uq`QUUJ=WW#rdq>NxyOsXPROJ04KPMCgB?mtH*X!VnXqy>9tJZz}IyY*b9|7fk&Y9#60%U8DB6WrX@CLVZ> zOcFe)-`b^--!QS+ajf|2Ju}ZV;Ac4{f3|sNpFY5>Qoom6Q|YlvGruA3#n`4Y0o$-| zHO=A1=RU8z-n(bMyqiSK+wfnuvl}I);PKd8E;D+D ziknqnu7I}Qg(BzahZe;wh+HsX)UFiqSNzJvXSqoK1^odPbQ@tr(IuPB9!%<~?-G&2 z-KuXg@^kCHm%h@nNBhyjNl8WFb9LTlt;Bad6|m9ku`p`;VS2hm&x%>pf7n!LE+y&x zh8YL5uZOlf%=1Z~a^G+3s2rJoekJ6y&LjQG*_meYS=g$BUFm1s5AA|mig{biir2=x z{Bdlyaqaj}LAQjc_IBw*+2^C3_daa8CU3p&)8r;?d9{`-XR*S9Ycd9@p|134jFHzD%l3^Vr&cVR4PTXV~UaE3dmn{t}HcRN-fe&vFjb{dQ`9$(mcLY4=ns(gLd| z)Nf{5b^a(mTzBcS@ac^z-}zOG-#Y0RcCeZW>SECFXbc#5Zkr>rI5zzL_Cm$V_-_Mq z?lfI>c1RVhYH01%D!CrOoTI+wx4m+`w$ctKj$XW$EgGMlP(Pvjo@wX!BKp(@eud(* zTp|M~VMG?CS}?>w2mRusOl+Lsj}@i3a#dq-4)W> zS>-$Np4*A=JtyV*;_$BXq?B|nm*`1M@nOy_pw@#5@cy3udo5Ol7j{nzA0uExpV9fG z^Yt!GGli=`lH?nQsh6gWlhB)`@j0oeL*luE>NT?XjSG`X*N(1J|3+LkYLtZkCBfDE zmu6J-Hm)xps2F4W<-5Rs1MA(q4Fsn!LmSf}_!OKT{KUj9Hnp%=VlxaAhZ^rE91ac= zhQZ-j0IFf&E5Mpy*7=-pI6NB;$CETG{s<=*2yyqXk#M+SNH{85P5oDqaOf5`KX>@T z4`95|l4S_XFdP*}fEI8M@CpQ+1nkz_*8BLl;8p23La2KH0NW5=ju5mFd_kT8?%?-; zZwA4CuVC;ai~}LiJ=h&Ua4bj$!Pg5SAX|5L{|zqQ0Q|rc3;^2#FfzbEL30Cc44^4^ zf(2l(@B~{t!Hx~MBe=2wcLa9?o52&{YfHoNb_8#L+u?q~?FjzE(RO%3Faq0=;H-$S z2~P;+gxvuGj1yXc_3PlWVTf;r_53i{9Zxp{2mKSo9p9s>{{-TW0zj4_qzG0&RGq&8 z;tnH3B3KobMT#E*`wb1AY;03TJ*zD54u9S)}aL54)hs>bjzyXd7Fmb>ObU4;i z0|b;BLNp;v9eSeo06e0>$F$TSK1!1Tl9)zk(a;mC3 zXY5A98XM(}8W5h)usD<_72pn6dfFn%I>qD zNwGY@@T?>NKZ&AIQK_Khs9O;m40hMx6G{sQ9IRHtQc#5n0;l39>7!7nP8c&m zN?5zHnFFN`zj4?*b^;6<13%~xBd~Y?@d2m9+8X!*@x(BP@dM(CL5J}QnO_^7cnUjv=MO34!sBMA%Og~tqPAlG$B%wRNcJhVCH^8uD(`FsH@$@U?E z1(O8?01P0KN?1@N_vq|O`Qc&UW7HuD3gcmj!h)A$r#z6zpvZ+@z=HVWF$Fr5DF6>9 z3>XbQp2yCgB|ZU)%A$ir1pqqu0ZWL%*@0KFH-On@7lI{-VWhBMfN*AQgdv9w;ALkU z3|_3ZI3yVXzK$ic3B2r_L&3|=7L)Y|*p9*s2^JK6kahGCVw#nPJ`6`zsfL20He}43kpsNFc2yQN*BZPQ9$(6(i&E9R6}-YhKh*l z%0om|9xS&hni=6Yw4*6ZMeG(c1voLg#Z18ph23JNV2i?L6;r_ZvKi47U3*Ra&t<|uPI zI-d0A`(eC;o}rXavVcZgr<(2pe7t_x%(s_n|yVPnw1sW-~yoUbSFf{0DTZdki?bO51$<@_4)uRT)=JPRa!o9D_TRn7#R^mxrbUWj*oGU~IU{@KxKekD(3^!=T$r&_CO zw$%D5pYWGbIQiW0^n>;0MT<`L&wG0)#v)c~rkixghjAa0_mE8L+!oF;DT>Pztl8z% zyFkRRVp%Egq>0J-|8EY|FsK!WNzRMb>RN_fwpAiuzc z1cJ}cS1RiR)+pi#a0%P$4Q}v#_TjPoR(s%GUEJ0|HtZA>A$yd(*LvS2=%e=YmMkO_ z$rJ$Qkw|n3>`xsO(dB3Ca)kuAd*G0^iGhPlvDkm0T%dx&!UO-4hL)G;5AVySQNVm3 zs>?wm(;4Ur9c~)hM03&@Bv94x(2z=mI~-CQaM0if(tL2y&_zgGG`QIb`im3Lj1yoygO_ zL^>c-c*0T94OU!nky@LZM&}L3Ao1qQK-a=?#RY9BPnki>&E3DCQ$orFj^SRT5>-czSRHQuQ%vTMVmzzdMYILq}bX9Z@8)sa)8ZW=l`LrK{ zcix2&%`<1im7~0E%Ah0VFlW9D22!hX(8xpzPrs3gRGu~<6V>3zg)1&FG(Oco)4dw`dLYF*^cPDf9vI{j5QME_C}cR*!O6?( Ind#&H0byt<4*&oF literal 0 HcmV?d00001 diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py new file mode 100644 index 000000000..9ad991194 --- /dev/null +++ b/benchmarking/switchback/speed_benchmark.py @@ -0,0 +1,102 @@ +import json + +import time +import torch +import torch.nn as nn + +from bitsandbytes.triton.quantize_rowwise import quantize_rowwise +from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose +from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze + +# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. + +def get_time(k, fn, info_dict): + + for _ in range(repeat // 2): + fn() + + torch.cuda.synchronize() + start = time.time() + for _ in range(repeat): + fn() + + torch.cuda.synchronize() + end = time.time() + ms = (end - start) / repeat * 1000 + print(f"time {k}: {ms:.3f} ms") + info_dict[k] = ms + +if __name__ == '__main__': + torch.manual_seed(0) + wm = 4 + for dim in [1024, 1280, 1408, 1664, 2048, 4096]: + # note "batch_size" is actually "batch_size * embed_dim", which is why it's large + for batch_size in [256*32, 256*64, 256*128, 256*256, 256*512]: + + # switch switches dim_in and dim_out + for switch in [False, True]: + + # hparams + repeat = 64 + batch_size = batch_size + dim_out = dim * wm + dim_in = dim + if switch: + dim_out = dim + dim_in = wm * dim + + dim_in = round(dim_in) + dim_out = round(dim_out) + + # simulate forward pass + x = torch.randn(batch_size, dim_in, dtype=torch.float16).cuda() + g = torch.randn(batch_size, dim_out, dtype=torch.float16).cuda() + w = torch.randn(dim_out, dim_in, dtype=torch.float16).cuda() + + x_int8 = x.clone().to(torch.int8) + g_int8 = g.clone().to(torch.int8) + w_int8 = w.clone().to(torch.int8) + wt_int8 = w.t().contiguous().clone().to(torch.int8) + state_x_rowwise = x.max(dim=1)[0] + state_g_rowwise = g.max(dim=1)[0] + state_w_columnwise = w.max(dim=0)[0] + state_w_rowwise = w.max(dim=1)[0] + state_w_global = w.max() + + info = {'repeat' : repeat, 'batch_size' : batch_size, 'dim_out' : dim_out, 'dim_in' : dim_in, 'wm' : wm, 'switch' : switch} + + get_time('standard_fwd', lambda : x.matmul(w.t()), info) + get_time('standard_gw', lambda : g.t().matmul(x), info) + get_time('standard_gx', lambda : g.matmul(w), info) + get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) + get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) + get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) + get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) + get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) + get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) + get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) + get_time('w_quantize_colwise_transpose', lambda : quantize_columnwise_and_transpose(w), info) + get_time('w_quantize_global', lambda : quantize_global(w), info) + get_time('w_quantize_global_transpose', lambda : quantize_global_transpose(w), info) + + time_standard = info['standard_fwd'] + info['standard_gx'] + info['standard_gw'] + time_rowwise = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_colwise_transpose'] + info['w_quantize_rowwise'] + info['standard_gw'] + info['rowwise_fwd'] + info['rowwise_bwd'] + time_global = info['x_quantize_rowwise'] + info['g_quantize_rowwise'] + info['w_quantize_global'] + info['w_quantize_global_transpose'] + info['standard_gw'] + info['global_fwd'] + info['global_bwd'] + + print('TOTAL STANDARD', time_standard) + print('TOTAL ROWWISE', time_rowwise) + print('TOTAL GLOBAL', time_global) + + print('speedup', -100*(time_global - time_standard)/time_standard) + + info['time_standard'] = time_standard + info['time_rowwise'] = time_rowwise + info['time_global'] = time_global + + info_json = json.dumps(info) + + # TODO: change this to what you want. + with open("speed_benchmark/info.jsonl", "a") as file: + file.write(info_json + "\n") diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py new file mode 100644 index 000000000..f35a3b582 --- /dev/null +++ b/bitsandbytes/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import cuda_setup, utils, research +from .autograd._functions import ( + MatmulLtState, + bmm_cublas, + matmul, + matmul_cublas, + mm_cublas, + matmul_4bit +) +from .cextension import COMPILED_WITH_CUDA +from .nn import modules + +if COMPILED_WITH_CUDA: + from .optim import adam + +__pdoc__ = { + "libbitsandbytes": False, + "optim.optimizer.Optimizer8bit": False, + "optim.optimizer.MockArgs": False, +} + +PACKAGE_GITHUB_URL = "https://github.com/TimDettmers/bitsandbytes" diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py new file mode 100644 index 000000000..a100b2919 --- /dev/null +++ b/bitsandbytes/__main__.py @@ -0,0 +1,154 @@ +import os +import sys +import shlex +import subprocess + +from warnings import warn +from typing import Tuple +from os.path import isdir + +import torch + +HEADER_WIDTH = 60 + +def execute_and_return(command_string: str) -> Tuple[str, str]: + def _decode(subprocess_err_out_tuple): + return tuple( + to_decode.decode("UTF-8").strip() + for to_decode in subprocess_err_out_tuple + ) + + def execute_and_return_decoded_std_streams(command_string): + return _decode( + subprocess.Popen( + shlex.split(command_string), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).communicate() + ) + + std_out, std_err = execute_and_return_decoded_std_streams(command_string) + return std_out, std_err + +def find_file_recursive(folder, filename): + cmd = f'find {folder} -name {filename}' + out, err = execute_and_return(cmd) + if len(err) > 0: + raise RuntimeError('Something when wrong when trying to find file. Maybe you do not have a linux system?') + + return out + + +def generate_bug_report_information(): + print_header("") + print_header("BUG REPORT INFORMATION") + print_header("") + print('') + + if 'CONDA_PREFIX' in os.environ: + paths = find_file_recursive(os.environ['CONDA_PREFIX'], '*cuda*so') + print_header("ANACONDA CUDA PATHS") + print(paths) + print('') + if isdir('/usr/local/'): + paths = find_file_recursive('/usr/local', '*cuda*so') + print_header("/usr/local CUDA PATHS") + print(paths) + print('') + + if isdir(os.getcwd()): + paths = find_file_recursive(os.getcwd(), '*cuda*so') + print_header("WORKING DIRECTORY CUDA PATHS") + print(paths) + print('') + + print_header("LD_LIBRARY CUDA PATHS") + lib_path = os.environ['LD_LIBRARY_PATH'].strip() + for path in set(lib_path.split(':')): + try: + if isdir(path): + print_header(f"{path} CUDA PATHS") + paths = find_file_recursive(path, '*cuda*so') + print(paths) + except: + print(f'Could not read LD_LIBRARY_PATH: {path}') + print('') + + + + + +def print_header( + txt: str, width: int = HEADER_WIDTH, filler: str = "+" +) -> None: + txt = f" {txt} " if txt else "" + print(txt.center(width, filler)) + + +def print_debug_info() -> None: + print( + "\nAbove we output some debug information. Please provide this info when " + f"creating an issue via {PACKAGE_GITHUB_URL}/issues/new/choose ...\n" + ) + + +generate_bug_report_information() + + +from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL +from .cuda_setup.env_vars import to_be_ignored +from .cuda_setup.main import get_compute_capabilities, get_cuda_lib_handle + + +print_header("OTHER") +print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}") +cuda = get_cuda_lib_handle() +print(f"COMPUTE_CAPABILITIES_PER_GPU = {get_compute_capabilities(cuda)}") +print_header("") +print_header("DEBUG INFO END") +print_header("") +print( + """ +Running a quick check that: + + library is importable + + CUDA function is callable +""" +) +print("\nWARNING: Please be sure to sanitize sensible info from any such env vars!\n") + +try: + from bitsandbytes.optim import Adam + + p = torch.nn.Parameter(torch.rand(10, 10).cuda()) + a = torch.rand(10, 10).cuda() + + p1 = p.data.sum().item() + + adam = Adam([p]) + + out = a * p + loss = out.sum() + loss.backward() + adam.step() + + p2 = p.data.sum().item() + + assert p1 != p2 + print("SUCCESS!") + print("Installation was successful!") + sys.exit(0) + +except ImportError: + print() + warn( + f"WARNING: {__package__} is currently running as CPU-only!\n" + "Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n" + f"If you think that this is so erroneously,\nplease report an issue!" + ) + print_debug_info() + sys.exit(0) +except Exception as e: + print(e) + print_debug_info() + sys.exit(1) + diff --git a/bitsandbytes/autograd/__init__.py b/bitsandbytes/autograd/__init__.py new file mode 100644 index 000000000..6b9a7e4d1 --- /dev/null +++ b/bitsandbytes/autograd/__init__.py @@ -0,0 +1 @@ +from ._functions import undo_layout, get_inverse_transform_indices diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py new file mode 100644 index 000000000..63b7156b4 --- /dev/null +++ b/bitsandbytes/autograd/_functions.py @@ -0,0 +1,564 @@ +import operator +import warnings +from dataclasses import dataclass +from functools import reduce # Required in Python 3 +from typing import Tuple, Optional, List + +import torch + +import bitsandbytes.functional as F + + +# math.prod not compatible with python < 3.8 +def prod(iterable): + return reduce(operator.mul, iterable, 1) + +tensor = torch.Tensor + + +# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: +# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py + + + +""" + This class pools outlier dimensions across layers. + This is particularly important for small models where outlier features + are less systematic and occur with low frequency. +""" +class GlobalOutlierPooler: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.outliers = set() + self.model_dim = None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def add_outliers(self, outlier_idx, feature_dim): + if self.model_dim is None: + self.model_dim = feature_dim + if feature_dim != self.model_dim: + return # we do not encode outliers for the 2nd FFN layer + + self.outliers.update(outlier_idx.tolist()) + + def get_current_outlier_idx(self): + return torch.Tensor(list(self.outliers)).to(torch.int64) + + +def get_inverse_transform_indices(transform_tile: callable, tile_size: Tuple[int, int]): + """ + Compute a permutation of indices that invert the specified (tiled) matrix transformation + + :param transform_tile: a function that applies forward transform to a tensor of shape [dim1, dim2] + :param tile_size: higher-level tile dimensions, i.e. (8, 32) for Turing and (32, 32) for Ampere + :note: we assume that tile_transform applies to a cpu-based int8 tensor of shape tile_size + :example: transform_tile function for the turing layout (bitsandbytes.functional as F) + :returns: indices + """ + d1, d2 = tile_size + assert 0 < d1 * d2 < 2**64 + tile_indices = torch.arange(d1 * d2, dtype=torch.int64).view(d1, d2) + # encode each position in tile as a tuple of <= 8 unique bytes + permuted_tile_indices = torch.zeros_like(tile_indices) + for i in range(8): + # select i-th byte, apply transformation and trace where each index ended up + ith_dim_indices = torch.div(tile_indices, 256**i, rounding_mode="trunc") % 256 + sample_tile_i = (ith_dim_indices - 128).to(torch.int8).contiguous() + assert torch.all(sample_tile_i.int() + 128 == ith_dim_indices), "int overflow" + permuted_tile_i = transform_tile(sample_tile_i) + ith_permuted_indices = permuted_tile_i.to(tile_indices.dtype) + 128 + permuted_tile_indices += ith_permuted_indices * (256**i) + if d1 * d2 < 256**i: + break # if all indices fit in i bytes, stop early + return permuted_tile_indices + +def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -> torch.Tensor: + """ + Undo a tiled permutation such as turing or ampere layout + + :param permuted_tensor: torch tensor in a permuted layout + :param tile_indices: reverse transformation indices, from get_inverse_transform_indices + :return: contiguous row-major tensor + """ + (rows, cols), (tile_rows, tile_cols) = permuted_tensor.shape, tile_indices.shape + assert rows % tile_rows == cols % tile_cols == 0, "tensor must contain a whole number of tiles" + tensor = permuted_tensor.reshape(-1, tile_indices.numel()).t() + outputs = torch.empty_like(tensor) # note: not using .index_copy because it was slower on cuda + outputs[tile_indices.flatten()] = tensor + outputs = outputs.reshape(tile_rows, tile_cols, cols // tile_cols, rows // tile_rows) + outputs = outputs.permute(3, 0, 2, 1) # (rows // tile_rows, tile_rows), (cols // tile_cols, tile_cols) + return outputs.reshape(rows, cols).contiguous() + + +class MatMul8bit(torch.autograd.Function): + @staticmethod + def forward(ctx, A, B, out=None, quant_type="vector", precision=None): + if precision is None: + precision = [8, 8, 8] + if precision[0] != 8: + with torch.no_grad(): + output = torch.matmul(A, B) + else: + if len(B.shape) == 2: + dim = 0 + else: + dim = 1 + qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type) + qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type) + iout = F.igemm(qA, qB) + output = F.vectorwise_mm_dequant(iout, SA, SB, A.dtype, quant_type) + + if A.requires_grad or B.requires_grad: + ctx.save_for_backward(A, B) + + ctx.quant_type = quant_type + ctx.precision = precision + + return output + + @staticmethod + def backward(ctx, grad_output): + A, B = ctx.saved_tensors + quant_type = ctx.quant_type + precision = ctx.precision + grad_A = grad_B = None + + if B.requires_grad: + if len(A.shape) == 3: + dims = [0, 1] + # bsi -> ibs + permute_dim = [0, 2, 1] + else: + dims = [0] + # bs -> sb + permute_dim = [1, 0] + + if precision[1] != 8: + with torch.no_grad(): + grad_B = torch.matmul(A.permute(permute_dim), grad_output) + else: + if len(B.shape) == 2 and len(A.shape) == 3: + grad_output = grad_output.contiguous() + if not grad_output.is_contiguous(): + grad_output.contiguous() + qgrad_output, S1 = F.vectorwise_quant( + grad_output.view(-1, grad_output.shape[2]), + dim=0, + quant_type=quant_type, + ) + if not A.is_contiguous(): + A = A.contiguous() + qA, S2 = F.vectorwise_quant( + A.view(-1, A.shape[2]), dim=0, quant_type=quant_type + ) + igrad_B = F.igemm(qA.t(), qgrad_output) + grad_B = F.vectorwise_mm_dequant( + igrad_B, S2.t(), S1, grad_output.dtype, quant_type + ) + else: + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) + qA, S2 = F.vectorwise_quant( + A, dim=dims, quant_type=quant_type + ) + igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output) + grad_B = F.vectorwise_mm_dequant( + igrad_B, + S2.permute(permute_dim), + S1, + grad_output.dtype, + quant_type, + ) + + if A.requires_grad: + if len(grad_output.shape) == 3: + dims = [2] + else: + dims = [1] + + if len(B.shape) == 3: + # bio -> boi + permute_dim = [0, 2, 1] + dim_B = dims + else: + # io -> oi + permute_dim = [1, 0] + dim_B = [1] + + if precision[2] != 8: + with torch.no_grad(): + grad_A = torch.matmul(grad_output, B.permute(permute_dim)) + else: + qgrad_output, S1 = F.vectorwise_quant( + grad_output, dim=dims, quant_type=quant_type + ) + qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type) + igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim)) + grad_A = F.vectorwise_mm_dequant( + igrad_A, + S1, + S3.permute(permute_dim), + grad_output.dtype, + quant_type, + ) + + return grad_A, grad_B, None, None, None + + +mm_cublas = MatMul8bit.apply +bmm_cublas = MatMul8bit.apply +matmul_cublas = MatMul8bit.apply + + +def supports_igemmlt(device: torch.device) -> bool: + """check if this device supports the optimized int8 kernel""" + if torch.cuda.get_device_capability(device=device) < (7, 5): + return False + device_name = torch.cuda.get_device_name(device=device) + nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series + if any(model_name in device_name for model_name in nvidia16_models): + return False # these devices are technically cuda 7.5-capable, but they lack tensor cores + return True + + +@dataclass +class MatmulLtState: + _tile_indices: Optional[torch.Tensor] = None + force_no_igemmlt: bool = False + CB = None + CxB = None + SB = None + SCB = None + + CxBt = None + SBt = None + CBt = None + + subB = None + + outlier_pool = None + has_accumulated_gradients = False + threshold = 0.0 + idx = None + is_training = True + has_fp16_weights = True + memory_efficient_backward = False + use_pool = False + formatB = F.get_special_format_str() + + def reset_grads(self): + self.CB = None + self.CxB = None + self.SB = None + self.SCB = None + + self.CxBt = None + self.SBt = None + self.CBt = None + + def get_tile_size(self): + assert self.formatB in ( + "col_turing", + "col_ampere", + ), f"please find this assert and manually enter tile size for {self.formatB}" + return (8, 32) if self.formatB == "col_turing" else (32, 32) + + @property + def tile_indices(self): + if self._tile_indices is None: + device = self.CxB.device + transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device) + with torch.no_grad(): + self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device) + return self._tile_indices + + +class MatMul8bitLt(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): + using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) + + # 1. Quantize A + # 2. Quantize B + # 3. Matmul + # 4. Mixed-precision decomposition matmul + # 5. Save state + formatB = state.formatB + input_shape = A.shape + if state.outlier_pool is None: + state.outlier_pool = GlobalOutlierPooler.get_instance() + + # Cast A to fp16 + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + + # 1. Quantize A + if len(A.shape) == 3: + A = A.view(-1, A.shape[-1]).contiguous() + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + + if state.threshold > 0.0 and coo_tensorA is not None: + if state.has_fp16_weights: + idx = torch.unique(coo_tensorA.colidx).long() + CA[:, idx] = 0 + CAt[:, idx] = 0 + subA = A[:, idx] + state.subB = B[:, idx].t().contiguous() + state.idx = idx + else: + if state.CxB is None and using_igemmlt: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + else: + if not state.has_fp16_weights and state.CxB is None and using_igemmlt: + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + subA = None + + # 2. Quantize B + if state.has_fp16_weights: + has_grad = True if (getattr(B, "grad", None) is not None) else False + is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) + if is_transposed: + B = B.contiguous() + + if (state.is_training and not has_grad) or state.CxB is None: + state.reset_grads() + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B.to(torch.float16)) + if using_igemmlt: + state.CxB, state.SB = F.transform(CB, to_order=formatB) + else: + state.CB = CB + else: + has_grad = False + + if coo_tensorA is not None and not state.has_fp16_weights: + # extract outliers + + outlier_idx = torch.unique(coo_tensorA.colidx) + state.idx = outlier_idx + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx + if state.CxB is not None: + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + else: + outliers = state.CB[:, state.idx.long()].clone() + + state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) + CA[:, state.idx.long()] = 0 + CAt[:, state.idx.long()] = 0 + subA = A[:, state.idx.long()] + + shapeB = state.SB[0] if state.SB else B.shape + + if len(input_shape) == 3: + output_shape = (input_shape[0], input_shape[1], shapeB[0]) + else: + output_shape = (input_shape[0], shapeB[0]) + + # 3. Matmul + if using_igemmlt: + C32A, SA = F.transform(CA, "col32") + out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + if bias is None or bias.dtype == torch.float16: + # we apply the fused bias here + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype).add_(bias) + + else: + A_wo_outliers = A.clone() + if state.idx is not None: + A_wo_outliers[:, state.idx.long()] = 0 + output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype)) + output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0)) + if bias is not None: + output = output.add_(bias) + + # 4. Mixed-precision decomposition matmul + if coo_tensorA is not None and subA is not None: + output += torch.matmul(subA, state.subB) + + # 5. Save state + ctx.state = state + + ctx.formatB = formatB + ctx.grad_shape = input_shape + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (CAt, subA, A) + ctx.tensor_states = (SCAt, state.idx) + else: + ctx.tensors = [None, None, A] + ctx.tensor_states = (None, None) + ctx.save_for_backward(None, None) + + clone_func = torch.clone if len(output_shape) == 3 else lambda x: x + return clone_func(output.view(output_shape)) + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad + CAt, subA, A = ctx.tensors + SCAt, idx = ctx.tensor_states + formatB = ctx.formatB + state = ctx.state + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # Cast grad_output to fp16 + if len(grad_output.shape) == 3: + grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + if req_gradB: + CxAt, SAt = F.transform(CAt, formatB, transpose=True) + C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) + gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) + grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + if state.threshold > 0.0 and subA is not None: + grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + + if req_gradA: + if state.CBt is not None: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + + elif state.CB is not None: + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + elif state.CxB is not None: + CB = ( + undo_layout(state.CxB, state.tile_indices) + .to(ctx.dtype_A) + .mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + ) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + else: + raise Exception("State must contain either CBt or CB or CxB matrix for backward") + + return grad_A, grad_B, None, grad_bias, None + + +class MatMul4Bit(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=None): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + B_shape = state[1] + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + + # 1. Dequantize + # 2. MatmulnN + output = torch.nn.functional.linear(A, F.dequantize_fp4(B, state).to(A.dtype).t(), bias) + + # 3. Save state + ctx.state = state + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (A, B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + + req_gradA, _, _, req_gradBias, _= ctx.needs_input_grad + A, B = ctx.tensors + state = ctx.state + + grad_A, grad_B, grad_bias = None, None, None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # not supported by PyTorch. TODO: create work-around + #if req_gradB: grad_B = torch.matmul(grad_output.t(), A) + if req_gradA: grad_A = torch.matmul(grad_output, F.dequantize_fp4(B, ctx.state).to(grad_output.dtype).t()) + + return grad_A, grad_B, None, grad_bias, None + + +def matmul( + A: tensor, + B: tensor, + out: tensor = None, + state: MatmulLtState = None, + threshold=0.0, + bias=None +): + state = state or MatmulLtState() + if threshold > 0.0: + state.threshold = threshold + return MatMul8bitLt.apply(A, B, out, bias, state) + + +def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): + assert quant_state is not None + return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py new file mode 100644 index 000000000..131edc5ee --- /dev/null +++ b/bitsandbytes/cextension.py @@ -0,0 +1,42 @@ +import ctypes as ct +import os +import torch + +from pathlib import Path +from warnings import warn + +from bitsandbytes.cuda_setup.main import CUDASetup + + +setup = CUDASetup.get_instance() +if setup.initialized != True: + setup.run_cuda_setup() + +lib = setup.lib +try: + if lib is None and torch.cuda.is_available(): + CUDASetup.get_instance().generate_instructions() + CUDASetup.get_instance().print_log_stack() + raise RuntimeError(''' + CUDA Setup failed despite GPU being available. Please run the following command to get more information: + + python -m bitsandbytes + + Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them + to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes + and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues''') + lib.cadam32bit_grad_fp32 # runs on an error if the library could not be found -> COMPILED_WITH_CUDA=False + lib.get_context.restype = ct.c_void_p + lib.get_cusparse.restype = ct.c_void_p + lib.cget_managed_ptr.restype = ct.c_void_p + COMPILED_WITH_CUDA = True +except AttributeError as ex: + warn("The installed version of bitsandbytes was compiled without GPU support. " + "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") + COMPILED_WITH_CUDA = False + print(str(ex)) + + +# print the setup details after checking for errors so we do not print twice +if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': + setup.print_log_stack() diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/cuda_setup/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/cuda_setup/env_vars.py new file mode 100644 index 000000000..4fcb643ee --- /dev/null +++ b/bitsandbytes/cuda_setup/env_vars.py @@ -0,0 +1,52 @@ +import os +from typing import Dict + + +def to_be_ignored(env_var: str, value: str) -> bool: + ignorable = { + "PWD", # PWD: this is how the shell keeps track of the current working dir + "OLDPWD", + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "HOME", # Linux shell default + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "MAIL", # something related to emails + "SHELL", # binary for currently invoked shell + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "PATH", # this is for finding binaries, not libraries + "LESSOPEN", # related to the `less` command + "LESSCLOSE", + "_", # current Python interpreter + } + return env_var in ignorable + + +def might_contain_a_path(candidate: str) -> bool: + return "/" in candidate + + +def is_active_conda_env(env_var: str) -> bool: + return "CONDA_PREFIX" == env_var + + +def is_other_conda_env_var(env_var: str) -> bool: + return "CONDA" in env_var + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return is_active_conda_env(env_var) or ( + might_contain_a_path(value) and not + is_other_conda_env_var(env_var) and not + to_be_ignored(env_var, value) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return { + env_var: value + for env_var, value in os.environ.items() + if is_relevant_candidate_env_var(env_var, value) + } diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/cuda_setup/main.py new file mode 100644 index 000000000..e7901d82e --- /dev/null +++ b/bitsandbytes/cuda_setup/main.py @@ -0,0 +1,427 @@ +""" +extract factors the build is dependent on: +[X] compute capability + [ ] TODO: Q - What if we have multiple GPUs of different makes? +- CUDA version +- Software: + - CPU-only: only CPU quantization functions (no optimizer, no matrix multipl) + - CuBLAS-LT: full-build 8-bit optimizer + - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) + +evaluation: + - if paths faulty, return meaningful error + - else: + - determine CUDA version + - determine capabilities + - based on that set the default path +""" + +import ctypes as ct +import os +import errno +import torch +from warnings import warn +from itertools import product + +from pathlib import Path +from typing import Set, Union +from .env_vars import get_potentially_lib_path_containing_env_vars + +# these are the most common libs names +# libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead +# we have libcudart.so.11.0 which causes a lot of errors before +# not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt +CUDA_RUNTIME_LIBS: list = ["libcudart.so", 'libcudart.so.11.0', 'libcudart.so.12.0'] + +# this is a order list of backup paths to search CUDA in, if it cannot be found in the main environmental paths +backup_paths = [] +backup_paths.append('$CONDA_PREFIX/lib/libcudart.so.11.0') + +class CUDASetup: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def generate_instructions(self): + if getattr(self, 'error', False): return + print(self.error) + self.error = True + if self.cuda is None: + self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected.') + self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') + self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') + self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') + self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') + self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') + return + + if self.cudart_path is None: + self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') + self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') + self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') + self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') + self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') + self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') + self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://github.com/TimDettmers/bitsandbytes/blob/main/cuda_install.sh') + self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') + self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') + return + + make_cmd = f'CUDA_VERSION={self.cuda_version_string}' + if len(self.cuda_version_string) < 3: + make_cmd += ' make cuda92' + elif self.cuda_version_string == '110': + make_cmd += ' make cuda110' + elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: + make_cmd += ' make cuda11x' + elif self.cuda_version_string == '100': + self.add_log_entry('CUDA SETUP: CUDA 10.0 not supported. Please use a different CUDA version.') + self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.') + return + + + has_cublaslt = is_cublasLt_compatible(self.cc) + if not has_cublaslt: + make_cmd += '_nomatmul' + + self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') + self.add_log_entry('git clone git@github.com:TimDettmers/bitsandbytes.git') + self.add_log_entry('cd bitsandbytes') + self.add_log_entry(make_cmd) + self.add_log_entry('python setup.py install') + + def initialize(self): + if not getattr(self, 'initialized', False): + self.has_printed = False + self.lib = None + self.initialized = False + self.error = False + + def run_cuda_setup(self): + self.initialized = True + self.cuda_setup_log = [] + + binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() + self.cudart_path = cudart_path + self.cuda = cuda + self.cc = cc + self.cuda_version_string = cuda_version_string + + package_dir = Path(__file__).parent.parent + binary_path = package_dir / binary_name + + print('bin', binary_path) + + try: + if not binary_path.exists(): + self.add_log_entry(f"CUDA SETUP: Required library version not found: {binary_name}. Maybe you need to compile it from source?") + legacy_binary_name = "libbitsandbytes_cpu.so" + self.add_log_entry(f"CUDA SETUP: Defaulting to {legacy_binary_name}...") + binary_path = package_dir / legacy_binary_name + if not binary_path.exists() or torch.cuda.is_available(): + self.add_log_entry('') + self.add_log_entry('='*48 + 'ERROR' + '='*37) + self.add_log_entry('CUDA SETUP: CUDA detection failed! Possible reasons:') + self.add_log_entry('1. CUDA driver not installed') + self.add_log_entry('2. CUDA not installed') + self.add_log_entry('3. You have multiple conflicting CUDA libraries') + self.add_log_entry('4. Required library not pre-compiled for this bitsandbytes release!') + self.add_log_entry('CUDA SETUP: If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION` for example, `make CUDA_VERSION=113`.') + self.add_log_entry('CUDA SETUP: The CUDA version for the compile might depend on your conda install. Inspect CUDA version via `conda list | grep cuda`.') + self.add_log_entry('='*80) + self.add_log_entry('') + self.generate_instructions() + raise Exception('CUDA SETUP: Setup Failed!') + self.lib = ct.cdll.LoadLibrary(binary_path) + else: + self.add_log_entry(f"CUDA SETUP: Loading binary {binary_path}...") + self.lib = ct.cdll.LoadLibrary(binary_path) + except Exception as ex: + self.add_log_entry(str(ex)) + + def add_log_entry(self, msg, is_warning=False): + self.cuda_setup_log.append((msg, is_warning)) + + def print_log_stack(self): + for msg, is_warning in self.cuda_setup_log: + if is_warning: + warn(msg) + else: + print(msg) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + +def is_cublasLt_compatible(cc): + has_cublaslt = False + if cc is not None: + cc_major, cc_minor = cc.split('.') + if int(cc_major) < 7 or (int(cc_major) == 7 and int(cc_minor) < 5): + CUDASetup.get_instance().add_log_entry("WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!", is_warning=True) + else: + has_cublaslt = True + return has_cublaslt + +def extract_candidate_paths(paths_list_candidate: str) -> Set[Path]: + return {Path(ld_path) for ld_path in paths_list_candidate.split(":") if ld_path} + + +def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: + existent_directories: Set[Path] = set() + for path in candidate_paths: + try: + if path.exists(): + existent_directories.add(path) + except OSError as exc: + if exc.errno != errno.ENAMETOOLONG: + raise exc + + non_existent_directories: Set[Path] = candidate_paths - existent_directories + if non_existent_directories: + CUDASetup.get_instance().add_log_entry("WARNING: The following directories listed in your path were found to " + f"be non-existent: {non_existent_directories}", is_warning=True) + + return existent_directories + + +def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: + paths = set() + for libname in CUDA_RUNTIME_LIBS: + for path in candidate_paths: + if (path / libname).is_file(): + paths.add(path / libname) + return paths + + +def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: + """ + Searches a given environmental var for the CUDA runtime library, + i.e. `libcudart.so`. + """ + return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) + + +def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: + return get_cuda_runtime_lib_paths( + resolve_paths_list(paths_list_candidate) + ) + + +def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: + if len(results_paths) > 1: + warning_msg = ( + f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. " + "We'll flip a coin and try one of these, in order to fail forward.\n" + "Either way, this might cause trouble in the future:\n" + "If you get `CUDA error: invalid device function` errors, the above " + "might be the cause and the solution is to make sure only one " + f"{CUDA_RUNTIME_LIBS} in the paths that we search based on your env.") + CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) + + +def determine_cuda_runtime_lib_path() -> Union[Path, None]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + if "CONDA_PREFIX" in candidate_env_vars: + conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" + + conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) + warn_in_case_of_duplicates(conda_cuda_libs) + + if conda_cuda_libs: + return next(iter(conda_cuda_libs)) + + CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' + f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) + + if "LD_LIBRARY_PATH" in candidate_env_vars: + lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) + + if lib_ld_cuda_libs: + return next(iter(lib_ld_cuda_libs)) + warn_in_case_of_duplicates(lib_ld_cuda_libs) + + CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' + f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) + + remaining_candidate_env_vars = { + env_var: value for env_var, value in candidate_env_vars.items() + if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} + } + + cuda_runtime_libs = set() + for env_var, value in remaining_candidate_env_vars.items(): + cuda_runtime_libs.update(find_cuda_lib_in(value)) + + if len(cuda_runtime_libs) == 0: + CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...') + cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) + + warn_in_case_of_duplicates(cuda_runtime_libs) + + return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None + + +def check_cuda_result(cuda, result_val): + # 3. Check for CUDA errors + if result_val != 0: + error_str = ct.c_char_p() + cuda.cuGetErrorString(result_val, ct.byref(error_str)) + if error_str.value is not None: + CUDASetup.get_instance().add_log_entry(f"CUDA exception! Error code: {error_str.value.decode()}") + else: + CUDASetup.get_instance().add_log_entry(f"Unknown CUDA exception! Please check your CUDA install. It might also be that your GPU is too old.") + + +# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION +def get_cuda_version(cuda, cudart_path): + if cuda is None: return None + + try: + cudart = ct.CDLL(cudart_path) + except OSError: + CUDASetup.get_instance().add_log_entry(f'ERROR: libcudart.so could not be read from path: {cudart_path}!') + return None + + version = ct.c_int() + try: + check_cuda_result(cuda, cudart.cudaRuntimeGetVersion(ct.byref(version))) + except AttributeError as e: + CUDASetup.get_instance().add_log_entry(f'ERROR: {str(e)}') + CUDASetup.get_instance().add_log_entry(f'CUDA SETUP: libcudart.so path is {cudart_path}') + CUDASetup.get_instance().add_log_entry(f'CUDA SETUP: Is seems that your cuda installation is not in your path. See https://github.com/TimDettmers/bitsandbytes/issues/85 for more information.') + version = int(version.value) + major = version//1000 + minor = (version-(major*1000))//10 + + if major < 11: + CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') + + return f'{major}{minor}' + + +def get_cuda_lib_handle(): + # 1. find libcuda.so library (GPU driver) (/usr/lib) + try: + cuda = ct.CDLL("libcuda.so") + except OSError: + CUDASetup.get_instance().add_log_entry('CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!') + return None + check_cuda_result(cuda, cuda.cuInit(0)) + + return cuda + + +def get_compute_capabilities(cuda): + """ + 1. find libcuda.so library (GPU driver) (/usr/lib) + init_device -> init variables -> call function by reference + 2. call extern C function to determine CC + (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) + 3. Check for CUDA errors + https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api + # bits taken from https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 + """ + + nGpus = ct.c_int() + cc_major = ct.c_int() + cc_minor = ct.c_int() + + device = ct.c_int() + + check_cuda_result(cuda, cuda.cuDeviceGetCount(ct.byref(nGpus))) + ccs = [] + for i in range(nGpus.value): + check_cuda_result(cuda, cuda.cuDeviceGet(ct.byref(device), i)) + ref_major = ct.byref(cc_major) + ref_minor = ct.byref(cc_minor) + # 2. call extern C function to determine CC + check_cuda_result(cuda, cuda.cuDeviceComputeCapability(ref_major, ref_minor, device)) + ccs.append(f"{cc_major.value}.{cc_minor.value}") + + return ccs + + +# def get_compute_capability()-> Union[List[str, ...], None]: # FIXME: error +def get_compute_capability(cuda): + """ + Extracts the highest compute capbility from all available GPUs, as compute + capabilities are downwards compatible. If no GPUs are detected, it returns + None. + """ + if cuda is None: return None + + # TODO: handle different compute capabilities; for now, take the max + ccs = get_compute_capabilities(cuda) + if ccs: return ccs[-1] + + +def evaluate_cuda_setup(): + if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': + print('') + print('='*35 + 'BUG REPORT' + '='*35) + print(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), + ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) + print('='*80) + if not torch.cuda.is_available(): return 'libbitsandbytes_cpu.so', None, None, None, None + + cuda_setup = CUDASetup.get_instance() + cudart_path = determine_cuda_runtime_lib_path() + cuda = get_cuda_lib_handle() + cc = get_compute_capability(cuda) + cuda_version_string = get_cuda_version(cuda, cudart_path) + + failure = False + if cudart_path is None: + failure = True + cuda_setup.add_log_entry("WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!", is_warning=True) + else: + cuda_setup.add_log_entry(f"CUDA SETUP: CUDA runtime path found: {cudart_path}") + + if cc == '' or cc is None: + failure = True + cuda_setup.add_log_entry("WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...", is_warning=True) + else: + cuda_setup.add_log_entry(f"CUDA SETUP: Highest compute capability among GPUs detected: {cc}") + + if cuda is None: + failure = True + else: + cuda_setup.add_log_entry(f'CUDA SETUP: Detected CUDA version {cuda_version_string}') + + # 7.5 is the minimum CC vor cublaslt + has_cublaslt = is_cublasLt_compatible(cc) + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + # we use ls -l instead of nvcc to determine the cuda version + # since most installations will have the libcudart.so installed, but not the compiler + + if failure: + binary_name = "libbitsandbytes_cpu.so" + elif has_cublaslt: + binary_name = f"libbitsandbytes_cuda{cuda_version_string}.so" + else: + "if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt.so" + binary_name = f"libbitsandbytes_cuda{cuda_version_string}_nocublaslt.so" + + return binary_name, cudart_path, cuda, cc, cuda_version_string diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py new file mode 100644 index 000000000..cc82943b8 --- /dev/null +++ b/bitsandbytes/functional.py @@ -0,0 +1,2464 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import ctypes as ct +import itertools +import operator +import random +import torch +import itertools +import math +from scipy.stats import norm +import numpy as np + +from functools import reduce # Required in Python 3 +from typing import Tuple +from torch import Tensor + +from .cextension import COMPILED_WITH_CUDA, lib + + +# math.prod not compatible with python < 3.8 +def prod(iterable): + return reduce(operator.mul, iterable, 1) + +name2qmap = {} + +if COMPILED_WITH_CUDA: + """C FUNCTIONS FOR OPTIMIZERS""" + str2optimizer32bit = {} + str2optimizer32bit["adam"] = (lib.cadam32bit_grad_fp32, lib.cadam32bit_grad_fp16, lib.cadam32bit_grad_bf16) + str2optimizer32bit["momentum"] = ( + lib.cmomentum32bit_grad_32, + lib.cmomentum32bit_grad_16, + ) + str2optimizer32bit["rmsprop"] = ( + lib.crmsprop32bit_grad_32, + lib.crmsprop32bit_grad_16, + ) + str2optimizer32bit["lion"] = ( + lib.clion32bit_grad_32, + lib.clion32bit_grad_16, + ) + str2optimizer32bit["adagrad"] = ( + lib.cadagrad32bit_grad_32, + lib.cadagrad32bit_grad_16, + ) + + str2optimizer8bit = {} + str2optimizer8bit["adam"] = ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ) + str2optimizer8bit["momentum"] = ( + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, + ) + str2optimizer8bit["rmsprop"] = ( + lib.crmsprop_static_8bit_grad_32, + lib.crmsprop_static_8bit_grad_16, + ) + str2optimizer8bit["lion"] = ( + lib.clion_static_8bit_grad_32, + lib.clion_static_8bit_grad_16, + ) + str2optimizer8bit["lamb"] = ( + lib.cadam_static_8bit_grad_32, + lib.cadam_static_8bit_grad_16, + ) + str2optimizer8bit["lars"] = ( + lib.cmomentum_static_8bit_grad_32, + lib.cmomentum_static_8bit_grad_16, + ) + + str2optimizer8bit_blockwise = {} + str2optimizer8bit_blockwise["adam"] = ( + lib.cadam_8bit_blockwise_grad_fp32, + lib.cadam_8bit_blockwise_grad_fp16, + lib.cadam_8bit_blockwise_grad_bf16, + ) + str2optimizer8bit_blockwise["momentum"] = ( + lib.cmomentum_8bit_blockwise_grad_fp32, + lib.cmomentum_8bit_blockwise_grad_fp16, + ) + str2optimizer8bit_blockwise["rmsprop"] = ( + lib.crmsprop_8bit_blockwise_grad_fp32, + lib.crmsprop_8bit_blockwise_grad_fp16, + ) + str2optimizer8bit_blockwise["lion"] = ( + lib.clion_8bit_blockwise_grad_fp32, + lib.clion_8bit_blockwise_grad_fp16, + ) + str2optimizer8bit_blockwise["adagrad"] = ( + lib.cadagrad_8bit_blockwise_grad_fp32, + lib.cadagrad_8bit_blockwise_grad_fp16, + ) + +class GlobalPageManager: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.paged_tensors = [] + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def prefetch_all(self, to_cpu=False): + # assume the first added, will be hte + # ones that are used first, so swap them in last + # in the case they are evicted again + for t in self.paged_tensors[::-1]: + prefetch_tensor(t, to_cpu) + + + +class CUBLAS_Context: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.context = {} + # prev_device = torch.cuda.current_device() + # for i in range(torch.cuda.device_count()): + # torch.cuda.set_device(torch.device('cuda', i)) + # self.context.append(ct.c_void_p(lib.get_context())) + # torch.cuda.set_device(prev_device) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def get_context(self, device): + if device.index not in self.context: + prev_device = torch.cuda.current_device() + torch.cuda.set_device(device) + self.context[device.index] = ct.c_void_p(lib.get_context()) + torch.cuda.set_device(prev_device) + return self.context[device.index] + + +class Cusparse_Context: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.context = ct.c_void_p(lib.get_cusparse()) + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + +dtype2bytes = {} +dtype2bytes[torch.float32] = 4 +dtype2bytes[torch.float16] = 2 +dtype2bytes[torch.bfloat16] = 2 +dtype2bytes[torch.uint8] = 1 +dtype2bytes[torch.int8] = 1 + +def get_paged(*shape, dtype=torch.float32, device=torch.device('cuda', index=0)): + num_bytes = dtype2bytes[dtype]*prod(shape) + cuda_ptr = lib.cget_managed_ptr(ct.c_size_t(num_bytes)) + c_ptr = ct.cast(cuda_ptr, ct.POINTER(ct.c_int)) + new_array = np.ctypeslib.as_array(c_ptr, shape=shape) + out = torch.frombuffer(new_array, dtype=dtype, count=prod(shape)).view(shape) + out.is_paged = True + out.page_deviceid = device.index + return out + +def prefetch_tensor(A, to_cpu=False): + assert A.is_paged, 'Only paged tensors can be prefetched!' + if to_cpu: + deviceid = -1 + else: + deviceid = A.page_deviceid + + num_bytes = dtype2bytes[A.dtype]*A.numel() + lib.cprefetch(get_ptr(A), ct.c_size_t(num_bytes), ct.c_int32(deviceid)) + +def elementwise_func(func_name, A, B, value, prefetch=True): + func = None + if A.dtype == torch.float32: + func = getattr(lib, f'c{func_name}_fp32', None) + cvalue = ct.c_float(value) + elif A.dtype == torch.uint8: + func = getattr(lib, f'c{func_name}_uint8', None) + cvalue = ct.c_uint8(value) + + if func is None: raise NotImplementedError(f'Function not implemented: {func_name}') + + is_managed = getattr(A, 'is_managed', False) + if is_managed and prefetch: + prefetch_tensor(A) + if B is not None: prefetch_tensor(B) + + func(get_ptr(A), get_ptr(B), cvalue, ct.c_int64(A.numel())) + if A.is_paged or B.is_paged: + # paged function are fully asynchronous + # if we return from this function, we want to the tensor + # to be in the correct state, that is the final state after the + # operation occured. So we synchronize. + torch.cuda.synchronize() + +def fill(A, value, device=None, prefetch=True): elementwise_func('fill', A, None, value) +def arange(A, device=None): elementwise_func('arange', A, None, 0) +def _mul(A, B, device=None): elementwise_func('_mul', A, B, 0) + + +def create_linear_map(signed=True, total_bits=8, add_zero=True): + sign = (-1.0 if signed else 0.0) + total_values = 2**total_bits + if add_zero or total_bits < 8: + # add a zero + # since we simulate less bits by having zeros in the data type, we + # we need to center the quantization around zero and as such lose + # a single value + total_values = (2**total_bits if not signed else 2**total_bits-1) + + values = torch.linspace(sign, 1.0, total_values) + gap = 256 - values.numel() + if gap == 0: + return values + else: + l = values.numel()//2 + #return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist()) + return torch.Tensor(values[:l].tolist() + [0]*gap + values[l:].tolist()) + +def create_custom_map(seed=0, scale=0.01): + v = [12, 10, 8, 6, 3, 2, 1] + # 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45 + # 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48 + + # 13B 100 steps: + # - 4-bit evo: 86.02 + # - 4-bit norm: 78.73 + # - 4-bit FP4: + # - 16-bit: + + # interval search on normal distribution + #v = [3.090232306167813, 1.4589770349449647, 1.064410327932115, 0.7896806653244509, 0.5646884166925807, 0.3653406435875121, 0.17964844284441311] # 0.999 26.5 + #v = [2.3263478740408408, 1.4050715603096329, 1.0364333894937898, 0.7721932141886848, 0.5533847195556727, 0.3584587932511938, 0.1763741647808615] # 0.99 24.99 + #v = [1.6448536269514722, 1.2040469600267016, 0.9208229763683788, 0.6971414348463417, 0.5039653672113453, 0.3280721075316511, 0.16184416680396213] # 0.95 24.53 22.97 + #v = [1.4050715603096329, 1.0803193408149558, 0.8416212335729143, 0.643345405392917, 0.4676987991145084, 0.3054807880993974, 0.1509692154967774] # 0.92 24.81 + #v = [1.2815515655446004, 1.0062699858608395, 0.7916386077433746, 0.6084981344998837, 0.4438613119262478, 0.29050677112339396, 0.14372923370582416] # 0.9 24.68 + #v = [1.8807936081512509, 1.2980047163986055, 0.9769954022693226, 0.7341502955472268, 0.5285136765472481, 0.343225833559403, 0.16910470304375366] # 0.97 25.03 + #v = [1.7506860712521692, 1.2496468758017434, 0.9485350408266378, 0.7155233557034365, 0.5162006366043174, 0.3356393360829622, 0.16547334454641704] # 0.96 24.85 23.01 + #v = [1.5547735945968535, 1.1608220210715001, 0.893800631179489, 0.6789921163940618, 0.4918050830048072, 0.3205236191093902, 0.15821711945563585] # 0.94 24.47 + #v = [1.475791028179171, 1.1196635980209986, 0.8674156943957149, 0.6610637542614526, 0.4797170937629045, 0.31299335020578195, 0.15459215234139795] # 0.93 24.85 + #v = [1.5981931399228175, 1.1821583959486879, 0.9072289939325966, 0.6880384454306778, 0.49787602226482025, 0.3242955535308664, 0.160030379970179] # 0.945 24.287 + ##v = [1.6164363711150211, 1.1908453913294612, 0.9126463450304729, 0.6916727602238111, 0.5003095327012462, 0.3258056171348078, 0.1607558311941979] # 0.947 24.293 + #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.207 + #v = [1.6118251211466303, 1.188665228776879, 0.9112895004060624, 0.690763326564427, 0.4997008778346997, 0.3254280317127771, 0.16057446047146948] # 0.9465 24.30 + #v = [1.6027040905517569, 1.184321770169049, 0.9085808314549837, 0.6889461706317986, 0.4984841229538408, 0.32467299997597887, 0.1602117348657326] # 0.9455 24.293 + #v = [1.6072478919002173, 1.1864907014855421, 0.9099343314196248, 0.6898544638558411, 0.4990924080314459, 0.32505049268156666, 0.16039309503073892] # 0.946 24.37 22.88 + + # 7B evo start + #v = [1.62129629, 1.18870191, 0.90848106, 0.69108646, 0.50515268, 0.34927819905, 0.14122701] # 22.06 + #v = [1.6143079205628337, 1.1888081407660314, 0.8990131955745421, 0.694373759813679, 0.5083033257326773, 0.3452499746844963, 0.1148939728228951] + #v = [1.614442766030303, 1.189401918639665, 0.8998038168964273, 0.6953094818279475, 0.5073264599048384, 0.3449003790823619, 0.11428378427205564] + + # 13B evo start + #v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042] + #v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283] + v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908] + + # mean evo 7B + 13B + #v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237] + + # theoretically optiomal (0.93333) + #v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333 + + if seed > 0: + v = np.array(v) + np.random.seed(seed) + v += np.random.randn(7)*scale + print(v.tolist()) + #v[0] += (np.random.randn(1)*0.001)[0] + #v[-1] += (np.random.randn(1)*0.001)[0] + #print(v[0], v[-1]) + v = v.tolist() + values = v + [0]*(256-14) + \ + v[::-1] + + values = torch.Tensor(values) + values[0:7] *= -1 + values = values.sort().values + values /= values.max() + assert values.numel() == 256 + return values + +def create_normal_map(offset=0.9677083, use_extra_value=True): + + if use_extra_value: + # one more positive value, this is an asymmetric type + v1 = norm.ppf(torch.linspace(offset, 0.5, 9)[:-1]).tolist() + v2 = [0]*(256-15) ## we have 15 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() + v = v1 + v2 + v3 + else: + v1 = norm.ppf(torch.linspace(offset, 0.5, 8)[:-1]).tolist() + v2 = [0]*(256-14) ## we have 14 non-zero values in this data type + v3 = (-norm.ppf(torch.linspace(offset, 0.5, 8)[:-1])).tolist() + v = v1 + v2 + v3 + + values = torch.Tensor(v) + values = values.sort().values + values /= values.max() + assert values.numel() == 256 + return values + +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8): + e = exponent_bits + p = precision_bits + has_sign = 1 if signed else 0 + assert e+p == total_bits-has_sign + # the exponent is biased to 2^(e-1) -1 == 0 + evalues = [] + pvalues = [] + for i, val in enumerate(range(-((2**(exponent_bits-has_sign))), 2**(exponent_bits-has_sign), 1)): + evalues.append(2**val) + + + values = [] + lst = list(itertools.product([0, 1], repeat=precision_bits)) + #for ev in evalues: + bias = 2**(exponent_bits-1) + for evalue in range(2**(exponent_bits)): + for bit_pattern in lst: + value = (1 if evalue != 0 else 0) + for i, pval in enumerate(list(bit_pattern)): + value += pval*(2**-(i+1)) + if evalue == 0: + # subnormals + value = value*2**-(bias) + else: + # normals + value = value*2**-(evalue-bias-1) + values.append(value) + if signed: + values.append(-value) + + + assert len(values) == 2**total_bits + values.sort() + if total_bits < 8: + gap = 256 - len(values) + for i in range(gap): + values.append(0) + values.sort() + code = torch.Tensor(values) + code /= code.max() + + return code + + + +def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8): + """ + Creates the dynamic quantiztion map. + + The dynamic data type is made up of a dynamic exponent and + fraction. As the exponent increase from 0 to -7 the number + of bits available for the fraction shrinks. + + This is a generalization of the dynamic type where a certain + number of the bits and be reserved for the linear quantization + region (the fraction). n determines the maximum number of + exponent bits. + + For more details see + (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561] + """ + + data = [] + # these are additional items that come from the case + # where all the exponent bits are zero and no + # indicator bit is present + non_sign_bits = total_bits - (1 if signed else 0) + additional_items = 2 ** (non_sign_bits - max_exponent_bits) - 1 + if not signed: + additional_items = 2 * additional_items + for i in range(max_exponent_bits): + fraction_items = int((2 ** (i + non_sign_bits - max_exponent_bits) + 1 if signed else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1)) + boundaries = torch.linspace(0.1, 1, fraction_items) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + + if additional_items > 0: + boundaries = torch.linspace(0.1, 1, additional_items + 1) + means = (boundaries[:-1] + boundaries[1:]) / 2.0 + data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + if signed: + data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist() + + data.append(0) + data.append(1.0) + + gap = 256 - len(data) + for i in range(gap): + data.append(0) + + data.sort() + return Tensor(data) + +def create_quantile_map(A, total_bits=8): + q = estimate_quantiles(A, num_quantiles=2**total_bits-1) + q = q.tolist() + q.append(0) + + gap = 256 - len(q) + for i in range(gap): + q.append(0) + + q.sort() + + q = Tensor(q) + q = q/q.abs().max() + return q + +def get_special_format_str(): + if not torch.cuda.is_available(): return 'col_turing' + major, _minor = torch.cuda.get_device_capability() + if major <= 7: + return "col_turing" + if major == 8: + return "col_ampere" + return "col_turing" + + + +def is_on_gpu(tensors): + on_gpu = True + gpu_ids = set() + for t in tensors: + if t is None: continue # NULL pointers are fine + is_paged = getattr(t, 'is_paged', False) + on_gpu &= (t.device.type == 'cuda' or is_paged) + if not is_paged: + gpu_ids.add(t.device.index) + if not on_gpu: + raise TypeError(f'All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}') + if len(gpu_ids) > 1: + raise TypeError(f'Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}') + return on_gpu + +def get_ptr(A: Tensor) -> ct.c_void_p: + """ + Get the ctypes pointer from a PyTorch Tensor. + + Parameters + ---------- + A : torch.tensor + The PyTorch tensor. + + Returns + ------- + ctypes.c_void_p + """ + if A is None: + return None + else: + return ct.c_void_p(A.data.data_ptr()) + + +def pre_call(device): + prev_device = torch.cuda.current_device() + torch.cuda.set_device(device) + return prev_device + + +def post_call(prev_device): + torch.cuda.set_device(prev_device) + + +def get_transform_func(dtype, orderA, orderOut, transpose=False): + name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' + if not hasattr(lib, name): + print(name) + raise ValueError( + f"Transform function not supported: {orderA} to {orderOut} for data type {dtype} and transpose={transpose}" + ) + else: + return getattr(lib, name) + + +def get_transform_buffer( + shape, dtype, device, to_order, from_order="row", transpose=False +): + # init_func = torch.empty + init_func = torch.zeros + dims = len(shape) + + if dims == 2: + rows = shape[0] + elif dims == 3: + rows = shape[0] * shape[1] + cols = shape[-1] + + state = (shape, to_order) + if transpose: + # swap dims + tmp = rows + rows = cols + cols = tmp + state = (shape[::-1], to_order) + + if to_order == "row" or to_order == "col": + return init_func(shape, dtype=dtype, device=device), state + elif to_order == "col32": + # blocks of 32 columns (padded) + cols = 32 * ((cols + 31) // 32) + return init_func((rows, cols), dtype=dtype, device=device), state + elif to_order == "col_turing": + # blocks of 32 columns and 8 rows + cols = 32 * ((cols + 31) // 32) + rows = 8 * ((rows + 7) // 8) + return init_func((rows, cols), dtype=dtype, device=device), state + elif to_order == "col_ampere": + # blocks of 32 columns and 32 rows + cols = 32 * ((cols + 31) // 32) + rows = 32 * ((rows + 31) // 32) + return init_func((rows, cols), dtype=dtype, device=device), state + else: + raise NotImplementedError(f"To_order not supported: {to_order}") + + +def nvidia_transform( + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, +): + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1] + ) + else: + new_state = (state[1], to_order) + func = get_transform_func(A.dtype, from_order, to_order, transpose) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + elif ld is not None: + n = prod(shape) + dim1 = prod([shape[i] for i in ld]) + dim2 = ct.c_int32(n // dim1) + dim1 = ct.c_int32(dim1) + else: + dim1 = ct.c_int32(shape[0] * shape[1]) + dim2 = ct.c_int32(shape[2]) + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + func(ptr, get_ptr(A), get_ptr(out), dim1, dim2) + + return out, new_state + + +def estimate_quantiles(A: Tensor, out: Tensor = None, offset: float = 1 / 512, num_quantiles=256) -> Tensor: + ''' + Estimates 256 equidistant quantiles on the input tensor eCDF. + + Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles + via the eCDF of the input tensor `A`. This is a fast but approximate algorithm + and the extreme quantiles close to 0 and 1 have high variance / large estimation + errors. These large errors can be avoided by using the offset variable which trims + the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it + trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02 + usually has a much lower error but is not a minimum entropy encoding. Given an offset + of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles. + + Parameters + ---------- + A : torch.Tensor + The input tensor. Any shape. + out : torch.Tensor + Tensor with the 256 estimated quantiles. + offset : float + The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles) + num_quantiles : int + The number of equally spaced quantiles. + + Returns + ------- + torch.Tensor: + The 256 quantiles in float32 datatype. + ''' + if A.numel() < 256: raise NotImplementedError(f'Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.') + if num_quantiles > 256: raise NotImplementedError(f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}") + if num_quantiles < 256 and offset == 1/(512): + # override default arguments + offset = 1/(2*num_quantiles) + + if out is None: out = torch.zeros((256,), dtype=torch.float32, device=A.device) + is_on_gpu([A, out]) + device = pre_call(A.device) + if A.dtype == torch.float32: + lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + elif A.dtype == torch.float16: + lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel())) + else: + raise NotImplementedError(f"Not supported data type {A.dtype}") + post_call(device) + + if num_quantiles < 256: + step = round(256/num_quantiles) + idx = torch.linspace(0, 255, num_quantiles).long().to(A.device) + out = out[idx] + + return out + + +def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, out: Tensor = None, blocksize=4096, nested=False) -> Tensor: + """ + Quantize tensor A in blocks of size 4096 values. + + Quantizes tensor A by dividing it into blocks of 4096 values. + Then the absolute maximum value within these blocks is calculated + for the non-linear quantization. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + code : torch.Tensor + The quantization map. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + + Returns + ------- + torch.Tensor: + The 8-bit tensor. + tuple(torch.Tensor, torch.Tensor): + The quantization state to undo the quantization. + """ + + + if code is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + if absmax is None: + n = A.numel() + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device) + + if out is None: + out = torch.zeros_like(A, dtype=torch.uint8) + + if A.device.type != 'cpu': + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + cblocksize = ct.c_int32(blocksize) + prev_device = pre_call(A.device) + code = code.to(A.device) + is_on_gpu([code, A, out, absmax]) + if A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + elif A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + else: + # cpu + code = code.cpu() + lib.cquantize_blockwise_cpu_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + + if nested: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=blocksize, nested=False) + state = [qabsmax, code, blocksize, nested, offset, state2] + else: + state = [absmax, code, blocksize, nested, None, None] + + + + return out, state + + +def dequantize_blockwise( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, + blocksize: int = 4096, + nested=False +) -> Tensor: + """ + Dequantizes blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in + blocks of size 4096. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor. + quant_state : tuple(torch.Tensor, torch.Tensor) + Tuple of code and absmax values. + absmax : torch.Tensor + The absmax values. + code : torch.Tensor + The quantization map. + out : torch.Tensor + Dequantized output tensor (default: float32) + + + Returns + ------- + torch.Tensor: + Dequantized tensor (default: float32) + """ + assert quant_state is not None or absmax is not None + if code is None and quant_state is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + + if out is None: + out = torch.zeros_like(A, dtype=torch.float32) + + if quant_state is None: + quant_state = (absmax, code, blocksize) + assert absmax is not None and out is not None + else: + absmax, code, blocksize, nested, offset, state2 = quant_state + if nested: + absmax = dequantize_blockwise(absmax, state2) + absmax += offset + + + if A.device.type != 'cpu': + device = pre_call(A.device) + code = code.to(A.device) + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + elif out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(A.numel())) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + else: + code = code.cpu() + lib.cdequantize_blockwise_cpu_fp32(get_ptr(quant_state[1]), get_ptr(A), get_ptr(quant_state[0]), get_ptr(out), ct.c_longlong(blocksize), ct.c_longlong(A.numel())) + + return out + +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') + +def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') + +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if A.device.type != 'cuda': + raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device) + + + if out is None: + out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + prev_device = pre_call(A.device) + is_on_gpu([A, out, absmax]) + + if A.dtype == torch.float32: + if quant_type == 'fp4': + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.float16: + if quant_type == 'fp4': + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + #code = create_custom_map().to(absmax.device) + #qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256) + qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + del absmax + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + else: + state = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + + return out, state + +def dequantize_fp4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') + +def dequantize_nf4(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') + +def dequantize_4bit(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : tuple(torch.Tensor, torch.Size, torch.dtype) + Tuple of absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + + if quant_state is None: + assert absmax is not None and out is not None + shape = out.shape + dtype = out.dtype + else: + absmax, shape, dtype, blocksize, compressed_stats, quant_type = quant_state + + + if compressed_stats is not None: + offset, state2 = compressed_stats + absmax = dequantize_blockwise(absmax, state2) + absmax += offset + + if out is None: + out = torch.empty(shape, dtype=dtype, device=A.device) + + n = out.numel() + + + device = pre_call(A.device) + is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: + if quant_type == 'fp4': + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + elif out.dtype == torch.float16: + if quant_type == 'fp4': + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() + else: return out + + +def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: + if code is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + code = code.to(A.device) + + absmax = torch.abs(A).max() + inp = A / absmax + out = quantize_no_absmax(inp, code, out) + return out, (absmax, code) + + +def dequantize( + A: Tensor, + quant_state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, +) -> Tensor: + assert quant_state is not None or absmax is not None + if code is None and quant_state is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + code = code.to(A.device) + + if quant_state is None: + quant_state = (absmax, code) + out = dequantize_no_absmax(A, quant_state[1], out) + return out * quant_state[0] + + +def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: + ''' + Quantizes input tensor to 8-bit. + + Quantizes the 32-bit input tensor `A` to the 8-bit output tensor + `out` using the quantization map `code`. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + code : torch.Tensor + The quantization map. + out : torch.Tensor, optional + The output tensor. Needs to be of type byte. + + Returns + ------- + torch.Tensor: + Quantized 8-bit tensor. + ''' + prev_device = pre_call(A.device) + if out is None: out = torch.zeros_like(A, dtype=torch.uint8) + is_on_gpu([A, out]) + lib.cquantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + post_call(prev_device) + return out + + +def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: + ''' + Dequantizes the 8-bit tensor to 32-bit. + + Dequantizes the 8-bit tensor `A` to the 32-bit tensor `out` via + the quantization map `code`. + + Parameters + ---------- + A : torch.Tensor + The 8-bit input tensor. + code : torch.Tensor + The quantization map. + out : torch.Tensor + The 32-bit output tensor. + + Returns + ------- + torch.Tensor: + 32-bit output tensor. + ''' + prev_device = pre_call(A.device) + if out is None: out = torch.zeros_like(A, dtype=torch.float32) + is_on_gpu([code, A, out]) + lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + post_call(prev_device) + return out + + +def optimizer_update_32bit( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Tensor = None, + beta2: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Tensor = None, + max_unorm: float = 0.0, + skip_zeros=False, +) -> None: + """ + Performs an inplace optimizer update with one or two optimizer states. + + Universal optimizer update for 32-bit state and 32/16-bit gradients/weights. + + Parameters + ---------- + optimizer_name : str + The name of the optimizer: {adam}. + g : torch.Tensor + Gradient tensor. + p : torch.Tensor + Parameter tensor. + state1 : torch.Tensor + Optimizer state 1. + beta1 : float + Optimizer beta1. + eps : float + Optimizer epsilon. + weight_decay : float + Weight decay. + step : int + Current optimizer step. + lr : float + The learning rate. + state2 : torch.Tensor + Optimizer state 2. + beta2 : float + Optimizer beta2. + gnorm_scale : float + The factor to rescale the gradient to the max clip value. + unorm_vec : torch.Tensor + The tensor for the update norm. + max_unorm : float + The maximum update norm relative to the weight norm. + skip_zeros : bool + Whether to skip zero-valued gradients or not (default: False). + """ + + param_norm = 0.0 + if max_unorm > 0.0: + param_norm = torch.norm(p.data.float()) + + + optim_func = None + if g.dtype == torch.float32: + optim_func = str2optimizer32bit[optimizer_name][0] + elif g.dtype == torch.float16: + optim_func = str2optimizer32bit[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name])==3): + optim_func = str2optimizer32bit[optimizer_name][2] + else: + raise ValueError(f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}") + + is_on_gpu([g, p, state1, state2, unorm_vec]) + prev_device = pre_call(g.device) + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel())) + post_call(prev_device) + + +def optimizer_update_8bit( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + state2: Tensor, + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: Tensor, + qmap2: Tensor, + max1: Tensor, + max2: Tensor, + new_max1: Tensor, + new_max2: Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Tensor = None, + max_unorm: float = 0.0, +) -> None: + """ + Performs an inplace Adam update. + + Universal Adam update for 32/8-bit state and 32/16-bit gradients/weights. + Uses AdamW formulation if weight decay > 0.0. + + Parameters + ---------- + optimizer_name : str + The name of the optimizer. Choices {adam, momentum} + g : torch.Tensor + Gradient tensor. + p : torch.Tensor + Parameter tensor. + state1 : torch.Tensor + Adam state 1. + state2 : torch.Tensor + Adam state 2. + beta1 : float + Adam beta1. + beta2 : float + Adam beta2. + eps : float + Adam epsilon. + weight_decay : float + Weight decay. + step : int + Current optimizer step. + lr : float + The learning rate. + qmap1 : torch.Tensor + Quantization map for first Adam state. + qmap2 : torch.Tensor + Quantization map for second Adam state. + max1 : torch.Tensor + Max value for first Adam state update. + max2 : torch.Tensor + Max value for second Adam state update. + new_max1 : torch.Tensor + Max value for the next Adam update of the first state. + new_max2 : torch.Tensor + Max value for the next Adam update of the second state. + gnorm_scale : float + The factor to rescale the gradient to the max clip value. + unorm_vec : torch.Tensor + The tensor for the update norm. + max_unorm : float + The maximum update norm relative to the weight norm. + """ + + param_norm = 0.0 + if max_unorm > 0.0: + param_norm = torch.norm(p.data.float()) + + prev_device = pre_call(g.device) + is_on_gpu([g, p, state1, state2, unorm_vec, qmap1, qmap2, max1, max2, new_max1, new_max2]) + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + str2optimizer8bit[optimizer_name][0]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + str2optimizer8bit[optimizer_name][1]( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(max1), + get_ptr(max2), + get_ptr(new_max1), + get_ptr(new_max2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_int32(g.numel()), + ) + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) + post_call(prev_device) + + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: Tensor, + p: Tensor, + state1: Tensor, + state2: Tensor, + beta1: float, + beta2: float, + eps: float, + step: int, + lr: float, + qmap1: Tensor, + qmap2: Tensor, + absmax1: Tensor, + absmax2: Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + + optim_func = None + prev_device = pre_call(g.device) + is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif (g.dtype == torch.bfloat16 and state1.dtype == torch.uint8 and + len(str2optimizer8bit_blockwise[optimizer_name])==3): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}" + ) + post_call(prev_device) + + is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) + + prev_device = pre_call(g.device) + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) + post_call(prev_device) + +def percentile_clipping( + grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5 +): + """Applies percentile clipping + + grad: torch.Tensor + The gradient tensor. + gnorm_vec: torch.Tensor + Vector of gradient norms. 100 elements expected. + step: int + The current optimiation steps (number of past gradient norms). + + """ + prev_device = pre_call(grad.device) + is_on_gpu([grad, gnorm_vec]) + if grad.dtype == torch.float32: + lib.cpercentile_clipping_g32( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) + elif grad.dtype == torch.float16: + lib.cpercentile_clipping_g16( + get_ptr(grad), + get_ptr(gnorm_vec), + ct.c_int32(step), + ct.c_int32(grad.numel()), + ) + else: + raise ValueError(f"Gradient type {grad.dtype} not supported!") + post_call(prev_device) + + current_gnorm = torch.sqrt(gnorm_vec[step % 100]) + vals, idx = torch.sort(gnorm_vec) + clip_value = torch.sqrt(vals[percentile]) + gnorm_scale = 1.0 + + if current_gnorm > clip_value: + gnorm_scale = clip_value / current_gnorm + + return current_gnorm, clip_value, gnorm_scale + + +def histogram_scatter_add_2d( + histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor +): + assert len(histogram.shape) == 2 + assert histogram.dtype == torch.float32 + assert source.dtype == torch.float32 + assert index1.dtype == torch.int32 + assert index2.dtype == torch.int32 + + assert histogram.device.type == "cuda" + assert index1.device.type == "cuda" + assert index2.device.type == "cuda" + assert source.device.type == "cuda" + + maxdim1 = ct.c_int32(histogram.shape[0]) + n = ct.c_int32(index1.numel()) + is_on_gpu([histogram, index1, index2, source]) + lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n) + +def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8): + if not torch.cuda.is_initialized(): torch.cuda.init() + if A.dtype != expected_type or B.dtype != expected_type: + raise TypeError( + f"Expected torch.int8 input tensors A and B, but got {A.dtype} and {B.dtype}" + ) + + sA = A.shape + sB = B.shape + tA = transposed_A + tB = transposed_B + + correct = True + + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[0] != B.shape[0]: + correct = False + elif tA and tB and A.shape[0] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB and A.shape[2] != B.shape[0]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[0]: + correct = False + elif tA and tB and A.shape[1] != B.shape[1]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[1]: + correct = False + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB and A.shape[2] != B.shape[1]: + correct = False + elif tA and not tB and A.shape[1] != B.shape[1]: + correct = False + elif tA and tB and A.shape[1] != B.shape[2]: + correct = False + elif not tA and tB and A.shape[2] != B.shape[2]: + correct = False + + if out is not None: + sout = out.shape + # special case common in backprop + if not correct and len(sA) == 3 and len(sB) == 3: + if ( + sout[0] == sA[2] + and sout[1] == sB[2] + and sA[0] == sB[0] + and sA[1] == sB[1] + ): + correct = True + else: + if len(sA) == 2 and len(sB) == 2: + if not tA and not tB: + sout = (sA[0], sB[1]) + elif tA and tB: + sout = (sA[1], sB[0]) + elif tA and not tB: + sout = (sA[1], sB[1]) + elif not tA and tB: + sout = (sA[0], sB[0]) + elif len(sA) == 3 and len(sB) == 2: + if not tA and not tB: + sout = (sA[0], sA[1], sB[1]) + elif tA and tB: + sout = (sA[0], sA[2], sB[0]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[1]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[0]) + elif len(sA) == 3 and len(sB) == 3: + if not tA and not tB: + sout = (sA[0], sA[1], sB[2]) + elif tA and tB: + sout = (sA[0], sA[2], sB[1]) + elif tA and not tB: + sout = (sA[0], sA[2], sB[2]) + elif not tA and tB: + sout = (sA[0], sA[1], sB[1]) + + if not correct: + raise ValueError( + f"Tensor dimensions incorrect for matrix mulitiplication: A x B: {sA} x {sB} with transpose for A x B: {tA} x {tB}." + ) + + return sout + +def cutlass3_gemm( + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, + state=None +): + #sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) + if state is None: + Bshape = B.shape + bout = Bshape[1] + else: + Bshape = state[1] + bout = Bshape[0] + if out is None: + out = torch.zeros(size=(A.shape[0], bout), dtype=A.dtype, device=A.device) + + sA = A.shape + sB = B.shape + if transposed_A and len(sA) == 2: + sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: + sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: + sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: + sB = (sB[0], sB[2], sB[0]) + # this is a mess: cuBLAS expect column major, but PyTorch is row major. + # So to perform the matrix multiplication, we have to treat A, B, and C matrices + # (transpose of row major is column major) + # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these + + # matrices in the input arguments for cuBLAS + # column major: A @ B = C: [m, k] @ [k, n] = [m, n] + # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] + # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] + if len(sB) == 2: + if B.stride()[0] == B.shape[1]: + transposed_B = False + elif B.stride()[1] == B.shape[0]: + transposed_B = True + if len(A.shape) == 2: + if A.stride()[0] == A.shape[1]: + transposed_A = False + elif A.stride()[1] == A.shape[0]: + transposed_A = True + else: + if A.stride()[1] == A.shape[2]: + transposed_A = False + elif A.stride()[2] == A.shape[1]: + transposed_A = True + + if len(sA) == 2: + n = sA[0] + ldb = A.stride()[1 if transposed_A else 0] + elif len(sA) == 3 and len(sB) == 2: + n = sA[0] * sA[1] + ldb = sA[2] + + m = sB[1] + k = sB[0] + lda = B.stride()[0] + ldc = sB[1] + elif len(sB) == 3: + # special case + assert len(sA) == 3 + if not (sA[0] == sB[0] and sA[1] == sB[1]): + raise ValueError( + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + ) + + transposed_A = True + transposed_B = False + + m = sB[2] + n = sA[2] + k = sB[0] * sB[1] + + lda = n + ldb = sA[2] + ldc = m + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + + # B^T @ A^T = C^T + # [km, nk -> mn] + #lda = ldb = ldc = 1 + #lda = 1 + if state is not None: + m = Bshape[0] + k = Bshape[1] + lda = Bshape[0] + ldc = Bshape[0] + ldb = (ldb+1)//2 + #print(m, n, k, lda, ldb, ldc) + is_on_gpu([B, A, out]) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + lda = ct.c_int32(lda) + ldb = ct.c_int32(ldb) + ldc = ct.c_int32(ldc) + + if B.dtype == torch.uint8: + lib.cgemm_4bit_inference(m, n, k, get_ptr(A), get_ptr(B), get_ptr(state[0]), get_ptr(out), lda, ldb, ldc, ct.c_int32(state[3])) + elif A.dtype == torch.float32: + lib.cgemm_host_fp32(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + elif A.dtype == torch.float16: + lib.cgemm_host_fp16(m, n, k, get_ptr(A), get_ptr(B), get_ptr(out), lda, ldb, ldc) + else: + raise NotImplementedError(f'Matmul not implemented for data type {A.dtype}') + + return out + + + + +def igemm( + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, +): + sout = check_matmul(A, B, out, transposed_A, transposed_B) + if out is None: + out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + if len(A.shape) == 3 and len(B.shape) == 3: + if A.shape[0] == B.shape[0] and A.shape[2] == B.shape[1]: + return batched_igemm(A, B, out) + + sA = A.shape + sB = B.shape + if transposed_A and len(sA) == 2: + sA = (sA[1], sA[0]) + elif transposed_A and len(sA) == 3: + sA = (sA[0], sA[2], sA[0]) + if transposed_B and len(sB) == 2: + sB = (sB[1], sB[0]) + elif transposed_B and len(sB) == 3: + sB = (sB[0], sB[2], sB[0]) + # this is a mess: cuBLAS expect column major, but PyTorch is row major. + # So to perform the matrix multiplication, we have to treat A, B, and C matrices + # (transpose of row major is column major) + # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these + + # matrices in the input arguments for cuBLAS + # column major: A @ B = C: [m, k] @ [k, n] = [m, n] + # row major: B^T @ A^T = C^T: [m, k] @ [k, n] = [m, n] + # column major with row major layout: B^T @ A^T = C^T: [k, m] @ [n, k] = [n, m] + if len(sB) == 2: + if B.stride()[0] == B.shape[1]: + transposed_B = False + elif B.stride()[1] == B.shape[0]: + transposed_B = True + if len(A.shape) == 2: + if A.stride()[0] == A.shape[1]: + transposed_A = False + elif A.stride()[1] == A.shape[0]: + transposed_A = True + else: + if A.stride()[1] == A.shape[2]: + transposed_A = False + elif A.stride()[2] == A.shape[1]: + transposed_A = True + + if len(sA) == 2: + n = sA[0] + ldb = A.stride()[1 if transposed_A else 0] + elif len(sA) == 3 and len(sB) == 2: + n = sA[0] * sA[1] + ldb = sA[2] + + m = sB[1] + k = sB[0] + lda = B.stride()[(1 if transposed_B else 0)] + ldc = sB[1] + elif len(sB) == 3: + # special case + assert len(sA) == 3 + if not (sA[0] == sB[0] and sA[1] == sB[1]): + raise ValueError( + f"Only bsi,bso->io supported for tensor contractions, but dims for A x B were: {sA} x {sB}" + ) + + transposed_A = True + transposed_B = False + + m = sB[2] + n = sA[2] + k = sB[0] * sB[1] + + lda = m + ldb = sA[2] + ldc = m + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + + # B^T @ A^T = C^T + # [km, nk -> mn] + is_on_gpu([B, A, out]) + lib.cigemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc)) + return out + + +def batched_igemm( + A: Tensor, + B: Tensor, + out: Tensor = None, + transposed_A=False, + transposed_B=False, +): + if not len(A.shape) == 3 or not len(B.shape) == 3: + raise ValueError( + f"Expected 3-dimensional tensors for bmm, but got shapes A and B: {A.shape} and {B.shape}" + ) + sout = check_matmul(A, B, out, transposed_A, transposed_B) + if out is None: + out = torch.zeros(size=sout, dtype=torch.int32, device=A.device) + + if B.is_contiguous(): + lda = B.stride()[1] + transposed_A = False + else: + s = B.stride() + if s[0] != B.shape[0]: + B = B.contiguous() + lda = B.stride()[1] + elif s[2] == B.shape[1]: + transposed_A = True + lda = B.stride()[2] + else: + if s[2] == 1: + B = B.contiguous() + lda = B.stride()[1] + elif s[1] == 1: + B = B.contiguous() + lda = B.stride()[1] + else: + B = B.contiguous() + lda = B.stride()[1] + + if A.is_contiguous(): + ldb = A.stride()[1] + transposed_B = False + else: + s = A.stride() + if s[0] != A.shape[0]: + A = A.contiguous() + ldb = A.stride()[1] + transposed_B = False + elif s[2] == A.shape[1]: + ldb = A.stride()[2] + transposed_B = True + else: + A = A.contiguous() + ldb = A.stride()[1] + transposed_B = False + + # this is a mess: cuBLAS expect column major, but PyTorch is row major. + # So to perform the matrix multiplication, we have to treat A, B, and C matrices + # (transpose of row major is column major) + # This means we compute B^T A^T = C^T and we explicitly switch the dimensions of each of these + # matrices in the input arguments for cuBLAS + + # column major: A @ B = C: [batch, m, k] @ [batch, k, n] = [batch, m, n] + # row major: B^T @ A^T = C^T: [batch, m, k] @ [batch, k, n] = [batch, m, n] + # column major with row major layout: B^T @ A^T = C^T: [batch, k, m] @ [batch, n, k] = [batch, n, m] + num_batch = A.shape[0] + n = A.shape[1] + m = B.shape[2] + k = B.shape[1] + + ldc = m + + strideA = B.shape[1] * B.shape[2] + strideB = A.shape[1] * A.shape[2] + strideC = A.shape[1] * B.shape[2] + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + + is_on_gpu([B, A, out]) + lib.cbatched_igemm(ptr, ct.c_bool(transposed_B), ct.c_bool(transposed_A), ct.c_int32(m), ct.c_int32(n), ct.c_int32(k), + get_ptr(B), get_ptr(A), get_ptr(out), ct.c_int32(lda), ct.c_int32(ldb), ct.c_int32(ldc), + ct.c_long(strideA), ct.c_long(strideB), ct.c_long(strideC), ct.c_uint32(num_batch)) + return out + + +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + shapeA = SA[0] + shapeB = SB[0] + dimsA = len(shapeA) + dimsB = len(shapeB) + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0] * shapeA[1] + + rows = n = shapeB[0] + assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) + + if dimsA == 2 and out is None: + out, Sout = get_transform_buffer( + (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" + ) + elif dimsA == 3 and out is None: + out, Sout = get_transform_buffer( + (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" + ) + + assert dimsB != 3, "len(B.shape)==3 not supported" + assert A.device.type == "cuda" + assert B.device.type == "cuda" + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + assert out.dtype == dtype + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" + assert ( + shapeA[-1] == shapeB[-1] + ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" + formatB = SB[1] + prev_device = A.device + torch.cuda.set_device(A.device) + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + + k = shapeA[-1] + lda = ct.c_int32(m * 32) + if formatB == "col_turing": + # turing: tiles with rows filled up to multiple of 8 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) + else: + # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) + + ldc = ct.c_int32(m * 32) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + + has_error = 0 + ptrRowScale = get_ptr(None) + is_on_gpu([A, B, out]) + if formatB == 'col_turing': + if dtype == torch.int32: + has_error = lib.cigemmlt_turing_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + else: + has_error = lib.cigemmlt_turing_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + elif formatB == "col_ampere": + if dtype == torch.int32: + has_error = lib.cigemmlt_ampere_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + else: + has_error = lib.cigemmlt_ampere_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + + if has_error == 1: + print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') + raise Exception('cublasLt ran into an error!') + + torch.cuda.set_device(prev_device) + + return out, Sout + + +def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None +): + assert A.dtype == torch.int32 + if bias is not None: assert bias.dtype == torch.float16 + out_shape = quant_state[0] + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + if out is None: + out = torch.empty(out_shape, dtype=torch.float16, device=A.device) + if new_row_stats is None: + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) + if new_col_stats is None: + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) + assert ( + new_row_stats.shape[0] == row_stats.shape[0] + ), f"{new_row_stats.shape} vs {row_stats.shape}" + assert ( + new_col_stats.shape[0] == col_stats.shape[0] + ), f"{new_col_stats.shape} vs {col_stats.shape}" + + prev_device = pre_call(A.device) + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNewRowStats = get_ptr(new_row_stats) + ptrNewColStats = get_ptr(new_col_stats) + ptrBias = get_ptr(bias) + numRows = ct.c_int32(out_shape[0]) + numCols = ct.c_int32(out_shape[1]) + + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) + lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + post_call(prev_device) + + return out + + +def get_colrow_absmax( + A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 +): + assert A.dtype == torch.float16 + device = A.device + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + col_tiles = (cols + 255) // 256 + tiled_rows = ((rows + 15) // 16) * 16 + if row_stats is None: + row_stats = torch.empty( + (rows,), dtype=torch.float32, device=device + ).fill_(-50000.0) + if col_stats is None: + col_stats = torch.empty( + (cols,), dtype=torch.float32, device=device + ).fill_(-50000.0) + + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros( + ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device + ) + + ptrA = get_ptr(A) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNnzrows = get_ptr(nnz_block_ptr) + rows = ct.c_int32(rows) + cols = ct.c_int32(cols) + + prev_device = pre_call(A.device) + is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) + lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) + post_call(prev_device) + + if threshold > 0.0: + nnz_block_ptr.cumsum_(0) + + return row_stats, col_stats, nnz_block_ptr + + +class COOSparseTensor: + def __init__(self, rows, cols, nnz, rowidx, colidx, values): + assert rowidx.dtype == torch.int32 + assert colidx.dtype == torch.int32 + assert values.dtype == torch.float16 + assert values.numel() == nnz + assert rowidx.numel() == nnz + assert colidx.numel() == nnz + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.rowidx = rowidx + self.colidx = colidx + self.values = values + + +class CSRSparseTensor: + def __init__(self, rows, cols, nnz, rowptr, colidx, values): + assert rowptr.dtype == torch.int32 + assert colidx.dtype == torch.int32 + assert values.dtype == torch.float16 + assert values.numel() == nnz + assert colidx.numel() == nnz + assert rowptr.numel() == rows + 1 + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.rowptr = rowptr + self.colidx = colidx + self.values = values + + +class CSCSparseTensor: + def __init__(self, rows, cols, nnz, colptr, rowidx, values): + assert colptr.dtype == torch.int32 + assert rowidx.dtype == torch.int32 + assert values.dtype == torch.float16 + assert values.numel() == nnz + assert rowidx.numel() == nnz + assert colptr.numel() == cols + 1 + + self.rows = rows + self.cols = cols + self.nnz = nnz + self.colptr = colptr + self.rowidx = rowidx + self.values = values + + +def coo2csr(cooA): + values, counts = torch.unique(cooA.rowidx, return_counts=True) + values.add_(1) + rowptr = torch.zeros( + (cooA.rows + 1,), dtype=torch.int32, device=cooA.rowidx.device + ) + rowptr.scatter_(index=values.long(), src=counts.int(), dim=0) + rowptr.cumsum_(0) + return CSRSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, rowptr, cooA.colidx, cooA.values + ) + + +def coo2csc(cooA): + val, col2rowidx = torch.sort(cooA.colidx) + rowidx = cooA.rowidx[col2rowidx] + values = cooA.values[col2rowidx] + colvalues, counts = torch.unique(val, return_counts=True) + colvalues.add_(1) + colptr = torch.zeros( + (cooA.cols + 1,), dtype=torch.int32, device=cooA.colidx.device + ) + colptr.scatter_(index=colvalues.long(), src=counts.int(), dim=0) + colptr.cumsum_(0) + return CSCSparseTensor( + cooA.rows, cooA.cols, cooA.nnz, colptr, rowidx, values + ) + + +def coo_zeros(rows, cols, nnz, device, dtype=torch.half): + rowidx = torch.zeros((nnz,), dtype=torch.int32, device=device) + colidx = torch.zeros((nnz,), dtype=torch.int32, device=device) + values = torch.zeros((nnz,), dtype=dtype, device=device) + return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) + + +def double_quant( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 +): + device = A.device + assert A.dtype == torch.half + assert device.type == "cuda" + prev_device = pre_call(A.device) + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( + A, threshold=threshold + ) + + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + + coo_tensor = None + ptrA = get_ptr(A) + ptrColStats = get_ptr(col_stats) + ptrRowStats = get_ptr(row_stats) + ptrOutCol = get_ptr(out_col) + ptrOutRow = get_ptr(out_row) + + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) + if threshold > 0.0: + nnz = nnz_row_ptr[-1].item() + if nnz > 0: + coo_tensor = coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) + ptrRowIdx = get_ptr(coo_tensor.rowidx) + ptrColIdx = get_ptr(coo_tensor.colidx) + ptrVal = get_ptr(coo_tensor.values) + ptrRowPtr = get_ptr(nnz_row_ptr) + + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) + val, idx = torch.sort(coo_tensor.rowidx) + coo_tensor.rowidx = val + coo_tensor.colidx = coo_tensor.colidx[idx] + coo_tensor.values = coo_tensor.values[idx] + else: + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(0.0), + ct.c_int32(rows), + ct.c_int32(cols), + ) + else: + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) + post_call(prev_device) + + return out_row, out_col, row_stats, col_stats, coo_tensor + + +def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + prev_device = pre_call(A.device) + if state is None: state = (A.shape, from_order) + else: from_order = state[1] + if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: new_state = (state[0], to_order) # (shape, order) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + else: + dim1 = ct.c_int32(shape[0] * shape[1]) + dim2 = ct.c_int32(shape[2]) + + is_on_gpu([A, out]) + if to_order == 'col32': + if transpose: + lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "col_turing": + if transpose: + lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "col_ampere": + if transpose: + lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "row": + if from_order == "col_turing": + lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) + elif from_order == "col_ampere": + lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) + else: + raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + + post_call(prev_device) + + return out, new_state + + +def spmm_coo(cooA, B, out=None): + if out is None: + out = torch.empty( + (cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype + ) + nnz = cooA.nnz + assert cooA.rowidx.numel() == nnz + assert cooA.colidx.numel() == nnz + assert cooA.values.numel() == nnz + assert cooA.cols == B.shape[0] + + transposed_B = False if B.is_contiguous() else True + + ldb = B.stride()[(1 if transposed_B else 0)] + ldc = B.shape[1] + + ptr = Cusparse_Context.get_instance().context + + ptrRowidx = get_ptr(cooA.rowidx) + ptrColidx = get_ptr(cooA.colidx) + ptrValues = get_ptr(cooA.values) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + cnnz = ct.c_int32(cooA.nnz) + crowsA = ct.c_int32(cooA.rows) + ccolsA = ct.c_int32(cooA.cols) + ccolsB = ct.c_int32(B.shape[1]) + cldb = ct.c_int32(ldb) + cldc = ct.c_int32(ldc) + + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out]) + lib.cspmm_coo(ptr, ptrRowidx, ptrColidx, ptrValues, cnnz, crowsA, ccolsA, ccolsB, cldb, ptrB, cldc, ptrC, ct.c_bool(transposed_B)) + + return out + + +def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): + if out is None: + out = torch.zeros( + (cooA.rows, B.shape[1]), device=B.device, dtype=cooA.values.dtype + ) + nnz = cooA.nnz + prev_device = pre_call(B.device) + assert cooA.rowidx.numel() == nnz + assert cooA.colidx.numel() == nnz + assert cooA.values.numel() == nnz + assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}" + + transposed_B = False if B.is_contiguous() else True + + ldb = B.stride()[(1 if transposed_B else 0)] + ldc = B.shape[1] + + values, counts = torch.unique(cooA.rowidx, return_counts=True) + offset = counts.cumsum(0).int() + max_count, max_idx = torch.sort(counts, descending=True) + max_idx = max_idx.int() + max_count = max_count.int() + assert ( + max_count[0] <= 32 + ), f"Current max count per row is 8 but found {max_count[0]}." + assert B.dtype in [torch.float16, torch.int8] + ptrOffset = get_ptr(offset) + ptrMaxCount = get_ptr(max_count) + ptrMaxIdx = get_ptr(max_idx) + + ptrRowidx = get_ptr(cooA.rowidx) + ptrColidx = get_ptr(cooA.colidx) + ptrValues = get_ptr(cooA.values) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + ptrDequantStats = get_ptr(dequant_stats) + cnnz_rows = ct.c_int32(counts.numel()) + cnnz = ct.c_int32(cooA.nnz) + crowsA = ct.c_int32(cooA.rows) + ccolsA = ct.c_int32(cooA.cols) + crowsB = ct.c_int32(B.shape[1]) + ccolsB = ct.c_int32(B.shape[1]) + cldb = ct.c_int32(ldb) + cldc = ct.c_int32(ldc) + + is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats]) + if B.dtype == torch.float16: + lib.cspmm_coo_very_sparse_naive_fp16( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) + elif B.dtype == torch.int8: + lib.cspmm_coo_very_sparse_naive_int8( + ptrMaxCount, + ptrMaxIdx, + ptrOffset, + ptrRowidx, + ptrColidx, + ptrValues, + ptrB, + ptrC, + ptrDequantStats, + cnnz_rows, + cnnz, + crowsA, + crowsB, + ccolsB, + ) + # else: assertion error + post_call(prev_device) + + return out + + +C = 127.0 + + +def vectorwise_quant(x, dim=1, quant_type="vector"): + if quant_type == "linear": + max1 = torch.abs(x).max().float() + xq = torch.round(x / max1 * 127).to(torch.int8) + return xq, max1 + elif quant_type in ["vector", "row"]: + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + xq = torch.round(x * (C / max1)).to(torch.int8) + return xq, max1 + elif quant_type == "zeropoint": + dtype = x.dtype + x = x.float() + dyna = x.max() - x.min() + if dyna == 0: + dyna = 1 + qx = 255.0 / dyna + minx = x.min() + zpx = torch.round(minx * qx) + x = torch.round(qx * x - zpx) + zpx + return x, qx + elif quant_type in ["vector-zeropoint", "row-zeropoint"]: + dtype = x.dtype + x = x.float() + dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin( + x, dim=dim, keepdim=True + ) + dyna[dyna == 0] = 1 + qx = 255.0 / dyna + minx = torch.amin(x, dim=dim, keepdim=True) + zpx = torch.round(minx * qx) + x = torch.round(qx * x - zpx) + zpx + return x, qx + elif quant_type == "truncated-vector": + with torch.no_grad(): + absx = torch.abs(x) + max1 = torch.amax(absx, dim=dim, keepdim=True) + max1 = max1 * 0.7 + idx = absx > max1.expand_as(absx) + sign = torch.sign(x[idx]) + x[idx] = max1.expand_as(absx)[idx] * sign + xq = torch.round(x / max1 * C).to(torch.int8) + return xq, max1 + else: + return None + + +def vectorwise_dequant(xq, max1, quant_type="vector"): + if quant_type == "vector": + x = (xq / C * max1).to(torch.float32) + return x + else: + return None + + +def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): + if quant_type == "linear": + norm = S1 * S2 / (C * C) + # double cast needed to prevent overflows + return (xq.float() * norm).to(dtype) + elif quant_type == "zeropoint": + norm = 1.0 / (S1 * S2) + return (xq.float() * norm).to(dtype) + elif quant_type == "row-zeropoint": + norm = 1.0 / (S1 * S2) + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= norm + else: + x *= norm + return x.to(dtype) + elif quant_type == "vector-zeropoint": + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= 1.0 / S1 + else: + x *= 1.0 / S1 + x *= 1.0 / S2.t() + return x.to(dtype) + elif quant_type == "row": + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= S1 * S2 / (C * C) + else: + x *= S1 * S2 / (C * C) + return x.to(dtype) + elif quant_type in ["truncated-vector", "vector"]: + x = xq.float() + if len(S1.shape) == 3 and len(x.shape) == 2: + S1 = S1.squeeze(0) + if len(S2.shape) == 3 and len(x.shape) == 2: + S2 = S2.squeeze(0) + if len(S1.shape) == 2: + x *= S1 / C + else: + x *= S1 / C + x *= S2 / C + return x.to(dtype) + else: + return None + + +def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): + offset = B.float().t().sum(0) * (SA[0] + SA[1]) + x = xq.float() + if len(xq.shape) == 2 and len(SB.shape) == 3: + SB = SB.squeeze(0) + if len(SB.shape) == 2: + x *= SB.t() / 127 + else: + x *= SB / 127 + x *= SA[1] / 127 + x += offset + return x.to(dtype) + + +def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ["col_turing", "col_ampere"] + assert A.device.type == "cuda" + + out = torch.zeros( + (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device + ) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + prev_device = pre_call(A.device) + if formatA == 'col_turing': + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == "col_ampere": + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + post_call(prev_device) + + return out + +def pipeline_test(A, batch_size): + out = torch.zeros_like(A) + lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) + return out diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py new file mode 100644 index 000000000..49d7b5ced --- /dev/null +++ b/bitsandbytes/nn/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from .modules import Int8Params, Linear8bitLt, StableEmbedding, Linear4bit, LinearNF4, LinearFP4, Params4bit, OutlierAwareLinear, SwitchBackLinearBnb +from .triton_based_modules import SwitchBackLinear, SwitchBackLinearGlobal, SwitchBackLinearVectorwise, StandardLinear diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py new file mode 100644 index 000000000..32849212d --- /dev/null +++ b/bitsandbytes/nn/modules.py @@ -0,0 +1,464 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional, TypeVar, Union, overload + +import torch +import torch.nn.functional as F +from torch import Tensor, device, dtype, nn + +import bitsandbytes as bnb +import bitsandbytes.functional +from bitsandbytes.autograd._functions import get_inverse_transform_indices, undo_layout +from bitsandbytes.optim import GlobalOptimManager +from bitsandbytes.utils import OutlierTracer, find_outlier_dims + +T = TypeVar("T", bound="torch.nn.Module") + + +class StableEmbedding(torch.nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + device=None, + dtype=None, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + device, + dtype, + ) + self.norm = torch.nn.LayerNorm(embedding_dim, device=device) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight) + self._fill_padding_idx_with_zero() + + """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + to make the Layer compatible with Pytorch < 1.9. + This means that if this changes in future PyTorch releases this need to change too + which is cumbersome. However, with this we can ensure compatibility with previous + PyTorch releases. + """ + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + emb = F.embedding( + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + # always apply layer norm in full precision + emb = emb.to(torch.get_default_dtype()) + + return self.norm(emb).to(self.weight.dtype) + + +class Embedding(torch.nn.Embedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + _weight: Optional[Tensor] = None, + ) -> None: + super().__init__( + num_embeddings, + embedding_dim, + padding_idx, + max_norm, + norm_type, + scale_grad_by_freq, + sparse, + _weight, + ) + GlobalOptimManager.get_instance().register_module_override( + self, "weight", {"optim_bits": 32} + ) + + def reset_parameters(self) -> None: + torch.nn.init.xavier_uniform_(self.weight) + self._fill_padding_idx_with_zero() + + """ !!! This is a redefinition of _fill_padding_idx_with_zero in torch.nn.Embedding + to make the Layer compatible with Pytorch < 1.9. + This means that if this changes in future PyTorch releases this need to change too + which is cumbersome. However, with this we can ensure compatibility with previous + PyTorch releases. + """ + + def _fill_padding_idx_with_zero(self) -> None: + if self.padding_idx is not None: + with torch.no_grad(): + self.weight[self.padding_idx].fill_(0) + + def forward(self, input: Tensor) -> Tensor: + emb = F.embedding( + input, + self.weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) + + return emb + +class Params4bit(torch.nn.Parameter): + def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, compress_statistics=True, quant_type='fp4'): + if data is None: + data = torch.empty(0) + + self = torch.Tensor._make_subclass(cls, data, requires_grad) + self.blocksize = blocksize + self.compress_statistics = compress_statistics + self.quant_type = quant_type + self.quant_state = quant_state + self.data = data + return self + + def cuda(self, device): + w = self.data.contiguous().half().cuda(device) + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + self.data = w_4bit + self.quant_state = quant_state + + return self + + @overload + def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: + ... + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) + + if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): + return self.cuda(device) + else: + s = self.quant_state + if s is not None: + # make sure the quantization state is on the right device + s[0] = s[0].to(device) + if self.compress_statistics: + # TODO: refactor this. This is a nightmare + # for 4-bit: + # state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + # state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + #s[-2][0] = s[-2][0].to(device) # offset + #s[-2][1][0] = s[-2][1][0].to(device) # nested absmax + + # for 8-bit + s[-2][0] = s[-2][0].to(device) # offset + s[-2][1][0] = s[-2][1][0].to(device) # nested quantiation state statitics + s[-2][1][1] = s[-2][1][1].to(device) # nested quantiation codebook + new_param = Params4bit(super().to(device=device, dtype=dtype, non_blocking=non_blocking), + requires_grad=self.requires_grad, quant_state=self.quant_state, + blocksize=self.blocksize, compress_statistics=self.compress_statistics, + quant_type=self.quant_type) + + return new_param + +class Linear4bit(nn.Linear): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True, quant_type='fp4'): + super().__init__(input_features, output_features, bias) + self.weight = Params4bit(self.weight.data, requires_grad=False, compress_statistics=compress_statistics, quant_type=quant_type) + self.compute_dtype = compute_dtype + + def forward(self, x: torch.Tensor): + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + if getattr(self.weight, 'quant_state', None) is None: + print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.') + inp_dtype = x.dtype + if self.compute_dtype is not None: + x = x.to(self.compute_dtype) + + bias = None if self.bias is None else self.bias.to(self.compute_dtype) + out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state) + + out = out.to(inp_dtype) + + return out + +class LinearFP4(Linear4bit): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'fp4') + +class LinearNF4(Linear4bit): + def __init__(self, input_features, output_features, bias=True, compute_dtype=None, compress_statistics=True): + super().__init__(input_features, output_features, bias, compute_dtype, compress_statistics, 'nf4') + + + +class Int8Params(torch.nn.Parameter): + def __new__( + cls, + data=None, + requires_grad=True, + has_fp16_weights=False, + CB=None, + SCB=None, + ): + cls.has_fp16_weights = has_fp16_weights + cls.CB = None + cls.SCB = None + if data is None: + data = torch.empty(0) + return torch.Tensor._make_subclass(cls, data, requires_grad) + + def cuda(self, device): + if self.has_fp16_weights: + return super().cuda(device) + else: + # we store the 8-bit rows-major weight + # we convert this weight to the turning/ampere weight during the first inference pass + B = self.data.contiguous().half().cuda(device) + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + self.data = CB + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) + + return self + + @overload + def to( + self: T, + device: Optional[Union[int, device]] = ..., + dtype: Optional[Union[dtype, str]] = ..., + non_blocking: bool = ..., + ) -> T: + ... + + @overload + def to(self: T, dtype: Union[dtype, str], non_blocking: bool = ...) -> T: + ... + + @overload + def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: + ... + + def to(self, *args, **kwargs): + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + + if ( + device is not None + and device.type == "cuda" + and self.data.device.type == "cpu" + ): + return self.cuda(device) + else: + new_param = Int8Params( + super().to( + device=device, dtype=dtype, non_blocking=non_blocking + ), + requires_grad=self.requires_grad, + has_fp16_weights=self.has_fp16_weights, + ) + new_param.CB = self.CB + new_param.SCB = self.SCB + + return new_param + + + +class Linear8bitLt(nn.Linear): + def __init__(self, input_features, output_features, bias=True, has_fp16_weights=True, + memory_efficient_backward=False, threshold=0.0, index=None): + super().__init__(input_features, output_features, bias) + assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0" + self.state = bnb.MatmulLtState() + self.index = index + + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params(self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if not self.state.has_fp16_weights and self.state.CB is None and self.state.CxB is not None: + # reorder weight layout back from ampere/turing to row + reorder_layout = True + weight_clone = self.weight.data.clone() + else: + reorder_layout = False + + try: + if reorder_layout: + self.weight.data = undo_layout(self.state.CxB, self.state.tile_indices) + + super()._save_to_state_dict(destination, prefix, keep_vars) + + # we only need to save SCB as extra data, because CB for quantized weights is already stored in weight.data + weight_name = "SCB" + + # case 1: .cuda was called, SCB is in self.weight + param_from_weight = getattr(self.weight, weight_name) + # case 2: self.init_8bit_state was called, SCB is in self.state + param_from_state = getattr(self.state, weight_name) + + key_name = prefix + f"{weight_name}" + if param_from_weight is not None: + destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() + elif not self.state.has_fp16_weights and param_from_state is not None: + destination[key_name] = param_from_state if keep_vars else param_from_state.detach() + finally: + if reorder_layout: + self.weight.data = weight_clone + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + for key in unexpected_keys: + input_name = key[len(prefix):] + if input_name == "SCB": + if self.weight.SCB is None: + # buffers not yet initialized, can't call them directly without + raise RuntimeError("Loading a quantized checkpoint into non-quantized Linear8bitLt is " + "not supported. Please call module.cuda() before module.load_state_dict()") + + input_param = state_dict[key] + self.weight.SCB.copy_(input_param) + unexpected_keys.remove(key) + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x: torch.Tensor): + self.state.is_training = self.training + if self.weight.CB is not None: + self.init_8bit_state() + + # weights are cast automatically as Int8Params, but the bias has to be cast manually + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(x.dtype) + + out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) + + if not self.state.has_fp16_weights: + if self.state.CB is not None and self.state.CxB is not None: + # we converted 8-bit row major to turing/ampere format in the first inference pass + # we no longer need the row-major weight + del self.state.CB + self.weight.data = self.state.CxB + return out + + +class OutlierAwareLinear(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.outlier_dim = None + self.is_quantized = False + + def forward_with_outliers(self, x, outlier_idx): + raise NotImplementedError('Please override the `forward_with_outliers(self, x, outlier_idx)` function') + + def quantize_weight(self, w, outlier_idx): + raise NotImplementedError('Please override the `quantize_weights(self, w, outlier_idx)` function') + + def forward(self, x): + if self.outlier_dim is None: + tracer = OutlierTracer.get_instance() + if not tracer.is_initialized(): + print('Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer') + outlier_idx = tracer.get_outliers(self.weight) + #print(outlier_idx, tracer.get_hvalue(self.weight)) + self.outlier_dim = outlier_idx + + if not self.is_quantized: + w = self.quantize_weight(self.weight, self.outlier_dim) + self.weight.data.copy_(w) + self.is_quantized = True + +class SwitchBackLinearBnb(nn.Linear): + def __init__( + self, + input_features, + output_features, + bias=True, + has_fp16_weights=True, + memory_efficient_backward=False, + threshold=0.0, + index=None, + ): + super().__init__( + input_features, output_features, bias + ) + self.state = bnb.MatmulLtState() + self.index = index + + self.state.threshold = threshold + self.state.has_fp16_weights = has_fp16_weights + self.state.memory_efficient_backward = memory_efficient_backward + if threshold > 0.0 and not has_fp16_weights: + self.state.use_pool = True + + self.weight = Int8Params( + self.weight.data, has_fp16_weights=has_fp16_weights, requires_grad=has_fp16_weights + ) + + def init_8bit_state(self): + self.state.CB = self.weight.CB + self.state.SCB = self.weight.SCB + self.weight.CB = None + self.weight.SCB = None + + def forward(self, x): + self.state.is_training = self.training + + if self.weight.CB is not None: + self.init_8bit_state() + + out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py new file mode 100644 index 000000000..6fbf583b9 --- /dev/null +++ b/bitsandbytes/nn/triton_based_modules.py @@ -0,0 +1,258 @@ +import torch +import torch.nn as nn +import time +from functools import partial + +from bitsandbytes.triton.triton_utils import is_triton_available + +from bitsandbytes.triton.dequantize_rowwise import dequantize_rowwise +from bitsandbytes.triton.quantize_rowwise import quantize_rowwise +from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose +from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize +from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose +from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze + + +class _switchback_global(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + # reshape input to [N * L, D] + X = X_3D.view(-1, X_3D.size(-1)) + + # rowwise quantize for X, global quantize for W + X_int8, state_X = quantize_rowwise(X) + W_int8, state_W = quantize_global(W) + + # save for backward. + ctx.save_for_backward = X, W + + # matmult, fused dequant and add bias + # call "mixed" because we are mixing rowwise quantized and global quantized + return int8_matmul_mixed_dequanitze( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + # reshape input to [N_out * L, D] + G = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X = grad_W = grad_bias = None + + X, W = ctx.save_for_backward + if ctx.needs_input_grad[0]: + # rowwise quantize for G, global quantize for W + # for W, we also fuse the transpose operation because only A @ B^T is supported + # so we transpose once then call .t() in the matmul + G_int8, state_G = quantize_rowwise(G) + W_int8, state_W = quantize_global_transpose(W) + grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + # backward pass uses standard weight grad + grad_W = torch.matmul(G.t(), X.to(G.dtype)) + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + + return grad_X, grad_W, grad_bias + +class _switchback_vectorrize(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + # reshape input to [N * L, D] + X = X_3D.view(-1, X_3D.size(-1)) + + ctx.save_for_backward = X, W + # rowwise quantize for X + # columnwise quantize for W (first rowwise, transpose later) + X_int8, state_X = quantize_rowwise(X) + W_int8, state_W = quantize_rowwise(W) + + # matmult, fused dequant and add bias + # call kernel which expects rowwise quantized X and W + return int8_matmul_rowwise_dequantize( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D.size()[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + X, W = ctx.save_for_backward + + G = G_3D.reshape(-1, G_3D.size(-1)) + + grad_X = grad_W = grad_bias = None + + if ctx.needs_input_grad[0]: + # rowwise quantize for G, columnwise quantize for W and fused transpose + # we call .t() for weight later because only A @ B^T is supported + G_int8, state_G = quantize_rowwise(G) + W_int8, state_W = quantize_columnwise_and_transpose(W) + grad_X = int8_matmul_rowwise_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( + *G_3D.size()[:-1], -1 + ) + if ctx.needs_input_grad[1]: + # backward pass uses standard weight grad + grad_W = torch.matmul(G.t(), X.to(G.dtype)) + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + + return grad_X, grad_W, grad_bias + +class _switchback_global_mem_efficient(torch.autograd.Function): + + @staticmethod + def forward(ctx, X_3D, W, bias): + # reshape input to [N * L, D] + X = X_3D.view(-1, X_3D.size(-1)) + X_3D_sz = X_3D.size() + + # rowwise quantize for X, global quantize for W + X_int8, state_X = quantize_rowwise(X) + del X + W_int8, state_W = quantize_global(W) + + # save for backward. + ctx.save_for_backward = X_int8, state_X, W_int8, state_W + + # matmult, fused dequant and add bias + # call "mixed" because we are mixing rowwise quantized and global quantized + return int8_matmul_mixed_dequanitze( + X_int8, W_int8.t(), state_X, state_W, bias + ).view(*X_3D_sz[:-1], -1) + + @staticmethod + def backward(ctx, G_3D): + # reshape input to [N_out * L, D] + G = G_3D.reshape(-1, G_3D.size(-1)) + G_3D_sz = G_3D.size() + + grad_X = grad_W = grad_bias = None + + X_int8, state_X, W_int8, state_W = ctx.save_for_backward + if ctx.needs_input_grad[1]: + real_X = dequantize_rowwise(X_int8, state_X) + del X_int8 + grad_W = torch.matmul(G.t(), real_X.to(G.dtype)) + del real_X + if ctx.needs_input_grad[2]: + grad_bias = G.sum(dim=0) + if ctx.needs_input_grad[0]: + G_int8, state_G = quantize_rowwise(G) + del G + W_int8 = W_int8.t().contiguous() + grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( + *G_3D_sz[:-1], -1 + ) + + return grad_X, grad_W, grad_bias + +class SwitchBackLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + vector_wise_quantization: bool = False, + mem_efficient : bool = False, + ): + super().__init__(in_features, out_features, bias, device, dtype) + + if not is_triton_available: + raise ImportError('''Could not import triton. Please install triton to use SwitchBackLinear. + Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower''') + + # By default, we use the global quantization. + self.vector_wise_quantization = vector_wise_quantization + if self.vector_wise_quantization: + self._fn = _switchback_vectorrize + if mem_efficient: + print('mem efficient is not supported for vector-wise quantization.') + exit(1) + else: + if mem_efficient: + self._fn = _switchback_global_mem_efficient + else: + self._fn = _switchback_global + + def prepare_for_eval(self): + # If we just want to do eval, we can pre-quantize the weights instead of doing it on the forward pass. + # Note this is experimental and not tested thoroughly. + # Note this needs to be explicitly called with something like + # def cond_prepare(m): + # if hasattr(m, "prepare_for_eval"): + # m.prepare_for_eval() + # model.apply(cond_prepare) + print('=> preparing for eval.') + if self.vector_wise_quantization: + W_int8, state_W = quantize_rowwise(self.weight) + else: + W_int8, state_W = quantize_global(self.weight) + + self.register_buffer("W_int8", W_int8) + self.register_buffer("state_W", state_W) + + del self.weight + + def forward(self, x): + if self.training: + return self._fn.apply(x, self.weight, self.bias) + else: + # If it hasn't been "prepared for eval", run the standard forward pass. + if not hasattr(self, "W_int8"): + return self._fn.apply(x, self.weight, self.bias) + + # Otherwise, use pre-computed weights. + X = x.view(-1, x.size(-1)) + X_int8, state_X = quantize_rowwise(X) + + if self.vector_wise_quantization: + return int8_matmul_rowwise_dequantize( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) + else: + return int8_matmul_mixed_dequanitze( + X_int8, self.W_int8.t(), state_X, self.state_W, self.bias + ).view(*x.size()[:-1], -1) + +SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False) +SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True) +SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True) + +# This is just the standard linear function. +class StandardLinearFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, bias=None): + X = input.view(-1, input.size(-1)) + + ctx.save_for_backward(X, weight, bias) + output = input.matmul(weight.t()) + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output.view(*input.size()[:-1], -1) + + @staticmethod + def backward(ctx, grad_output_3D): + input, weight, bias = ctx.saved_tensors + + grad_output = grad_output_3D.reshape(-1, grad_output_3D.size(-1)) + + grad_input = grad_weight = grad_bias = None + + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight.to(grad_output.dtype)).view(*grad_output_3D.size()[:-1], -1) + if ctx.needs_input_grad[1]: + grad_weight = grad_output.t().matmul(input.to(grad_output.dtype)) + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + return grad_input, grad_weight, grad_bias + +class StandardLinear(nn.Linear): + + def forward(self, x): + return StandardLinearFunction.apply(x, self.weight, self.bias) diff --git a/bitsandbytes/optim/__init__.py b/bitsandbytes/optim/__init__.py new file mode 100644 index 000000000..1cfe2410e --- /dev/null +++ b/bitsandbytes/optim/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from bitsandbytes.cextension import COMPILED_WITH_CUDA + +from .adagrad import Adagrad, Adagrad8bit, Adagrad32bit +from .adam import Adam, Adam8bit, Adam32bit, PagedAdam, PagedAdam8bit, PagedAdam32bit +from .adamw import AdamW, AdamW8bit, AdamW32bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit +from .lamb import LAMB, LAMB8bit, LAMB32bit +from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS +from .optimizer import GlobalOptimManager +from .rmsprop import RMSprop, RMSprop8bit, RMSprop32bit +from .lion import Lion, Lion8bit, Lion32bit +from .sgd import SGD, SGD8bit, SGD32bit diff --git a/bitsandbytes/optim/adagrad.py b/bitsandbytes/optim/adagrad.py new file mode 100644 index 000000000..7d8df58ac --- /dev/null +++ b/bitsandbytes/optim/adagrad.py @@ -0,0 +1,132 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer1State + + +class Adagrad(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if initial_accumulator_value != 0.0: + raise ValueError("Initial accumulator value != 0.0 not supported!") + if lr_decay != 0.0: + raise ValueError("Lr Decay != 0.0 not supported!") + super().__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Adagrad8bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=8, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if initial_accumulator_value != 0.0: + raise ValueError("Initial accumulator value != 0.0 not supported!") + if lr_decay != 0.0: + raise ValueError("Lr Decay != 0.0 not supported!") + assert block_wise + super().__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Adagrad32bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + eps=1e-10, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= weight_decay: + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if initial_accumulator_value != 0.0: + raise ValueError("Initial accumulator value != 0.0 not supported!") + if lr_decay != 0.0: + raise ValueError("Lr Decay != 0.0 not supported!") + super().__init__( + "adagrad", + params, + lr, + (0.0, 0.0), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py new file mode 100644 index 000000000..86981eb86 --- /dev/null +++ b/bitsandbytes/optim/adam.py @@ -0,0 +1,273 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os + +import torch +import torch.distributed as dist + +import bitsandbytes.functional as F +from bitsandbytes.optim.optimizer import Optimizer2State + + +class Adam(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + +class Adam8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + +class Adam32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + +class PagedAdam(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdam32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class AnalysisAdam(torch.optim.Optimizer): + """Adam that performs 8-bit vs 32-bit error analysis. + + This implementation is modified from torch.optim.Adam based on: + `Fixed Weight Decay Regularization in Adam` + (see https://arxiv.org/abs/1711.05101) + + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + bnb_analysis="dynamic-blockwise", + savedir=None, + ): + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad, + ) + super().__init__(params, defaults) + self.analysis = bnb_analysis + self.savedir = savedir + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p_id, p in enumerate(group["params"]): + if p.grad is None: + continue + grad = p.grad.data + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + amsgrad = group.get("amsgrad", False) + assert not amsgrad + + p_data_fp32 = p.data + if p.data.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state = self.state[p] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p_data_fp32) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) + state["abserrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["relerrors"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + state["counts"] = torch.zeros( + (256, 256), device=p_data_fp32.device + ) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) + else: + state["exp_avg"] = state["exp_avg"].to(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) + if amsgrad: + state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( + p_data_fp32 + ) + + state["step"] += 1 + beta1, beta2 = group["betas"] + bias_correction1 = 1 - beta1 ** state["step"] + bias_correction2 = 1 - beta2 ** state["step"] + step_size = ( + group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + ) + e = state["abserrors"] + rele = state["relerrors"] + counts = state["counts"] + + if group["weight_decay"] != 0: + p_data_fp32.add_( + p_data_fp32, alpha=-group["weight_decay"] * group["lr"] + ) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + denom = exp_avg_sq.sqrt().add_(group["eps"]) + update_fp32 = exp_avg / denom + + if ( + p_data_fp32.numel() <= 8192 + or p_data_fp32.numel() > 50000 * 1000 + ): + # embedding layer or too small + p_data_fp32 += -step_size * update_fp32 + else: + if self.analysis == "dynamic-blockwise": + code1 = F.create_dynamic_map(signed=True).to(p.device) + code2 = F.create_dynamic_map(signed=False).to(p.device) + C1, S1 = F.quantize_blockwise(exp_avg, code=code1) + state1 = F.dequantize_blockwise(C1, S1) + C2, S2 = F.quantize_blockwise(exp_avg_sq, code=code2) + state2 = F.dequantize_blockwise(C2, S2) + elif self.analysis == "dynamic": + code1 = F.create_dynamic_map(signed=True).to(p.device) + code2 = F.create_dynamic_map(signed=False).to(p.device) + C1, S1 = F.quantize(exp_avg, code=code1) + state1 = F.dequantize(C1, S1) + C2, S2 = F.quantize(exp_avg_sq, code=code2) + state2 = F.dequantize(C2, S2) + elif self.analysis == "linear": + code1 = F.create_linear_map(signed=True).to(p.device) + code2 = F.create_linear_map(signed=False).to(p.device) + C1, S1 = F.quantize(exp_avg, code=code1) + state1 = F.dequantize(C1, S1) + C2, S2 = F.quantize(exp_avg_sq, code=code2) + state2 = F.dequantize(C2, S2) + elif self.analysis == "quantile": + code1 = F.estimate_quantiles(exp_avg) + code2 = F.estimate_quantiles(exp_avg_sq) + C1 = F.quantize_no_absmax(exp_avg, code=code1) + state1 = F.dequantize_no_absmax(C1, code1) + C2 = F.quantize_no_absmax(exp_avg_sq, code=code2) + state2 = F.dequantize_no_absmax(C2, code2) + elif self.analysis == "my-quantization-routine": + pass + # 1. get code + # 2. quantize + # 3. dequantize + # Error will be calculated automatically! + else: + raise ValueError( + f"Invalid analysis value: {self.analysis}!" + ) + + denom = state2.sqrt().add_(group["eps"]) + update_8bit = state1 / denom + + abserr = torch.abs(update_8bit - update_fp32) + relerr = abserr / torch.abs(update_fp32 + 1e-6) + + C1, C2 = C1.int(), C2.int() + + F.histogram_scatter_add_2d(e, C1.int(), C2.int(), abserr) + F.histogram_scatter_add_2d(rele, C1.int(), C2.int(), relerr) + F.histogram_scatter_add_2d( + counts, C1.int(), C2.int(), torch.ones_like(abserr) + ) + + p_data_fp32 += -step_size * update_fp32 + + if not dist.is_initialized() or dist.get_rank() == 0: + if self.savedir != "" and state["step"] % 100 == 0: + if not os.path.exists(self.savedir): + os.makedirs(self.savedir) + shapestr = "_".join( + [str(dim) for dim in p_data_fp32.shape] + ) + pathe = os.path.join( + self.savedir, f"{p_id}_{shapestr}_abserr.pkl" + ) + pathrele = os.path.join( + self.savedir, f"{p_id}_{shapestr}_relerr.pkl" + ) + pathcounts = os.path.join( + self.savedir, f"{p_id}_{shapestr}_counts.pkl" + ) + torch.save(e, pathe) + torch.save(rele, pathrele) + torch.save(counts, pathcounts) + + if p.data.dtype in {torch.float16, torch.bfloat16}: + p.data.copy_(p_data_fp32) + + return loss diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py new file mode 100644 index 000000000..21077f1a0 --- /dev/null +++ b/bitsandbytes/optim/adamw.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer2State + + + +class AdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) + +class AdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged ) + +class AdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged) + + +class PagedAdamW(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW8bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + +class PagedAdamW32bit(Optimizer2State): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32, + args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True): + super().__init__( "adam", params, lr, betas, eps, weight_decay, 32, args, min_8bit_size, percentile_clipping, block_wise, is_paged=True) + diff --git a/bitsandbytes/optim/lamb.py b/bitsandbytes/optim/lamb.py new file mode 100644 index 000000000..1fbb6fadc --- /dev/null +++ b/bitsandbytes/optim/lamb.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer2State + + +class LAMB(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super().__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) + + +class LAMB8bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super().__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) + + +class LAMB32bit(Optimizer2State): + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adam_w_mode=True, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=False, + max_unorm=1.0, + ): + super().__init__( + "lamb", + params, + lr, + betas, + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + max_unorm=1.0, + ) diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py new file mode 100644 index 000000000..73554e3cc --- /dev/null +++ b/bitsandbytes/optim/lars.py @@ -0,0 +1,210 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import torch +from torch.optim import Optimizer + +from bitsandbytes.optim.optimizer import Optimizer1State + + +class LARS(Optimizer1State): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): + if momentum == 0: + raise NotImplementedError( + "LARS without momentum is not supported!" + ) + super().__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + + +class LARS8bit(Optimizer1State): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): + if momentum == 0: + raise NotImplementedError( + "LARS without momentum is not supported!" + ) + super().__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + + +class LARS32bit(Optimizer1State): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + max_unorm=0.02, + ): + if momentum == 0: + raise NotImplementedError( + "LARS without momentum is not supported!" + ) + super().__init__( + "lars", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + max_unorm=max_unorm, + block_wise=False, + ) + + +class PytorchLARS(Optimizer): + def __init__( + self, + params, + lr=0.01, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + max_unorm=0.02, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if momentum < 0.0: + raise ValueError(f"Invalid momentum value: {momentum}") + if weight_decay < 0.0: + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) + + defaults = dict( + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + max_unorm=max_unorm, + ) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError( + "Nesterov momentum requires a momentum and zero dampening" + ) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("nesterov", False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + d_p_list = [] + momentum_buffer_list = [] + weight_decay = group["weight_decay"] + momentum = group["momentum"] + dampening = group["dampening"] + nesterov = group["nesterov"] + max_unorm = group["max_unorm"] + lr = group["lr"] + + for p in group["params"]: + if p.grad is None: + continue + + state = self.state[p] + d_p = p.grad + if weight_decay != 0: + d_p = d_p.add(p, alpha=weight_decay) + + if momentum != 0: + buf = state.get("momentum_buffer", None) + + if buf is None: + buf = torch.clone(d_p).detach() + state["momentum_buffer"] = buf + else: + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + + if nesterov: + update = d_p + buf * momentum + else: + update = buf + + update_scale = 1.0 + if max_unorm > 0.0: + assert p.dtype == torch.float32 + pnorm = torch.norm(p.detach()) + unorm = torch.norm(update) + if unorm > max_unorm * pnorm: + update_scale = max_unorm * pnorm / unorm + + p.add_(update, alpha=-lr * update_scale) + + return loss diff --git a/bitsandbytes/optim/lion.py b/bitsandbytes/optim/lion.py new file mode 100644 index 000000000..2551b68e1 --- /dev/null +++ b/bitsandbytes/optim/lion.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer1State + + +class Lion(Optimizer1State): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super().__init__( + "lion", + params, + lr, + betas, + 0., + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Lion8bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super().__init__( + "lion", + params, + lr, + betas, + 0., + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class Lion32bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-4, + betas=(0.9, 0.99), + weight_decay=0, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + super().__init__( + "lion", + params, + lr, + betas, + 0., + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py new file mode 100644 index 000000000..fb83eddf0 --- /dev/null +++ b/bitsandbytes/optim/optimizer.py @@ -0,0 +1,724 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from collections import abc as container_abcs +from collections import defaultdict +from copy import deepcopy +from itertools import chain + +import torch + +import bitsandbytes.functional as F + + +class MockArgs: + def __init__(self, initial_data): + for key in initial_data: + setattr(self, key, initial_data[key]) + + +class GlobalOptimManager: + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self): + self.pid2config = {} + self.index2config = {} + self.optimizer = None + self.uses_config_override = False + self.module_weight_config_triple = [] + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + cls._instance.initialize() + return cls._instance + + def register_parameters(self, params): + param_groups = list(params) + if not isinstance(param_groups[0], dict): + param_groups = [{"params": param_groups}] + + for group_index, group in enumerate(param_groups): + for p_index, p in enumerate(group["params"]): + if id(p) in self.pid2config: + self.index2config[(group_index, p_index)] = self.pid2config[ + id(p) + ] + + def override_config( + self, parameters, key=None, value=None, key_value_dict=None + ): + """ + Overrides initial optimizer config for specific parameters. + + The key-values of the optimizer config for the input parameters are overridden + This can be both, optimizer parameters like "betas", or "lr" or it can be + 8-bit specific parameters like "optim_bits", "percentile_clipping". + + Parameters + ---------- + parameters : torch.Tensor or list(torch.Tensors) + The input parameters. + key : str + The hyperparamter to override. + value : object + The value for the hyperparamters. + key_value_dict : dict + A dictionary with multiple key-values to override. + """ + self.uses_config_override = True + if isinstance(parameters, torch.nn.Parameter): + parameters = [parameters] + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + if key is not None and value is not None: + assert key_value_dict is None + key_value_dict = {key: value} + + if key_value_dict is not None: + for p in parameters: + if id(p) in self.pid2config: + self.pid2config[id(p)].update(key_value_dict) + else: + self.pid2config[id(p)] = key_value_dict + + def register_module_override(self, module, param_name, config): + self.module_weight_config_triple.append((module, param_name, config)) + + +class Optimizer8bit(torch.optim.Optimizer): + def __init__(self, params, defaults, optim_bits=32, is_paged=False): + super().__init__(params, defaults) + self.initialized = False + self.name2qmap = {} + self.is_paged = is_paged + self.page_mng = F.GlobalPageManager.get_instance() + + self.mng = GlobalOptimManager.get_instance() + self.non_castable_tensor_keys = { + "qmap1", + "qmap2", + "max1", + "max2", + "new_max1", + "new_max2", + "state1", + "state2", + "gnorm_vec", + "absmax1", + "absmax2", + "unorm_vec", + } + + if optim_bits == 8: + self.fill_qmap() + + def fill_qmap(self): + self.name2qmap["dynamic"] = F.create_dynamic_map(signed=True) + self.name2qmap["udynamic"] = F.create_dynamic_map(signed=False) + + def __setstate__(self, state): + super().__setstate__(state) + + def load_state_dict(self, state_dict): + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = deepcopy(state_dict) + # Validate the state_dict + groups = self.param_groups + saved_groups = state_dict["param_groups"] + + if len(groups) != len(saved_groups): + raise ValueError( + "loaded state dict has a different number of " + "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable(g["params"] for g in saved_groups), + chain.from_iterable(g["params"] for g in groups), + ) + } + + def cast(param, value): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + if param.is_floating_point() and value.dtype != torch.uint8: + value = value.to(param.dtype) + return value + elif isinstance(value, dict): + for k, v in value.items(): + if k in self.non_castable_tensor_keys: + value[k] = v.to(param.device) + else: + value[k] = cast(param, v) + + return value + elif isinstance(value, container_abcs.Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + state[param] = cast(param, v) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group["params"] = group["params"] + return new_group + + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups) + ] + self.__setstate__({"state": state, "param_groups": param_groups}) + + def to_gpu(self): + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group["params"]): + if p in self.state: + values = self.state[p] + for k, v in values.items(): + if isinstance(v, torch.Tensor): + is_paged = getattr(v, 'is_paged', False) + if not is_paged: + self.state[p][k] = v.to(p.device) + + def check_overrides(self): + for module, attr, config in self.mng.module_weight_config_triple: + pmodule = getattr(module, attr) + assert pmodule is not None + assert isinstance(pmodule, torch.Tensor) or isinstance( + pmodule, torch.Parameter + ) + found = False + for gindex, group in enumerate(self.param_groups): + if found: + break + for pindex, p in enumerate(group["params"]): + if found: + break + if id(p) == id(pmodule): + # found the matching parameter + # init override + self.mng.pid2config[id(p)] = config + self.mng.index2config[ + (gindex, pindex) + ] = self.mng.pid2config[id(p)] + found = True + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + overflows = [] + + if not self.initialized: + self.check_overrides() + self.to_gpu() # needed for fairseq pure fp16 training + self.initialized = True + + #if self.is_paged: self.page_mng.prefetch_all() + for gindex, group in enumerate(self.param_groups): + for pindex, p in enumerate(group["params"]): + if p.grad is None: + continue + state = self.state[p] + if len(state) == 0: + self.init_state(group, p, gindex, pindex) + + self.prefetch_state(p) + self.update_step(group, p, gindex, pindex) + torch.cuda.synchronize() + if self.is_paged: + # all paged operation are asynchronous, we need + # to sync to make sure all tensors are in the right state + torch.cuda.synchronize() + + + return loss + + def get_config(self, gindex, pindex, group): + config = {} + config["betas"] = group["betas"] + config["eps"] = group["eps"] + config["weight_decay"] = group["weight_decay"] + config["lr"] = group["lr"] + config["optim_bits"] = self.args.optim_bits + config["min_8bit_size"] = self.args.min_8bit_size + config["percentile_clipping"] = self.args.percentile_clipping + config["block_wise"] = self.args.block_wise + config["max_unorm"] = self.args.max_unorm + config["skip_zeros"] = self.args.skip_zeros + + if (gindex, pindex) in self.mng.index2config: + config.update(self.mng.index2config[(gindex, pindex)]) + return config + + def init_state(self, group, p, gindex, pindex): + raise NotImplementedError("init_state method needs to be overridden") + + def update_step(self, group, p, gindex, pindex): + raise NotImplementedError( + "The update_step method needs to be overridden" + ) + + def get_state_buffer(self, p, dtype=torch.float32): + if not self.is_paged or p.numel() < 1e5: + return torch.zeros_like(p, dtype=dtype, device=p.device) + else: + # > 1 MB + buff = F.get_paged(*p.shape, dtype=dtype, device=p.device) + F.fill(buff, 0) + self.page_mng.paged_tensors.append(buff) + return buff + + def prefetch_state(self, p): + if self.is_paged: + state = self.state[p] + s1 = state['state1'] + is_paged = getattr(s1, 'is_paged', False) + if is_paged: + F.prefetch_tensor(state['state1']) + if 'state2' in state: + F.prefetch_tensor(state['state2']) + + +class Optimizer2State(Optimizer8bit): + def __init__( + self, + optimizer_name, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + max_unorm=0.0, + skip_zeros=False, + is_paged=False + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if isinstance(betas, str): + # format: '(beta1, beta2)' + betas = betas.replace("(", "").replace(")", "").strip().split(",") + betas = [float(b) for b in betas] + for i in range(len(betas)): + if not 0.0 <= betas[i] < 1.0: + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) + if not 0.0 <= weight_decay: + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults, optim_bits, is_paged) + + if args is None: + args = {} + args["optim_bits"] = optim_bits + args["percentile_clipping"] = 100 + args["min_8bit_size"] = min_8bit_size + args["percentile_clipping"] = percentile_clipping + args["block_wise"] = block_wise + args["max_unorm"] = max_unorm + args["skip_zeros"] = skip_zeros + + self.args = MockArgs(args) + else: + self.args = args + + self.optimizer_name = optimizer_name + + @torch.no_grad() + def init_state(self, group, p, gindex, pindex): + config = self.get_config(gindex, pindex, group) + + if config["optim_bits"] == 32: + dtype = torch.float32 + elif config["optim_bits"] == 8: + dtype = torch.uint8 + else: + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) + + if p.numel() < config["min_8bit_size"]: + dtype = torch.float32 + + state = self.state[p] + state["step"] = 0 + + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) + state["state2"] = self.get_state_buffer(p, dtype=torch.float32) + elif dtype == torch.uint8: + if state["step"] == 0: + if "dynamic" not in self.name2qmap: + self.fill_qmap() + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) + self.name2qmap["udynamic"] = self.name2qmap["udynamic"].to( + p.device + ) + + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) + state["qmap1"] = self.name2qmap["dynamic"] + + state["state2"] = self.get_state_buffer(p, dtype=torch.uint8) + state["qmap2"] = self.name2qmap["udynamic"] + + if config["block_wise"]: + n = p.numel() + blocks = n // 2048 + blocks += 1 if n % 2048 > 0 else 0 + + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) + state["absmax2"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) + else: + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max2"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + + if config["percentile_clipping"] < 100: + state["gnorm_vec"] = torch.zeros((100,), device=p.device) + + if config["max_unorm"] > 0.0: + state["unorm_vec"] = torch.zeros((1,), device=p.device) + + @torch.no_grad() + def update_step(self, group, p, gindex, pindex): + state = self.state[p] + grad = p.grad + + config = self.get_config(gindex, pindex, group) + + state["step"] += 1 + step = state["step"] + + if config["percentile_clipping"] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( + grad, state["gnorm_vec"], step, config["percentile_clipping"] + ) + else: + gnorm_scale = 1.0 + + if state["state1"].dtype == torch.float: + F.optimizer_update_32bit( + self.optimizer_name, + grad, + p, + state["state1"], + config["betas"][0], + config["eps"], + step, + config["lr"], + state["state2"], + config["betas"][1], + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + skip_zeros=config["skip_zeros"], + ) + + elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: + F.optimizer_update_8bit( + self.optimizer_name, + grad, + p, + state["state1"], + state["state2"], + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + state["qmap2"], + state["max1"], + state["max2"], + state["new_max1"], + state["new_max2"], + config["weight_decay"], + gnorm_scale=gnorm_scale, + unorm_vec=state["unorm_vec"] + if config["max_unorm"] > 0.0 + else None, + max_unorm=config["max_unorm"], + ) + + # swap maxes + state["max1"], state["new_max1"] = state["new_max1"], state["max1"] + state["max2"], state["new_max2"] = state["new_max2"], state["max2"] + elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + F.optimizer_update_8bit_blockwise( + self.optimizer_name, + grad, + p, + state["state1"], + state["state2"], + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + state["qmap2"], + state["absmax1"], + state["absmax2"], + config["weight_decay"], + gnorm_scale=gnorm_scale, + skip_zeros=config["skip_zeros"], + ) + + +class Optimizer1State(Optimizer8bit): + def __init__( + self, + optimizer_name, + params, + lr=1e-3, + betas=(0.9, 0.0), + eps=1e-8, + weight_decay=0.0, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + max_unorm=0.0, + skip_zeros=False, + is_paged=False + ): + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + for i in range(len(betas)): + if not 0.0 <= betas[i] < 1.0: + raise ValueError( + f"Invalid beta parameter at index {i}: {betas[i]}" + ) + if not 0.0 <= weight_decay: + raise ValueError( + f"Invalid weight_decay value: {weight_decay}" + ) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults, optim_bits, is_paged) + + if args is None: + args = {} + args["optim_bits"] = optim_bits + args["percentile_clipping"] = 100 + args["min_8bit_size"] = min_8bit_size + args["percentile_clipping"] = percentile_clipping + args["block_wise"] = block_wise + args["max_unorm"] = max_unorm + args["skip_zeros"] = skip_zeros + + self.args = MockArgs(args) + else: + self.args = args + + self.optimizer_name = optimizer_name + + @torch.no_grad() + def init_state(self, group, p, gindex, pindex): + config = self.get_config(gindex, pindex, group) + + if config["optim_bits"] == 32: + dtype = torch.float32 + elif config["optim_bits"] == 8: + dtype = torch.uint8 + else: + raise NotImplementedError( + f'Amount of optimizer bits not supported: {config["optim_bits"]}' + ) + + if p.numel() < config["min_8bit_size"]: + dtype = torch.float32 + + state = self.state[p] + state["step"] = 0 + + if dtype == torch.float32 or ( + dtype == torch.uint8 and p.numel() < 4096 + ): + state["state1"] = self.get_state_buffer(p, dtype=torch.float32) + elif dtype == torch.uint8: + if state["step"] == 0: + if "dynamic" not in self.name2qmap: + self.fill_qmap() + self.name2qmap["dynamic"] = self.name2qmap["dynamic"].to( + p.device + ) + + state["state1"] = self.get_state_buffer(p, dtype=torch.uint8) + state["qmap1"] = self.name2qmap["dynamic"] + + if config["block_wise"]: + n = p.numel() + blocks = n // 2048 + blocks += 1 if n % 2048 > 0 else 0 + + state["absmax1"] = torch.zeros( + (blocks,), dtype=torch.float32, device=p.device + ) + else: + state["max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + state["new_max1"] = torch.zeros( + (1,), dtype=torch.float32, device=p.device + ) + + if config["percentile_clipping"] < 100: + state["gnorm_vec"] = torch.zeros((100,), device=p.device) + + if config["max_unorm"] > 0.0: + state["unorm_vec"] = torch.zeros((1,), device=p.device) + + @torch.no_grad() + def update_step(self, group, p, gindex, pindex): + state = self.state[p] + grad = p.grad + + config = self.get_config(gindex, pindex, group) + + state["step"] += 1 + step = state["step"] + + if config["percentile_clipping"] < 100: + current_gnorm, clip_value, gnorm_scale = F.percentile_clipping( + grad, state["gnorm_vec"], step, config["percentile_clipping"] + ) + else: + gnorm_scale = 1.0 + + if state["state1"].dtype == torch.float: + F.optimizer_update_32bit( + self.optimizer_name, + grad, + p, + state["state1"], + config["betas"][0], + config["eps"], + step, + config["lr"], + None, + config['betas'][1], + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + skip_zeros=config["skip_zeros"], + ) + + elif state["state1"].dtype == torch.uint8 and not config["block_wise"]: + F.optimizer_update_8bit( + self.optimizer_name, + grad, + p, + state["state1"], + None, + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + None, + state["max1"], + None, + state["new_max1"], + None, + config["weight_decay"], + gnorm_scale, + state["unorm_vec"] if config["max_unorm"] > 0.0 else None, + max_unorm=config["max_unorm"], + ) + + state["max1"], state["new_max1"] = state["new_max1"], state["max1"] + elif state["state1"].dtype == torch.uint8 and config["block_wise"]: + F.optimizer_update_8bit_blockwise( + self.optimizer_name, + grad, + p, + state["state1"], + None, + config["betas"][0], + config["betas"][1], + config["eps"], + step, + config["lr"], + state["qmap1"], + None, + state["absmax1"], + None, + config["weight_decay"], + gnorm_scale=gnorm_scale, + skip_zeros=config["skip_zeros"], + ) diff --git a/bitsandbytes/optim/rmsprop.py b/bitsandbytes/optim/rmsprop.py new file mode 100644 index 000000000..2853ca723 --- /dev/null +++ b/bitsandbytes/optim/rmsprop.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer1State + + +class RMSprop(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class RMSprop8bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class RMSprop32bit(Optimizer1State): + def __init__( + self, + params, + lr=1e-2, + alpha=0.99, + eps=1e-8, + weight_decay=0, + momentum=0, + centered=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + + if alpha == 0: + raise NotImplementedError( + "RMSprop with alpha==0.0 is not supported!" + ) + if centered: + raise NotImplementedError("Centered RMSprop is not supported!") + super().__init__( + "rmsprop", + params, + lr, + (alpha, momentum), + eps, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/optim/sgd.py b/bitsandbytes/optim/sgd.py new file mode 100644 index 000000000..3c0fc2b9f --- /dev/null +++ b/bitsandbytes/optim/sgd.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from bitsandbytes.optim.optimizer import Optimizer1State + + +class SGD(Optimizer1State): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + optim_bits=32, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if momentum == 0: + raise NotImplementedError("SGD without momentum is not supported!") + super().__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + optim_bits, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class SGD8bit(Optimizer1State): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if momentum == 0: + raise NotImplementedError("SGD without momentum is not supported!") + super().__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 8, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) + + +class SGD32bit(Optimizer1State): + def __init__( + self, + params, + lr, + momentum=0, + dampening=0, + weight_decay=0, + nesterov=False, + args=None, + min_8bit_size=4096, + percentile_clipping=100, + block_wise=True, + ): + if momentum == 0: + raise NotImplementedError("SGD without momentum is not supported!") + super().__init__( + "momentum", + params, + lr, + (momentum, dampening), + 0.0, + weight_decay, + 32, + args, + min_8bit_size, + percentile_clipping, + block_wise, + ) diff --git a/bitsandbytes/research/__init__.py b/bitsandbytes/research/__init__.py new file mode 100644 index 000000000..47b720d78 --- /dev/null +++ b/bitsandbytes/research/__init__.py @@ -0,0 +1,6 @@ +from . import nn +from .autograd._functions import ( + switchback_bnb, + matmul_fp8_global, + matmul_fp8_mixed, +) diff --git a/bitsandbytes/research/autograd/__init__.py b/bitsandbytes/research/autograd/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py new file mode 100644 index 000000000..0dff351e0 --- /dev/null +++ b/bitsandbytes/research/autograd/_functions.py @@ -0,0 +1,411 @@ +import operator +import warnings +from dataclasses import dataclass +from functools import reduce # Required in Python 3 + +import torch + +import bitsandbytes.functional as F + +from bitsandbytes.autograd._functions import MatmulLtState, GlobalOutlierPooler + + +# math.prod not compatible with python < 3.8 +def prod(iterable): + return reduce(operator.mul, iterable, 1) + +tensor = torch.Tensor + +class MatMulFP8Mixed(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + + B_shape = B.shape + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + # 1. Dequantize + # 2. MatmulnN + cA, state = F.quantize_blockwise(A, code=fw_code, blocksize=bsz) + fp8A = F.dequantize_blockwise(cA, state, blocksize=bsz).to(A.dtype) + + cB, state = F.quantize(B.float(), code=fw_code) + fp8B = F.dequantize(cB, state).to(B.dtype) + + output = torch.matmul(fp8A, fp8B) + + # output is half + + # 3. Save state + ctx.fw_code = fw_code + ctx.bw_code = bw_code + ctx.bsz = bsz + ctx.bsz2 = bsz2 + ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype + + if any(ctx.needs_input_grad[:2]): + # NOTE: we send back A, and re-quant. + ctx.tensors = (A, fp8B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None + + req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad + A, B = ctx.tensors + + grad_A, grad_B = None, None + + # TODO: Fix blocksize to be output_dim + cgrad_out, state = F.quantize_blockwise(grad_output, code=ctx.bw_code, blocksize=ctx.bsz2) + fp8out = F.dequantize_blockwise(cgrad_out, state, blocksize=ctx.bsz2).to(grad_output.dtype) + + # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) + # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) + + # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') + # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose + # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) + + # not supported by PyTorch. TODO: create work-around + if req_gradA: + grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) + + if req_gradB: + if len(A.shape) == 3: + At = A.transpose(2, 1).contiguous() + else: + At = A.transpose(1, 0).contiguous() + # cA, state = F.quantize(At.float(), code=ctx.fw_code) + # fp8At = F.dequantize(cA, state).to(A.dtype) + grad_B = torch.matmul(At.to(grad_output.dtype), grad_output).to(B.dtype) + + return grad_A, grad_B, None, None, None, None, None + + +class MatMulFP8Global(torch.autograd.Function): + # forward is the same, but we added the fallback for pre-turing GPUs + # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") + + @staticmethod + def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024): + # default of pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + + B_shape = B.shape + if A.shape[-1] == B_shape[0]: + return torch.empty(A.shape[:-1] + B_shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device) + + # 1. Dequantize + # 2. MatmulnN + cA, state = F.quantize(A.float(), code=fw_code) + fp8A = F.dequantize(cA, state).to(A.dtype) + + cB, state = F.quantize(B.float(), code=fw_code) + fp8B = F.dequantize(cB, state).to(B.dtype) + + output = torch.matmul(fp8A, fp8B) + + # output is half + + # 3. Save state + ctx.fw_code = fw_code + ctx.bw_code = bw_code + ctx.bsz = bsz + ctx.bsz2 = bsz2 + ctx.dtype_A, ctx.dtype_B = A.dtype, B.dtype + + if any(ctx.needs_input_grad[:2]): + # NOTE: we send back A, and re-quant. + ctx.tensors = (A, fp8B) + else: + ctx.tensors = (None, None) + + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None, None, None, None + + req_gradA, req_gradB, _, _, _, _, _ = ctx.needs_input_grad + A, B = ctx.tensors + + grad_A, grad_B = None, None + + # TODO: Fix blocksize to be output_dim + cgrad_out, state = F.quantize(grad_output.float(), code=ctx.bw_code) + fp8out = F.dequantize(cgrad_out, state).to(grad_output.dtype) + + # cgrad_output_2, state_2 = F.quantize(grad_output.float(), code=ctx.bw_code) + # fp8out_2 = F.dequantize(cgrad_output_2, state_2).to(grad_output.dtype) + + # grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() + # fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector') + # fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose + # fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2]) + + # not supported by PyTorch. TODO: create work-around + if req_gradA: + grad_A = torch.matmul(fp8out, B.t().to(fp8out.dtype)).to(A.dtype) + + if req_gradB: + if len(A.shape) == 3: + At = A.transpose(2, 1).contiguous() + else: + At = A.transpose(1, 0).contiguous() + cA, state = F.quantize(At.float(), code=ctx.fw_code) + fp8At = F.dequantize(cA, state).to(A.dtype) + grad_B = torch.matmul(fp8At.to(fp8out.dtype), fp8out).to(B.dtype) + + return grad_A, grad_B, None, None, None, None, None + + +class SwitchBackBnb(torch.autograd.Function): + @staticmethod + def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + ctx.bias = bias + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=A.dtype, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=A.dtype, device=A.device) + + # 1. Quantize A + # 2. Quantize B + # 3. Matmul + # 4. Mixed-precision decomposition matmul + # 5. Save state + formatB = state.formatB + input_shape = A.shape + if state.outlier_pool is None: + state.outlier_pool = GlobalOutlierPooler.get_instance() + + # Cast A to fp16 + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + + # 1. Quantize A + if len(A.shape) == 3: + A = A.view(-1, A.shape[-1]).contiguous() + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant( + A.to(torch.float16), threshold=state.threshold + ) + + if state.threshold > 0.0 and coo_tensorA is not None: + if state.has_fp16_weights: + idx = torch.unique(coo_tensorA.colidx).long() + CA[:, idx] = 0 + CAt[:, idx] = 0 + subA = A[:, idx] + state.subB = B[:, idx].t().contiguous() + state.idx = idx + else: + if state.CxB is None: + # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions + # we also need to convert it to the turing/ampere format + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + else: + #print('A shape', A.shape) + if not state.has_fp16_weights and state.CxB is None: + state.CxB, state.SB = F.transform(state.CB, to_order=formatB) + subA = None + + # 2. Quantize B + if state.has_fp16_weights: + #print('B shape', B.shape) + has_grad = True if (getattr(B, "grad", None) is not None) else False + is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) + if is_transposed: + B = B.contiguous() + + if (state.is_training and not has_grad) or state.CxB is None: + state.reset_grads() + ( + CB, + state.CBt, + state.SCB, + state.SCBt, + coo_tensorB, + ) = F.double_quant(B.to(torch.float16)) + state.CxB, state.SB = F.transform(CB, to_order=formatB) + else: + has_grad = False + + if coo_tensorA is not None and not state.has_fp16_weights: + # extract outliers + + outlier_idx = torch.unique(coo_tensorA.colidx) + state.idx = outlier_idx + # state.outlier_pool.add_outliers(outlier_idx, A.shape[-1]) + # if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]: + # # do not use pool for 2nd FFN layer + # state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device) + # else: + # state.idx = outlier_idx + outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) + state.subB = ( + (outliers * state.SCB.view(-1, 1) / 127.0) + .t() + .contiguous() + .to(A.dtype) + ) + CA[:, state.idx.long()] = 0 + CAt[:, state.idx.long()] = 0 + subA = A[:, state.idx.long()] + + shapeB = state.SB[0] + + if len(input_shape) == 3: + output_shape = (input_shape[0], input_shape[1], shapeB[0]) + else: + output_shape = (input_shape[0], shapeB[0]) + + # 3. Matmul + C32A, SA = F.transform(CA, "col32") + out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) + # we apply the fused bias here + + if bias is None or bias.dtype == torch.float16: + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) + output = output.to(A.dtype) + else: # apply bias separately + output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) + output = output.to(A.dtype).add_(bias) + + # 4. Mixed-precision decomposition matmul + if coo_tensorA is not None and subA is not None: + output += torch.matmul(subA, state.subB) + + # 5. Save state + ctx.state = state + + ctx.formatB = formatB + ctx.grad_shape = input_shape + ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype + + if any(ctx.needs_input_grad[:2]): + ctx.tensors = (CAt, subA, A) + ctx.tensor_states = (SCAt, state.idx) + else: + ctx.tensors = [None, None, None] + ctx.tensor_states = (None, None) + ctx.save_for_backward(None, None) + + + clone_func = torch.clone if len(output_shape) == 3 else lambda x : x + return clone_func(output.view(output_shape)) + + @staticmethod + def backward(ctx, grad_output): + if ctx.is_empty: + bias_grad = (None if ctx.bias is None else torch.zeros_like(ctx.bias)) + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None + req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad + CAt, subA, A = ctx.tensors + SCAt, idx = ctx.tensor_states + formatB = ctx.formatB + state = ctx.state + grad_A = grad_B = grad_bias = None + + if req_gradBias: + # compute grad_bias first before changing grad_output dtype + grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) + + # Cast grad_output to fp16 + if len(grad_output.shape) == 3: + grad_output = grad_output.reshape( + -1, grad_output.shape[-1] + ).contiguous() + + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + + if req_gradB: + # print('back A shape', A.shape) + # print('grad output t shape', grad_output.t().shape) + grad_B = torch.matmul(grad_output.t(), A) + + if req_gradA: + if state.CBt is not None: + C32grad, Sgrad = F.transform(Cgrad, "col32") + if state.CxBt is None: + state.CxBt, state.SBt = F.transform( + state.CBt, to_order=formatB, transpose=True + ) + # print('back B shape', state.CxBt.shape) + # print('back grad shape', C32grad.shape) + gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + + elif state.CB is not None: + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1. / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + else: + raise Exception('State must contain either CBt or CB matrix for backward') + + return grad_A, grad_B, None, grad_bias, None + +def get_block_sizes(input_matrix, weight_matrix): + input_features = input_matrix.shape[-1] + output_features = (weight_matrix.shape[0] if weight_matrix.shape[1] == input_features else weight_matrix.shape[1]) + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + bsz, bsz2 = 1024, 1024 + for i, k in enumerate(array): + if input_features > array[i + 1]: + bsz = k + break + for i, k in enumerate(array): + if output_features > array[i + 1]: + bsz2 = k + break + + return bsz, bsz2 + +def matmul_fp8_global(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + return MatMulFP8Global.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + +def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out: tensor = None, bsz : int = -1, bsz2 : int = -1): + if bsz == -1 or bsz2 == -1: bsz, bsz2 = get_block_sizes(A, B) + return MatMulFP8Mixed.apply(A, B, out, fw_code, bw_code, bsz, bsz2) + +def switchback_bnb( + A: tensor, + B: tensor, + out: tensor = None, + state: MatmulLtState = None, + threshold=0.0, + bias=None +): + state = state or MatmulLtState() + if threshold > 0.0: + state.threshold = threshold + return SwitchBackBnb.apply(A, B, out, bias, state) diff --git a/bitsandbytes/research/nn/__init__.py b/bitsandbytes/research/nn/__init__.py new file mode 100644 index 000000000..8faec10bb --- /dev/null +++ b/bitsandbytes/research/nn/__init__.py @@ -0,0 +1 @@ +from .modules import LinearFP8Mixed, LinearFP8Global diff --git a/bitsandbytes/research/nn/modules.py b/bitsandbytes/research/nn/modules.py new file mode 100644 index 000000000..2a46b40c3 --- /dev/null +++ b/bitsandbytes/research/nn/modules.py @@ -0,0 +1,64 @@ +from typing import Optional, TypeVar, Union, overload + +import torch +import torch.nn.functional as F +from torch import Tensor, device, dtype, nn + +import bitsandbytes as bnb +from bitsandbytes.optim import GlobalOptimManager +from bitsandbytes.utils import OutlierTracer, find_outlier_dims + +T = TypeVar("T", bound="torch.nn.Module") + + +class LinearFP8Mixed(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.bw_code = None + self.fw_code = None + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + for i, k in enumerate(array): + if input_features > array[i + 1]: + self.bsz = k + break + for i, k in enumerate(array): + if output_features > array[i + 1]: + self.bsz2 = k + break + + def forward(self, x: torch.Tensor): + if self.fw_code is None: + self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) + self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) + + out = bnb.research.matmul_fp8_mixed(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + if self.bias is not None: + out += self.bias + + return out + +class LinearFP8Global(nn.Linear): + def __init__(self, input_features, output_features, bias=True): + super().__init__(input_features, output_features, bias) + self.bw_code = None + self.fw_code = None + array = [4096, 2048, 1024, 512, 256, 128, 64, 0] + for i, k in enumerate(array): + if input_features > array[i + 1]: + self.bsz = k + break + for i, k in enumerate(array): + if output_features > array[i + 1]: + self.bsz2 = k + break + + def forward(self, x: torch.Tensor): + if self.fw_code is None: + self.bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(x.device) + self.fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(x.device) + + out = bnb.matmul_fp8_global(x, self.weight.t(), fw_code=self.fw_code, bw_code=self.bw_code, bsz=self.bsz, bsz2=self.bsz2) + if self.bias is not None: + out += self.bias + + return out diff --git a/bitsandbytes/triton/__init__.py b/bitsandbytes/triton/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/triton/dequantize_rowwise.py b/bitsandbytes/triton/dequantize_rowwise.py new file mode 100644 index 000000000..e092680b8 --- /dev/null +++ b/bitsandbytes/triton/dequantize_rowwise.py @@ -0,0 +1,64 @@ +import math +import torch +import time +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # rowwise quantize + + # TODO: autotune this better. + @triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] + ) + @triton.jit + def _dequantize_rowwise( + x_ptr, + state_x, + output_ptr, + inv_127, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + max_val = tl.load(state_x + pid) + output = max_val * x * inv_127 + tl.store(output_ptr + offsets, output, mask=row_mask) + + + def dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _dequantize_rowwise[grid](x, state_x, output, 1./127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py b/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py new file mode 100644 index 000000000..60a56e698 --- /dev/null +++ b/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py @@ -0,0 +1,163 @@ +import torch +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + + # This is a matmul kernel based on triton.ops.matmul + # It is modified to support rowwise quantized input and global quantized weight + # It's purpose is fused matmul then dequantize + # It does support bias. + + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + + @triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, + ) + @triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, + }) + @triton.jit + def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor: tl.constexpr, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr) + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + # conditionally add bias + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + + def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): + device = a.device + divfactor = 1. / (127. * 127.) + has_bias = 0 if bias is None else 1 + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch int8_matmul_mixed_dequantize kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _int8_matmul_mixed_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py new file mode 100644 index 000000000..33f4d13f2 --- /dev/null +++ b/bitsandbytes/triton/int8_matmul_rowwise_dequantize.py @@ -0,0 +1,164 @@ +import torch + +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): return None +else: + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # This is a matmul kernel based on triton.ops.matmul + # It is modified to support rowwise quantized input and columnwise quantized weight + # It's purpose is fused matmul then dequantize + # It does support bias. + + def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + + def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + + + @triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, + ) + @triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, + }) + @triton.jit + def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, K, divfactor, has_bias : tl.constexpr, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + w_factor = tl.load(state_w_ptr + rbn)[None, :] + x_factor = tl.load(state_x_ptr + ram)[:, None] + + # acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + + acc = (w_factor * (x_factor * (acc * divfactor))) + acc = acc.to(C.dtype.element_ty) + + if has_bias: + bias = tl.load(bias + rn).to(C.dtype.element_ty) + acc = acc + bias[None, :] + + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + + def int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias): + divfactor = 1. / (127. * 127.) + + has_bias = 0 if bias is None else 1 + + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=torch.float16) + # accumulator types + ACC_TYPE = tl.float32 #if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + # launch int8_matmul_rowwise_dequantize kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _int8_matmul_rowwise_dequantize[grid](a, b, c, bias, state_x, state_w, M, N, K, divfactor, has_bias, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + GROUP_M=8, ACC_TYPE=ACC_TYPE) + return c diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py new file mode 100644 index 000000000..54220d95a --- /dev/null +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -0,0 +1,74 @@ +import math +import torch +import time +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def quantize_columnwise_and_transpose(x: torch.Tensor): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # This kernel does fused columnwise quantization and transpose. + + # TODO: autotune this better. + @triton.autotune( + configs=[ + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_stages=16), + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=16, num_warps=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] + ) + @triton.jit + def _quantize_columnwise_and_transpose( + x_ptr, + output_ptr, + output_maxs, + n_elements, + M : tl.constexpr, N : tl.constexpr, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid + p2_arange = tl.arange(0, P2) + p2_arange_mask = p2_arange < M + arange = p2_arange * N + offsets = block_start + arange + x = tl.load(x_ptr + offsets, mask=p2_arange_mask) + abs_x = tl.abs(x) + max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + + new_start = pid * M + new_offsets = new_start + p2_arange + tl.store(output_ptr + new_offsets, output, mask=p2_arange_mask) + tl.store(output_maxs + pid, max_val) + + def quantize_columnwise_and_transpose(x: torch.Tensor): + M, N = x.shape + output = torch.empty(N, M, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[1], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(M)))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_columnwise_and_transpose[grid](x, output, output_maxs, n_elements, M, N, BLOCK_SIZE=M, P2=P2) + return output, output_maxs + diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py new file mode 100644 index 000000000..845db6ecd --- /dev/null +++ b/bitsandbytes/triton/quantize_global.py @@ -0,0 +1,107 @@ +import math +import torch +import time +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def quantize_global_transpose(input): return None + def quantize_global(x: torch.Tensor): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # global quantize + @triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': 1024,}, num_warps=4), + triton.Config({'BLOCK_SIZE': 2048,}, num_stages=1), + + ], + key=['n_elements'] + ) + @triton.jit + def _quantize_global( + x_ptr, + absmax_inv_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + output = tl.libdevice.llrint(127. * (x * absmax_inv)) + tl.store(output_ptr + offsets, output, mask=mask) + + def quantize_global(x: torch.Tensor): + absmax = x.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + output = torch.empty(*x.shape, device='cuda', dtype=torch.int8) + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + _quantize_global[grid](x, absmax_inv, output, n_elements) + return output, absmax + + + # global quantize and transpose + @triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=4), + + # ... + ], + key=['M', 'N'] + ) + @triton.jit + def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, stride_bn, stride_bm, M, N, + BLOCK_M : tl.constexpr, + BLOCK_N : tl.constexpr, + GROUP_M : tl.constexpr): + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // group_size + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an) + mask = (rm < M)[:, None] & (rn < N)[None, :] + a = tl.load(A, mask=mask) + absmax_inv = tl.load(absmax_inv_ptr) + + # rematerialize to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + + output = tl.libdevice.llrint(127. * (a * absmax_inv)) + + tl.store(B, output, mask=mask) + + def quantize_global_transpose(input): + absmax = input.abs().max().unsqueeze(0) + absmax_inv = 1./ absmax + M, N = input.shape + out = torch.empty(N, M, device='cuda', dtype=torch.int8) + + assert out.size(0) == N and out.size(1) == M + assert input.stride(0) == 1 or input.stride(1) == 1 + assert out.stride(0) == 1 or out.stride(1) == 1 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + _quantize_global_transpose[grid](input, absmax_inv, out, input.stride(0), input.stride(1), out.stride(0), out.stride(1), M, N) + return out, absmax + diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py new file mode 100644 index 000000000..26d218321 --- /dev/null +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -0,0 +1,68 @@ +import math +import torch +import time + +from bitsandbytes.triton.triton_utils import is_triton_available + +if not is_triton_available(): + def quantize_rowwise(x: torch.Tensor): return None +else: + + import triton + import triton.language as tl + from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + # rowwise quantize + + # TODO: autotune this better. + @triton.autotune( + configs=[ + triton.Config({}, num_stages=1, num_warps=8), + triton.Config({}, num_stages=2, num_warps=8), + triton.Config({}, num_stages=4, num_warps=8), + triton.Config({}, num_stages=8, num_warps=8), + triton.Config({}, num_stages=1), + triton.Config({}, num_stages=2), + triton.Config({}, num_stages=4), + triton.Config({}, num_stages=8), + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['n_elements'] + ) + @triton.jit + def _quantize_rowwise( + x_ptr, + output_ptr, + output_maxs, + n_elements, + BLOCK_SIZE: tl.constexpr, + P2: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + arange = tl.arange(0, P2) + offsets = block_start + arange + row_mask = arange < BLOCK_SIZE + x = tl.load(x_ptr + offsets, mask=row_mask) + + abs_x = tl.abs(x) + max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) + output = tl.libdevice.llrint(127. * (x / max_val)) + tl.store(output_ptr + offsets, output, mask=row_mask) + tl.store(output_maxs + pid, max_val) + + def quantize_rowwise(x: torch.Tensor): + output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) + output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) + + P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) + + assert x.is_cuda and output.is_cuda + n_elements = output.numel() + grid = lambda meta: (x.shape[0],) + _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) + return output, output_maxs + diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py new file mode 100644 index 000000000..c74c23962 --- /dev/null +++ b/bitsandbytes/triton/triton_utils.py @@ -0,0 +1,4 @@ +import importlib + +def is_triton_available(): + return importlib.util.find_spec("triton") is not None diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py new file mode 100644 index 000000000..6729f7cd4 --- /dev/null +++ b/bitsandbytes/utils.py @@ -0,0 +1,199 @@ +import shlex +import subprocess +import torch +from typing import Tuple + +def outlier_hook(module, input): + assert isinstance(module, torch.nn.Linear) + tracer = OutlierTracer.get_instance() + hvalue = tracer.get_hvalue(module.weight) + if hvalue not in tracer.hvalue2outlier_idx: + outlier_idx = find_outlier_dims(module.weight) + tracer.outliers.append(outlier_idx) + tracer.hvalues.append(hvalue) + if len(tracer.outliers) > 1: + # assign the current layer the outlier idx found from the weight + # of the previous linear layer + if tracer.outliers[-1].numel() > 0: + assert tracer.outliers[-1].max() < module.weight.shape[1] + tracer.hvalue2outlier_idx[hvalue] = tracer.outliers[-1] + + else: + # first layer, we cannot use the weight for outlier detection + # we follow a mixed approach: + # (1) zscore test of std of hidden dimension + # (2) magnitude > 6 test + merged = input[0].view(-1, input[0].shape[-1]) + # (1) zscore test of std of hidden dimension + outlier_idx = find_outlier_dims(merged, reduction_dim=1, zscore=3) + # (2) magnitude > 6 test + dims = (torch.abs(input[0])> 6).sum(dim=list(range(len(input[0].shape)-1))) + outlier_idx2 = torch.where(dims > 0)[0] + outlier_idx = torch.cat([outlier_idx, outlier_idx2]).unique() + tracer.hvalue2outlier_idx[hvalue] = outlier_idx + else: + for hook in tracer.hooks: + hook.remove() + + +class OutlierTracer(object): + _instance = None + + def __init__(self): + raise RuntimeError("Call get_instance() instead") + + def initialize(self, model): + self.last_w = None + self.current_outlier_dims = None + self.hvalues = [] + self.outliers = [] + self.hvalue2outlier_idx = {} + self.initialized = True + self.hooks = [] + + for n, m in model.named_modules(): + if isinstance(m, torch.nn.Linear): + self.hooks.append(m.register_forward_pre_hook(outlier_hook)) + + def is_initialized(self): + return getattr(self, 'initialized', False) + + def get_hvalue(self, weight): + return weight.data.storage().data_ptr() + + def get_outliers(self, weight): + if not self.is_initialized(): + print('Outlier tracer is not initialized...') + return None + hvalue = self.get_hvalue(weight) + if hvalue in self.hvalue2outlier_idx: + return self.hvalue2outlier_idx[hvalue] + else: + return None + + @classmethod + def get_instance(cls): + if cls._instance is None: + cls._instance = cls.__new__(cls) + return cls._instance + +def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False): + if rdm: + return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long() + + m = weight.mean(reduction_dim) + mm = m.mean() + mstd = m.std() + zm = (m-mm)/mstd + + std = weight.std(reduction_dim) + stdm = std.mean() + stdstd = std.std() + + zstd = (std-stdm)/stdstd + + if topk is not None: + val, idx = torch.topk(std.abs(), k=topk, dim=0) + else: + idx = torch.where(zstd > zscore)[0] + + return idx + +def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): + """ + Replace linear modules with a new Linear module. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + copy_weights (`bool`): + Copy the weights from the old linear module to the new one + post_processing_fun_name (`str`): + A function name of the replacement linear class that is called + after processing. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: func(module) + return model + + + +def execute_and_return(command_string: str) -> Tuple[str, str]: + def _decode(subprocess_err_out_tuple): + return tuple( + to_decode.decode("UTF-8").strip() + for to_decode in subprocess_err_out_tuple + ) + + def execute_and_return_decoded_std_streams(command_string): + return _decode( + subprocess.Popen( + shlex.split(command_string), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ).communicate() + ) + + std_out, std_err = execute_and_return_decoded_std_streams(command_string) + return std_out, std_err + + + +def replace_linear(model, linear_replacement, skip_modules=["lm_head"], copy_weights=False, post_processing_function=None): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + linear_replacement (`torch.nn.Module`): + The linear module that replaces the old one. Only expects standard arguments. + If other arguments need to be passed, use a lambda. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + copy_weights (`bool`): + Copy the weights from the old linear module to the new one + post_processing_fun_name (`str`): + A function name of the replacement linear class that is called + after processing. + """ + for name, module in model.named_children(): + if len(list(module.children())) > 0: + replace_linear(module, linear_replacement, skip_modules, copy_weights, post_processing_function) + + if isinstance(module, torch.nn.Linear) and name not in skip_modules: + old_module = model._modules[name] + model._modules[name] = linear_replacement( + module.in_features, + module.out_features, + module.bias is not None, + ) + if copy_weights: + model._modules[name].weight = old_module.weight + model._modules[name].bias = old_module.bias + + if post_processing_function is not None: + func = getattr(module, post_processing_function, None) + if func is not None: func(module) + return model + diff --git a/check_bnb_install.py b/check_bnb_install.py new file mode 100644 index 000000000..77cd03ec4 --- /dev/null +++ b/check_bnb_install.py @@ -0,0 +1,20 @@ +import bitsandbytes as bnb +import torch + +p = torch.nn.Parameter(torch.rand(10,10).cuda()) +a = torch.rand(10,10).cuda() + +p1 = p.data.sum().item() + +adam = bnb.optim.Adam([p]) + +out = a*p +loss = out.sum() +loss.backward() +adam.step() + +p2 = p.data.sum().item() + +assert p1 != p2 +print('SUCCESS!') +print('Installation was successful!') diff --git a/compile_from_source.md b/compile_from_source.md new file mode 100644 index 000000000..9d4f89da2 --- /dev/null +++ b/compile_from_source.md @@ -0,0 +1,35 @@ +# Compiling from source + +Basic steps. +1. `CUDA_VERSION=XXX make [target]` where `[target]` is among `cuda92, cuda10x, cuda110, cuda11x, cuda12x, cpuonly` +2. `python setup.py install` + +To run these steps you will need to have the nvcc compiler installed that comes with a CUDA installation. If you use anaconda (recommended) then you can figure out which version of CUDA you are using with PyTorch via the command `conda list | grep cudatoolkit`. Then you can install the nvcc compiler by downloading and installing the same CUDA version from the [CUDA toolkit archive](https://developer.nvidia.com/cuda-toolkit-archive). + +You can install CUDA locally without sudo by following the following steps: + +```bash +wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh +# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH +# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121} +# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True + +# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc +bash cuda install 117 ~/local 1 +``` + +By default, the Makefile will look at your `CUDA_HOME` environmental variable to find your CUDA version for compiling the library. If this path is not set it is inferred from the path of your `nvcc` compiler. + +Either `nvcc` needs to be in path for the `CUDA_HOME` variable needs to be set to the CUDA directory root (e.g. `/usr/local/cuda`) in order for compilation to succeed + +If you type `nvcc` and it cannot be found, you might need to add to your path or set the CUDA_HOME variable. You can run `python -m bitsandbytes` to find the path to CUDA. For example if `python -m bitsandbytes` shows you the following: +``` +++++++++++++++++++ /usr/local CUDA PATHS +++++++++++++++++++ +/usr/local/cuda-11.7/targets/x86_64-linux/lib/libcudart.so +``` +You can set `CUDA_HOME` to `/usr/local/cuda-11.7`. For example, you might be able to compile like this. + +``CUDA_HOME=~/local/cuda-11.7 CUDA_VERSION=117 make cuda11x`` + + +If you have problems compiling the library with these instructions from source, please open an issue. diff --git a/csrc/common.cpp b/csrc/common.cpp new file mode 100644 index 000000000..52f029917 --- /dev/null +++ b/csrc/common.cpp @@ -0,0 +1,39 @@ +#include +#include + +void *quantize_block(void *arguments) { + // 1. find absmax in block + // 2. divide input value by absmax to normalize into [-1.0, 1.0] + // 3. do binary search to find the closest value + // 4. check minimal distance + // 5. store index + + struct quantize_block_args *args = (quantize_block_args *) arguments; + + // 1. find absmax in block + float absmax_block = -FLT_MAX; + for (long long i = args->block_idx; i < args->block_end; i++) + absmax_block = fmax(absmax_block, fabs(args->A[i])); + + args->absmax[args->block_idx / args->blocksize] = absmax_block; + + for (long long i = args->block_idx; i < args->block_end; i++) { + // 2. divide input value by absmax to normalize into [-1.0, 1.0] + // 3. do binary search to find the closest value + float normed_value = args->A[i] / absmax_block; + long long idx = args->bin_searcher->scalar(normed_value); + + // 4. check minimal distance + // The binary search returns always the value to the left, which might not be the closest value + if (idx < 255) { + float dist_left = fabs(normed_value - (args->code[idx])); + float dist_right = fabs(normed_value - (args->code[idx + 1])); + if (dist_right < dist_left) { idx += 1; } + } + + // 5. store index + args->out[i] = (unsigned char) idx; + } + + return NULL; +} diff --git a/csrc/common.h b/csrc/common.h new file mode 100644 index 000000000..c99034e78 --- /dev/null +++ b/csrc/common.h @@ -0,0 +1,25 @@ +#include + +#ifndef common +#define common + +using namespace BinSearch; + +#define BLOCK_SIZE 16384 + +struct quantize_block_args { + BinAlgo *bin_searcher; + float *code; + float *A; + float *absmax; + unsigned char *out; + long long block_end; + long long block_idx; + long long threadidx; + long long blocksize; +}; + + +void *quantize_block(void *arguments); + +#endif diff --git a/csrc/cpu_ops.cpp b/csrc/cpu_ops.cpp new file mode 100644 index 000000000..e28e7b2c2 --- /dev/null +++ b/csrc/cpu_ops.cpp @@ -0,0 +1,73 @@ +#include +#include +#include + +using namespace BinSearch; + +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n) { + for (long long block_idx = 0; block_idx < n; block_idx += blocksize) { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + for (long long i = block_idx; i < block_end; i++) + out[i] = code[A[i]] * absmax[block_idx / blocksize]; + } +} + +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n) +{ + + // the default code is has range [-0.993, 1.0] which can cause an error in the binary search algorithm used below + code[0] = -1.0f; + + long long num_blocks = n / blocksize; + num_blocks += n % blocksize == 0 ? 0 : 1; + + const uint32 elements_code = 256; + BinAlgo bin_searcher(code, elements_code); + + int thread_wave_size = 256; + // we chunk the thresds into waves of 256 since the max limit is + // between 16k and 64k on Linux (we reach this when running BLOOM-176B with a large batch size) + for(long long offset = 0; offset < num_blocks; offset+=thread_wave_size) + { + long long valid_chunks = num_blocks - offset >= thread_wave_size ? thread_wave_size : num_blocks - offset; + pthread_t *threads = (pthread_t *) malloc(sizeof(pthread_t) * valid_chunks); + + struct quantize_block_args **args = (quantize_block_args **) malloc(valid_chunks * sizeof(quantize_block_args *)); + + for(long long i = 0; i < valid_chunks; i++) + args[i] = (quantize_block_args *) malloc(sizeof(quantize_block_args)); + + int chunks_processed = 0; + for(long long block_idx = offset*blocksize; block_idx < n; block_idx += blocksize) + { + long long valid_items = n - block_idx >= blocksize ? blocksize : n - block_idx; + long long block_end = block_idx + valid_items; + + struct quantize_block_args *arg = args[chunks_processed]; + arg->bin_searcher = &bin_searcher; + arg->code = code; + arg->A = A; + arg->absmax = absmax; + arg->out = out; + arg->block_end = block_end; + arg->block_idx = block_idx; + arg->threadidx = block_idx / blocksize; + arg->blocksize = blocksize; + + pthread_create(&threads[chunks_processed], NULL, &quantize_block, (void *) arg); + chunks_processed += 1; + if(chunks_processed == valid_chunks){ break; } + } + + for (int i = 0; i < valid_chunks; i++) + int err = pthread_join(threads[i], NULL); + + free(threads); + for (int i = 0; i < valid_chunks; i++) + free(args[i]); + free(args); + + } + +} diff --git a/csrc/cpu_ops.h b/csrc/cpu_ops.h new file mode 100644 index 000000000..2ddf81e49 --- /dev/null +++ b/csrc/cpu_ops.h @@ -0,0 +1,10 @@ +#ifndef BITSANDBYTES_CPU_OPS_H +#define BITSANDBYTES_CPU_OPS_H + +#include +#include + +void quantize_cpu(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n); +void dequantize_cpu(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n); + +#endif diff --git a/csrc/kernels.cu b/csrc/kernels.cu new file mode 100644 index 000000000..9e135dbc5 --- /dev/null +++ b/csrc/kernels.cu @@ -0,0 +1,3605 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#define HLF_MAX 65504 +#define TH 1024 +#define NUM 4 +#define NUM_BLOCK 4096 + +using namespace nvcuda; + +// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda +__device__ float atomicMax(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS( + reinterpret_cast(address), assumed, + __float_as_int(fmaxf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +} + +__device__ float atomicMin(float* address, float val) { + int* address_as_i = reinterpret_cast(address); + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS( + reinterpret_cast(address), assumed, + __float_as_int(fminf(val, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +} + +__device__ float dDequantizeFP4(unsigned char val, float absmax) +{ +} + +__device__ float d2DequantizeFP4(unsigned char val) +{ +} + +__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) +{ +} + +__device__ unsigned char dQuantizeFP4(float x) +{ +} + +__device__ half dhDequantizeNF4(unsigned char val) +{ +} + +__device__ float dDequantizeNF4(unsigned char val) +{ +} + +__device__ unsigned char dQuantizeNF4(float x) +{ +} +// sign function for lion +// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA + +template __device__ int sgn(T val) +{ + return (T(0) < val) - (val < T(0)); +} + +template +__device__ unsigned char dQuantize(float* smem_code, const float rand, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = -1.0f; + float upper = 1.0f; + + float val = smem_code[pivot]; + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(upper_pivot == 255) + upper = smem_code[upper_pivot]; + if(lower_pivot == 0) + lower = smem_code[lower_pivot]; + + if(!STOCHASTIC) + { + if(x > val) + { + float midpoint = (upper+val)*0.5f; + if(x > midpoint) + { + return upper_pivot; + } + else + return pivot; + } + else + { + float midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } + } + else + { + if(x > val) + { + float dist_to_upper = fabsf(upper-x); + float dist_full = upper-val; + if(rand >= dist_to_upper/dist_full) return upper_pivot; + else return pivot; + } + else + { + float dist_to_lower = fabsf(lower-x); + float dist_full = val-lower; + if(rand >= dist_to_lower/dist_full) return lower_pivot; + else return pivot; + } + } +} + +template +__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x) +{ + int pivot = 127; + int upper_pivot = 255; + int lower_pivot = 0; + + float lower = SIGNED ? -1.0f : 0.0f; + float upper = 1.0f; + float midpoint; + float val = quadrants[1]; + int local_pivot = 1; + int offset = 1; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 64; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + //val = i == 64 ? quadrants[2] : smem_code[pivot]; + local_pivot += offset; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + //val = i == 64 ? quadrants[0] : smem_code[pivot]; + local_pivot -= offset; + } + val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot]; + offset -= 1; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +template +__device__ __forceinline__ unsigned char quantize_quadrant(int QUADRANT, float *__restrict__ const smem_code, float x, float lower, float midpoint, float upper) +{ + int lower_pivot = QUADRANT*16-1 - 0; + int pivot = QUADRANT*16-1 + 16; + int upper_pivot = QUADRANT*16-1 + 31; + + float val = midpoint; + + // i>>=1 = {32, 16, 8, 4, 2, 1} + for(int i = 16; i > 0; i>>=1) + { + if(x > val) + { + lower_pivot = pivot; + lower = val; + pivot+=i; + } + else + { + upper_pivot = pivot; + upper = val; + pivot-=i; + } + val = smem_code[pivot]; + } + + if(x > val) + { + midpoint = (upper+val)*0.5f; + if(x > midpoint) + return upper_pivot; + else + return pivot; + } + else + { + midpoint = (lower+val)*0.5f; + if(x < midpoint) + return lower_pivot; + else + return pivot; + } +} + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n) +{ + const int tid = threadIdx.x + (blockDim.x*blockIdx.x); + const int numThreads = blockDim.x*gridDim.x; + + for(int i = tid; i < n; i+=numThreads) + { + int idx = (index1[i]*maxidx1) + index2[i]; + atomicAdd(&histogram[idx], src[i]); + } +} + +template +__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n) +{ + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage; + typedef cub::BlockLoad LoadT; + __shared__ typename LoadT::TempStorage loadt; + + const int warp_idx = threadIdx.x/32; + const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE); + + // BLOCK_SIZE/32 == number of warps + __shared__ int smem_max_indices[8*BLOCK_SIZE/32]; + __shared__ float smem_max_values[8*BLOCK_SIZE/32]; + + T values[8]; + T max1 = -64000.0f; + T max2 = -64000.0f; + int max_idx1 = -1; + int max_idx2 = -1; + int sign1 = -1; + int sign2 = -1; + + // 1. load 8 values per thread + // 2. compute 2-max in registers (64 max per warp) + // 3. do warp reduction + broadcast back + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + // 5. Repeat (3) 8 times for top 8 values in 256 + // 6. store with byte index + + LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f); + #pragma unroll 8 + for(int i = 0; i < 8; i++) + { + T absval = fabsf(values[i]); + if(absval > max1) + { + max1 = values[i]; + sign1 = signbit(values[i]); + max_idx1 = 8*threadIdx.x + i; + } + else if(absval > max2) + { + max2 = values[i]; + sign2 = signbit(values[i]); + max_idx2 = 8*threadIdx.x + i; + } + } + + float warp_max; + for(int i = 0; i < 8; i++) + { + // 3. do warp reduction + broadcast back + warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max()); + warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff); + + // 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest + if(warp_max == max1) + { + smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1; + smem_max_indices[warp_idx*8 + i] = max_idx1; + + sign1 = sign2; + max1 = max2; + max_idx1 = max_idx2; + + max2 = -64000.0f; + } + __syncwarp(); + } + + if(threadIdx.x % 32 < 8) + { + // offset: 8 values per 256 input values + // + int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8; + } + +} + +#define THREADS_ESTIMATE 512 +#define NUM_ESTIMATE 8 +#define BLOCK_ESTIMATE 4096 + +template +__launch_bounds__(THREADS_ESTIMATE, 1) +__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n) +{ + const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE; + const int base_idx = (blockIdx.x * BLOCK_ESTIMATE); + const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE)); + + T vals[NUM_ESTIMATE]; + + typedef cub::BlockRadixSort BlockRadixSort; + typedef cub::BlockLoad LoadFloat; + + __shared__ union { + typename LoadFloat::TempStorage loadf; + typename BlockRadixSort::TempStorage sort; + int smem_qidx[BLOCK_ESTIMATE]; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE) + { + valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i; + + // do not process half-blocks + if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; } + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = max_val; + + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items); + + #pragma unroll 4 + for(int j = 0; j < NUM_ESTIMATE; j++) + vals[j] = ((float)vals[j]) * reciprocal_num_blocks; + + + __syncthreads(); + // sort into striped pattern to mitigate bank conflicts + // striped pattern index for thread 0 [0, 1024, 2048, 3096] + // striped pattern index for thread 1 [1, 1025, 2049, 3097] + BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals); + + __syncthreads(); + for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x) + temp_storage.smem_qidx[j] = -1; + + if(threadIdx.x < 256) + { + float q_interval = (1.0f-(2.0f*offset))/255.0f; + int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1))); + temp_storage.smem_qidx[local_idx] = threadIdx.x; + } + + __syncthreads(); + + for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x) + { + if(temp_storage.smem_qidx[i] != -1) + atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]); + } + } +} + + +__launch_bounds__(TH, 4) +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n) +{ + const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); + int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK; + const int base_idx = (blockIdx.x * NUM_BLOCK); + + float vals[NUM]; + unsigned char qvals[NUM]; + //const int lane_id = threadIdx.x % 2; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreChar; + + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ float smem_code[256]; + //__shared__ float smem_code[2][257]; + + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + //smem_code[0][threadIdx.x] = code[threadIdx.x]; + //smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK) + { + // number of values already processed in blocks + + // number of values already processed in this block + + // rand_offset % mod value + valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + + __syncthreads(); + LoadFloat(loadf).Load(&(A[i]), vals, valid_items); + + + #pragma unroll 4 + for(int j = 0; j < NUM; j++) + qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]); + + __syncthreads(); + StoreChar(storec).Store(&(out[i]), qvals, valid_items); + } +} + +template +//__launch_bounds__(TH, 4) +__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n) +{ + const int n_full = gridDim.x * BLOCK_SIZE; + int valid_items = 0; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + + T vals[NUM_PER_TH]; + float rand_vals[NUM_PER_TH]; + unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH]; + //float local_abs_max = -FLT_MAX; + float local_abs_max = 0.0f; + int local_rand_idx = 0; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockStore 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar; + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadFloat; + + __shared__ typename LoadT::TempStorage loadt; + __shared__ typename LoadFloat::TempStorage loadf; + __shared__ typename StoreChar::TempStorage storec; + __shared__ typename BlockReduce::TempStorage reduce; + __shared__ float smem_code[256]; + __shared__ float smem_absmax_value[1]; + + if(DATA_TYPE == General8bit) + for(int i = threadIdx.x; i < 256; i+=blockDim.x) + smem_code[i] = code[i]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_abs_max = -FLT_MAX; + + __syncthreads(); + LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f); + + // 1. compute local max + // 2. broadcast local max + // 3. normalize inputs and quantize + + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); + + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); + + if(threadIdx.x == 0) + smem_absmax_value[0] = local_abs_max; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax[i/BLOCK_SIZE] = local_abs_max; + else + local_abs_max = smem_absmax_value[0]; + + __syncwarp(); + + local_abs_max = 1.0f/local_abs_max; + + if(STOCHASTIC) + { + local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); + LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); + } + + unsigned char packed_4bit = 0; + switch(DATA_TYPE) + { + case General8bit: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + if(!STOCHASTIC) + qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max); + else + qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max); + } + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + //packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; + //packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); + //qvals[j] = packed_4bit; + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH/2; j++) + { + //packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; + //packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); + //qvals[j] = packed_4bit; + } + break; + } + + __syncthreads(); + StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items); + } +} + +template +__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n) +{ + + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)]; + unsigned char qvals[NUM_PER_TH]; + float local_abs_max = -FLT_MAX; + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore 0) ? 2 : 1), cub::BLOCK_STORE_WARP_TRANSPOSE> StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) + { + if(DATA_TYPE > 0) + { + valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; + valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; + } + else + { + valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; + valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; + } + local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); + + switch(DATA_TYPE) + { + case General8bit: + // load code through read-only cache via __ldg + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + vals[j] = __ldg(&code[qvals[j]])*local_abs_max; + break; + case FP4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + //vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); + //vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); + } + break; + case NF4: + #pragma unroll NUM_PER_TH + for(int j = 0; j < NUM_PER_TH; j++) + { + //vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max; + //vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max; + } + break; + } + + __syncthreads(); + StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store); + } +} + +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n) +{ + const unsigned int numThreads = blockDim.x * gridDim.x; + const int idx = (blockIdx.x * blockDim.x) + threadIdx.x; + + __shared__ float smem_code[256]; + if(threadIdx.x < 256) + { + smem_code[threadIdx.x] = code[threadIdx.x]; + } + + __syncthreads(); + + for (int i = idx;i < n; i += numThreads) + { + out[i] = smem_code[A[i]]; + } +} + + + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + float s2_vals[NUM_VALS]; + + const float correction1 = 1.0f/(1.0f - powf(beta1, step)); + const float correction2 = 1.0f/(1.0f - powf(beta2, step)); + + typedef cub::BlockLoad Load; + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case ADAM: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update) + break; + } + } + + # pragma unroll NUM_VALS-1 + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + __syncwarp(); + } +} + + + +#define NUM_PER_THREAD 4 + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + float s2_vals[NUM_PER_THREAD]; + + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + typedef cub::BlockLoad Load; + typedef cub::BlockStore Store; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + switch(OPTIMIZER) + { + case ADAM: + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j])); + s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j]))); + p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2)))); + + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + break; + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items); + } +} + +template +__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1) +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n) +{ + + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS); + int valid_items = 0; + + T g_vals[NUM_VALS]; + + float s1_vals[NUM_VALS]; + + typedef cub::BlockLoad Load; + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockReduce BlockReduce; + + __shared__ union { + typename Load::TempStorage load; + typename LoadFloat::TempStorage loadf; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + g_vals[j] = gnorm_scale*((float)g_vals[j]); + + # pragma unroll NUM_VALS + for(unsigned int j = 0; j < NUM_VALS; j++) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; // state update + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update + s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value + s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm + break; + } + } + + # pragma unroll + for(unsigned int j = 1; j < NUM_VALS; j++) + s1_vals[0] += s1_vals[j]; + + __syncthreads(); + s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items); + + if(threadIdx.x == 0) + atomicAdd(&unorm[0], s1_vals[0]); + + __syncwarp(); + } +} + +template +__launch_bounds__(TH, 1) +__global__ void kOptimizer32bit1State(T *g, T *p, + float *state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n) +{ + + const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD)); + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = 0; + float update_scale = 0.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + T g_vals[NUM_PER_THREAD]; + T p_vals[NUM_PER_THREAD]; + + float s1_vals[NUM_PER_THREAD]; + + typedef cub::BlockLoad Load; + typedef cub::BlockStore Store; + + typedef cub::BlockLoad LoadFloat; + typedef cub::BlockStore StoreFloat; + + __shared__ union { + typename Load::TempStorage load; + typename Store::TempStorage store; + typename LoadFloat::TempStorage loadf; + typename StoreFloat::TempStorage storef; + } temp_storage; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items); + __syncthreads(); + LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items); + __syncthreads(); + Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items); + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + g_vals[j] = gnorm_scale*((float)g_vals[j]); + if(weight_decay > 0.0f) + g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j])))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j])); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); + p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); + p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps); + break; + } + } + } + + __syncthreads(); + Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items); + } +} + + +#define NUM8BIT 16 +#define NUM_THREADS 256 +#define NUM_PER_BLOCK 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_max_s2 = -FLT_MAX; + float local_unorm = 0.0f; + + float s2_vals[NUM8BIT]; + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + unsigned char r_c2[NUM8BIT]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadUInt8; + typedef cub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + if(threadIdx.x < 256) + { + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128); + __syncthreads(); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1; + s1_vals[j] += (1.0f-beta1)*g_val; + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2; + s2_vals[j] += (1.0f-beta2)*g_val*g_val; + local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j])); + } + + if(unorm != NULL) + { + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step)); + float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step)); + s1_vals[j] *= correction1; + s2_vals[j] *= correction2; + float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update + local_unorm += update_val*update_val; + } + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + __syncthreads(); + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + } + + if(threadIdx.x == 0) + { + atomicMax(&new_max1[0], local_max_s1); + atomicMax(&new_max2[0], local_max_s2); + if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); } + } +} + +#define NUM_PER_THREAD2 4 +#define NUM_THREADS2 1024 +#define NUM_PER_BLOCK2 4096 + +template +__global__ void +__launch_bounds__(NUM_THREADS2, 1) +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float s2_vals[NUM_PER_THREAD2]; + const float correction1 = 1.0f - powf(beta1, step); + const float correction2 = sqrtf(1.0f - powf(beta2, step)); + const float step_size = -lr*correction2/correction1; + //const float step_size = -lr*correction2/correction1; + float new_max_val1 = 1.0f/new_max1[0]; + float new_max_val2 = 1.0f/new_max2[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + unsigned char c2s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + __shared__ float smem_quantiles2[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 512) + { + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + else + smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256]; + } + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[c1s[j]]; + s1_vals[j] = s1_vals[j]*max1[0]; + + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + + s2_vals[j] = smem_quantiles2[c2s[j]]; + s2_vals[j] = s2_vals[j]*max2[0]; + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2); + } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps)))))); + if(weight_decay > 0.0f) + p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void +__launch_bounds__(NUM_THREADS, 2) +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n) +{ + const int n_full = gridDim.x * NUM_PER_BLOCK; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD); + int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK); + float g_val = 0.0f; + float local_max_s1 = -FLT_MAX; + float local_unorm = 0.0f; + + float s1_vals[NUM8BIT]; + T g_vals[NUM8BIT]; + unsigned char m_c1[NUM8BIT]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadUInt8; + typedef cub::BlockReduce BlockReduce; + + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadUInt8::TempStorage loadc; + typename BlockReduce::TempStorage reduce; + } temp_storage; + + __shared__ float smem_quantiles1[256]; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128); + + #pragma unroll 16 + for(int j = 0; j < NUM8BIT; j++) + { + g_val = g_vals[j]; + g_val *= gnorm_scale; + s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]; + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = (float)g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + if(unorm != NULL) + local_unorm += s1_vals[j]*s1_vals[j]; + break; + case LION: + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + } + + local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j])); + } + } + + __syncthreads(); + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); + if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); } + if(unorm != NULL) + { + __syncthreads(); + local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); + if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); } + } + +} + +template +__global__ void +__launch_bounds__(1024, 1) +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, + const float gnorm_scale, const int n) +{ + + const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2; + const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[NUM_PER_THREAD2]; + float new_max_val1 = 1.0f/new_max1[0]; + float update_scale = 1.0f; + + if(max_unorm > 0.0f) + { + update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f; + if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; } + else{ update_scale = 1.0f; } + } + else{ update_scale = 1.0f; } + + unsigned char c1s[NUM_PER_THREAD2]; + T p_vals[NUM_PER_THREAD2]; + T g_vals[NUM_PER_THREAD2]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[256]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + + if(threadIdx.x < 256) + smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x]; + + __syncthreads(); + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2) + { + valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i; + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items); + + if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; } + + # pragma unroll 4 + for(unsigned int j = 0; j < NUM_PER_THREAD2; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_vals[j]; + else + s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); + + p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j])); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val)))); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps)); + break; + } + + c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1); + + // make sure state1 term has still the same sign after quantization + if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + } +} + + +template +__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n) +{ + const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE); + int valid_items = 0; + + typedef cub::BlockReduce BlockReduce; + typedef cub::BlockLoad LoadT; + + __shared__ typename BlockReduce::TempStorage reduce; + + __shared__ typename LoadT::TempStorage loadT; + T vals[NUM_VALS]; + float local_sum = 0.0f; + + for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE) + { + valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; + local_sum = 0.0f; + + __syncthreads(); + LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f); + + #pragma unroll NUM_VALS + for(int j = 0; j < NUM_VALS; j++) + local_sum += ((float)vals[j])*((float)vals[j]); + + local_sum = BlockReduce(reduce).Sum(local_sum, valid_items); + if(threadIdx.x == 0) + { + if(step == 1) + { + // initialize with the same norm for all positions + //#pragma unroll 10 + for(int j = 0; j < 100; j++) + atomicAdd(&gnorm_vec[j], local_sum); + } + else + atomicAdd(&gnorm_vec[step % 100], local_sum); + } + + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + float s2_vals[N_PER_TH]; + // 2-5% + const float correction1 = 1.0f - __powf(beta1, step); + const float correction2 = sqrtf(1.0f -__powf(beta2, step)); + const float step_size = __fdividef(-lr*correction2,correction1); + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float new_local_abs_max2 = -FLT_MAX; + float quadrants1[QUAD]; + float quadrants2[QUAD]; + + unsigned char c1s[N_PER_TH]; + unsigned char c2s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + __shared__ float smem_quantiles2[LANES][257]; + typedef cub::BlockReduce BlockReduce1; + typedef cub::BlockReduce BlockReduce2; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ typename BlockReduce2::TempStorage reduce2; + __shared__ float smem_exchange1[1]; + __shared__ float smem_exchange2[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + { + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x]; + } + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + { + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + } + + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0); + + new_local_abs_max1 = -FLT_MAX; + new_local_abs_max2 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE]; + g_val = g_vals[j]; + //float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps); + //g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val; + g_val *= gnorm_scale; + + s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val)); + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val)); + } + else + { + s1_vals[j] = 0.0f; + s2_vals[j] = 0.0f; + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); + + if(threadIdx.x == 0) + { + smem_exchange1[0] = new_local_abs_max1; + smem_exchange2[0] = new_local_abs_max2; + } + + __syncthreads(); + + if(threadIdx.x == 0) + { + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + absmax2[i/BLOCK_SIZE] = new_local_abs_max2; + } + else + { + new_local_abs_max1 = smem_exchange1[0]; + new_local_abs_max2 = smem_exchange2[0]; + } + + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) + { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + if(weight_decay > 0.0f) + p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items); + } +} + + +#define LANES 2 +#define QUAD 3 +template +__launch_bounds__(256, 3) +__global__ void +kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n) +{ + + //const int n_full = n + (n%BLOCK_SIZE); + const int n_full = gridDim.x * BLOCK_SIZE; + const int base_idx = (blockIdx.x * BLOCK_SIZE); + int valid_items = 0; + float g_val = 0.0f; + float s1_vals[N_PER_TH]; + // 2-5% + const int lane_id = threadIdx.x % LANES; + float new_local_abs_max1 = -FLT_MAX; + float quadrants1[QUAD]; + + unsigned char c1s[N_PER_TH]; + T g_vals[N_PER_TH]; + T p_vals[N_PER_TH]; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockLoad LoadChar; + + typedef cub::BlockStore StoreChar; + typedef cub::BlockStore StoreT; + + __shared__ float smem_quantiles1[LANES][257]; + typedef cub::BlockReduce BlockReduce1; + __shared__ typename BlockReduce1::TempStorage reduce1; + __shared__ float smem_exchange1[1]; + + __shared__ union { + typename LoadT::TempStorage loadh; + typename LoadChar::TempStorage loadc; + typename StoreChar::TempStorage storec; + typename StoreT::TempStorage storeh; + } temp_storage; + // init: 0.2 -> 0.23 + + // 0.23 -> 0.23 + smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x]; + # pragma unroll + for(unsigned int j = 1; j < LANES; j++) + smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x]; + + __syncthreads(); + + #pragma unroll + for(int k = 0; k < QUAD; k++) + quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)]; + + for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) + { + // loads: 0.23 -> 0.85/1.44 + valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i; + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f); + __syncthreads(); + LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128); + __syncthreads(); + LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f); + + new_local_abs_max1 = -FLT_MAX; + + // update: 2.48/1.57 -> 2.51/1.60 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + g_val = float(g_vals[j]); + g_val *= gnorm_scale; + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + if(weight_decay > 0.0f) { + switch(OPTIMIZER) { + case MOMENTUM: + case ADAGRAD: + case RMSPROP: + g_val += ((float)p_vals[j])*weight_decay; + break; + case LION: + p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay); + break; + } + } + + s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE]; + + switch(OPTIMIZER) + { + case MOMENTUM: + if(step == 1) + s1_vals[j] = g_val; + else + s1_vals[j] = (s1_vals[j]*beta1) + g_val; + break; + case LION: + // here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2 + g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val)); + s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val); + break; + case RMSPROP: + s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val)); + break; + case ADAGRAD: + s1_vals[j] = s1_vals[j] + (g_val*g_val); + break; + } + } + + new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j])); + } + + + // reduce: 2.51/1.60 -> 2.67/1.69 + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); + + if(threadIdx.x == 0) + smem_exchange1[0] = new_local_abs_max1; + + __syncthreads(); + + if(threadIdx.x == 0) + absmax1[i/BLOCK_SIZE] = new_local_abs_max1; + else + new_local_abs_max1 = smem_exchange1[0]; + + // reduce: 2.67/1.69 -> 2.67/1.70 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) + { + switch(OPTIMIZER) + { + case MOMENTUM: + p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]); + break; + case LION: + p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]); + break; + case RMSPROP: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + case ADAGRAD: + g_val = g_vals[j]; + p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps)); + break; + } + } + } + + // store: 0.85/1.44 -> 2.48/1.57 + __syncthreads(); + StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items); + + // quantizaztion: 2.67/1.70 -> 3.4/3.3 + # pragma unroll N_PER_TH + for(unsigned int j = 0; j < N_PER_TH; j++) + { + c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1)); + + // make sure state1 term has still the same sign after quantization + // (not needed for state2 term which has only positive values) + if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j])) + { + if(s1_vals[j] > 0.0f) + c1s[j] += 1; + else + c1s[j] -= 1; + } + } + + __syncthreads(); + StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items); + } +} + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) +{ + // 0. reset stats to -FLT_MAX + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + // 2. compute col max (per thread); store in smem due to register pressure + // 3. compute row max (per block); store in smem to accumulate full global mem transation + // 4. store data via atomicMax + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadT; + typedef cub::BlockReduce BlockRowReduce; + typedef cub::BlockReduce BlockRowSum; + typedef cub::BlockExchange BlockExchange; + + __shared__ union { + typename BlockExchange::TempStorage exchange; + typename BlockRowReduce::TempStorage rowreduce; + typename BlockRowSum::TempStorage rowsum; + typename LoadT::TempStorage loadt; + } temp_storage; + + __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; + __shared__ int smem_row_nnz_values[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_data_fp32[ITEMS_PER_THREAD]; + float local_col_absmax_values[ITEMS_PER_THREAD]; + int local_row_nnz_count = 0; + float row_absmax = -FLT_MAX; + + // 0. reset stats to -FLT_MAX + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; + smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; + } + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_col_absmax_values[j] = -FLT_MAX; + + __syncthreads(); + + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + int i = base_idx; + // we load row after row from the base_position + // 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row+row >= rows){ break; } + local_row_nnz_count = 0; + i = base_idx + ((row)*cols); + // each thread gets data from the same column + __syncthreads(); + LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f)); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = fabsf(local_data[j]); + + + if(SPARSE_DECOMP) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + if((float)local_data[j] >= nnz_threshold) + { + local_row_nnz_count += 1; + local_data[j] = 0.0f; + } + } + + // 2. compute col max (per thread); store in smem due to register pressure + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + // take the col max for this row + // we use shared memory because register pressure is too high if we do this locally + //smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j])); + local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); + + // 3. compute row max (per block); store in smem to accumulate full global mem transation + + // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data_fp32[j] = local_data[j]; + + __syncthreads(); + + row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); + if(SPARSE_DECOMP) + { + __syncthreads(); + local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count); + } + // we store the data temporarily in shared memory so we + // can execute a full atomic block transaction into global memory later + // we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores + if(threadIdx.x == 0) + { + smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax; + // each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block + smem_row_nnz_values[row] = local_row_nnz_count; + } + + __syncthreads(); + + } + + // 4. store data via atomicMax + // to store col data efficienctly we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0 + // into a striped arangement: [0, 8, 16, 24, ..] for t0 + __syncthreads(); + BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+threadIdx.x+(j*THREADS) < cols) + { + float val = colStats[base_col+(threadIdx.x+(j*THREADS))]; + if(val < local_col_absmax_values[j]) + atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]); + } + + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_row+threadIdx.x+(j*THREADS) < rows) + { + float val = rowStats[base_row+(threadIdx.x+(j*THREADS))]; + if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)]) + atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]); + } + + if(SPARSE_DECOMP) + if(threadIdx.x < TILE_ROWS) + nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x]; + +} + +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kgetColRowStats(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); + +#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) + +template __global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n) +{ + + // Strategy: To dequantize we need to load col/row statistics. This can be very expensive + // since different row/col stats need to be loaded with each thread. + // (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure + // and would lead to low global load utilization. + // (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads + // for each thread and this is duplicated by a factor of 32/num-cols-per-thread. + // (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock. + // This allows for efficient row/col loading from shared memory within the tile. + // We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has + // the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts + // we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the + // shared memory loads. + + // data is in 32 column-tile major with tile width 32 columns and numRows rows + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + // C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register)) + // C2. Compute normalization values and store col values in register + // S1. Store C1 into 16-bit output + // S2. Store col/row statistics of new buffer in shared memory + + // We allow for sub-tiles to span multiple col32 tiles. This is okay + // since the items per thread only rely on a single column statistic. + + + const int n_out = numRows*numCols; + + int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1); + // we have tiles of size numRows*32, thus col only increases every numRows + // num_row_tiles is the tiles after which the column increases by 32 + // blockIdx.x is the index of the current tile + int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32)); + // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached + int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS); + + // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS + // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD + // Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads. + // For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have + // 1024*1024/(128*32) = 256 tiles + // 256 tiles are 256*128*32/4 = 256*1024 threads + + // 1. Figure out how index relates to the start of the sub-tile + // 2. Each thread < SUBTILE_ROWS calculates row index + // 3. Load striped and store in shared memory + + int local_values[ITEMS_PER_THREAD]; + half local_output[ITEMS_PER_THREAD]; + float local_rowStats[ITEMS_PER_THREAD]; + __shared__ float smem_rowStats[SUBTILE_ROWS]; + + typedef cub::BlockLoad LoadInt32; + typedef cub::BlockExchange ExchangeInt32; + __shared__ typename LoadInt32::TempStorage loadint32; + __shared__ typename ExchangeInt32::TempStorage exchangeint32; + + + // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory. + float colStat = col >= numCols ? 0.0f : colStats[col]; + float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]); + // no block loads for rows for now -- keep it simple + for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x) + { + // todo: is this global mem access slow due to overlaps or does the L1 cache work well here? + int row = (base_row+j) % numRows; // wrap around + // each warp accesses the same element, for four consequitive elements + // todo: update description about striped shared memory, it is not needed + // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements + smem_rowStats[j] = rowStats[row]; + } + __syncthreads(); + + + // each block processes SUBTILE_ROWS*32 elements + const int items_per_load = THREADS*ITEMS_PER_THREAD; + const int rows_per_load = items_per_load/32; + + int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile + int row_offset = 0; + // subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed + int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32); + for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load) + { + int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset); + int valid_items = valid_rows*32; + if(valid_items <= 0) // the sub-tile might have more elements than the tile itself + break; + + // L2. Load data in warp-striped arangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3]) + LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0); + ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values); + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j]; + + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue); + //absmax_col = fmax(fabsf(local_output[j]), absmax_col); + + // we store data in row major + // to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3] + // so that each thread holds ITEMS_PER_THREAD consecutive items for each row + // this way throughput into storage is increased by a factor of ~2x + // for now we use a simple store + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols); + if(outIdx< n_out && col < numCols) + out[outIdx] = local_output[j]; + } + + row_offset += rows_per_load; + } +} + + +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols) +{ + // assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD + // Each thread reads the same column but multiple rows + // Rows are loaded in shared memory and access is shared across the threadblock (broadcast) + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + // 2. quantize data with row/col stats + // 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance) + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + const int items_per_load = ITEMS_PER_THREAD*THREADS; + + typedef cub::BlockLoad LoadHalf; + __shared__ typename LoadHalf::TempStorage loadhalf; + typedef cub::BlockStore StoreInt8; + __shared__ typename StoreInt8::TempStorage storeint8; + + __shared__ float smem_row_stats[TILE_ROWS]; + __shared__ unsigned int smem_nnz_row_idx[TILE_ROWS]; + + half local_data[ITEMS_PER_THREAD]; + float local_col_stats[ITEMS_PER_THREAD]; + char local_quantized_data[ITEMS_PER_THREAD]; + + // 0. Load row stats data into shared memory; load col stat (1 fixed per thread) + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols) + local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]); + + for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x) + { + if(base_row + i < rows) + smem_row_stats[i] = rowStats[base_row+i]; + + if(SPARSE_DECOMP) + smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i]; + } + __syncthreads(); + + // we load row after row from the base_position + // 1. Load data row by row (should be at least with TILE_SIZE = 512) + for(int row = 0; row < TILE_ROWS; row++) + { + if(base_row + row >= rows){ break; } + int i = base_idx + (row*cols); + int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col; + + + LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); + float row_stat = __fdividef(127.0f, smem_row_stats[row]); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + if(SPARSE_DECOMP) + { + if(fabsf((float)local_data[j]) >= threshold) + { + local_quantized_data[j] = 0; + + int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX); + + rowidx[old_idx] = base_row+row; + colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j; + val[old_idx] = local_data[j]; + } + else + { + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + } + else + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat)); + } + + StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items); + + // 2. quantize data with row/col stats + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + // we already pre-normalized the col/row stat: + // what this does is float/absmax*127 = int8 + local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j])); + } + + __syncthreads(); + StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items); + + } +} + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols) +{ + + // 0. Load data into 32*32 shared memory tiles + // 1. transpose / reorder in shared memory + // 2. store + + // COL32 FORMAT: + // rows*32 tiles + + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + + + // To have efficient loads and stores if we transpose we need 128 consequitive bytes which at 1 byte are 128 values + // As such we need: + // at least 32*4 shared memory tiles for col32; preferably 32*32 + // at least 32*6 shared memory tiles for col32_ampere: preferably 32*32 + // at least 32*8 shared memory tiles for col4_turing: preferably 32*32 + // for efficient loading of row major we need to load 128 elements and repeat this 32 items + // this would imply a 32x128 shared memory tile -> 4kb + // It is more efficient to have more than 1 warp, so with 64 threads we need 32x128 -> 8 kb + // we have 64k sharded mem per SM in Turing which is 8 blocks per SM which is 2*8 = 32 warps = 100% occupancy + // for turing and 50% for A100 and 75% for RTX 30s / A40 which is probably good enough + // register pressure should be low with: 8 registers from local memoryh per block and 64 registers per SM + // + // to make the shared memory work with that occupancy we might need to union the block loads/stores + + // each block loads TILE_COLs columns and TILE_ROW rows + // after reading a tile the row counter increase by TILE_ROWS + // the col counter reset after reading TILE_COL elements + const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS; + // col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached + const int base_col = (blockIdx.x*TILE_COLS) % tiledCols; + const int base_idx = (base_row*cols) + base_col; + + // we load 128 bytes per warp with + // 32 rows for transposes that fill col32 types + // so that we can have contiguous stores + __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; + char local_data[ITEMS_PER_THREAD]; + typedef cub::BlockExchange BlockExchange; + + // we load row after row from the base_position + // Load data row by row + int warps = blockDim.x/32; + int warp_id = threadIdx.x/32; + int warp_lane = threadIdx.x % 32; + int offset = 0; + + int smem_row = 0; + // each warp loads one row of 128 bytes + for(int row = warp_id; row < TILE_ROWS; row+=warps) + { + int i = base_idx + (row*cols); + // we load up to 128 bytes/items per load + int valid_items = cols - base_col > 32*ITEMS_PER_THREAD ? 32*ITEMS_PER_THREAD : cols - base_col; + + // 0. Load data into 32*32 shared memory tiles + if(base_row + row < rows) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int col_idx = warp_lane+(j*32); + if(col_idx < valid_items) + local_data[j] = A[i+col_idx]; + else + local_data[j] = 0; + } + } + else + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + local_data[j] = 0; + } + + if(TRANSPOSE) + { + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + int local_col = (32*j)+warp_lane; + //int local_row = row; + // store as 256x32 + smem_data[(local_col*33) + row] = local_data[j]; + } + } + else + { + // treat smem as 32x256, that is 32 rows and 256 columns + #pragma unroll ITEMS_PER_THREAD + for(int j = 0; j < ITEMS_PER_THREAD; j++) + smem_data[row*32*ITEMS_PER_THREAD + (warp_lane) + (j*32)] = local_data[j]; + } + + + + smem_row += warps; + + // 1. transpose / reorder in shared memory + if(smem_row % 32 == 0) + { + smem_row = 0; + __syncthreads(); + + for(int subrow = warp_id; subrow < 32; subrow+=warps) + { + for(int j = 0; j < ITEMS_PER_THREAD; j++) + { + + switch(FORMAT) + { + case COL32: + if(TRANSPOSE) + { + // data lies in shared memory in the following way: + // row0 [col0 col1 ... col31] + // row1 [col0 col1 ... col31] + // ... + // + // As such we read consequtive entries with 256 threads (8rows x 32 columns) + // as j increase, the row increase by a factor of 8 + // We load 8 rows per subrow loop, and subrow increase by 8 per loop + // so we have an offset of 8 rows every loop or (subrow/warps)*8 = (subrow/8)*8 + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size outRows*32 and base_row is done in increments of 32 + offset = base_row*outRows; + out[offset + (base_col + jrow + subrow_loop_row)*32 + threadIdx.x] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + offset = (base_col/32)*(32*rows); + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + out[offset+(base_row+subrow)*32 + ((j)*rows*32)+warp_lane] = data; + } + } + break; + case COL_TURING: + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*4 = 64 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // + // [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 8*32 = 256 elements offset + // for each row offset of 8 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/8)*256; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 256*outRows/8; // there are outRows/8 row tiles, and each tile is 256 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 256*outRows/8*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + // since we process even number of rows with each j (8) and with each subrow (8j) we can determine + // odd or even rows with the warp_id (each warp processes one row) + // the col is warp_lane (max 32 columns per row) and the row warp_id + if(warp_id % 2 == 1) + // odd + offset += 128 + (warp_lane/4)*16 + (warp_lane%4) + (((warp_id%8)-1)*2); + else + // even + offset += 0 + (warp_lane/4)*16 + (warp_lane%4) + ((warp_id%8)*2); + + out[offset] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + // set offset designates the tile offset among the 8*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 8*32=256 every 8 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/8)*256); // global offset (8x32 tile) + // first 4 rows are reserved for even rows, [0, 2, 4, 6], the next 4 for odd + // each of these has 32 values in total for 32*4 = 128 as offset if odd + // every set of 4 columns increases the total offset by 16 + // each even row increase the offset by 4, for example row 2 is offset by 4, 4 by 6 etc so: subrow/2*4 = subrow*2 + // this happends every 8 rows anew (subrow % 8) + // one writes 4 columns at once that is (col % 4) for the particular index in the subtile + int subcol = warp_lane; + + // add local offset (4x4 sub-tile) + if(subrow % 2 == 1) + // odd + offset += 128 + (subcol/4)*16 + (subcol%4) + (((subrow%8)-1)*2); + else + // even + offset += 0 + (subcol/4)*16 + (subcol%4) + ((subrow%8)*2); + + out[offset] = data; + } + } + break; + case COL_AMPERE: + // AMPERE FORMAT: + // 32*32 tiles with 8*32 subtiles. The rows are interleaved in pairs of two rows with offset of 8 between pairs of two rows: + // row idx (each number stands for 32 values): [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // the tiles are column-major ordered, so after 1024*1024 values we process: A[32:64, 0:32] + if(TRANSPOSE) + { + const int jrow = j*ITEMS_PER_THREAD; // 8 rows per j + const int subrow_loop_row = (subrow/warps)*ITEMS_PER_THREAD*ITEMS_PER_THREAD; // 8 rows per j; 8j per subrow loop (subrow/warps) + //const int local_row = warp_id; // each warp_id is one row + //const int block_row = base_col; // block offset for row + //const int local_col = warp_lane + //const int global_col = base_row; // block offset for col + if((base_col + subrow_loop_row + jrow + warp_id < outRows) && (base_row+warp_lane < rows)) + { + // each row hae 32 columns and is offset by 1 to prevent bank conflict during storage into smem + char data = smem_data[(subrow_loop_row + jrow + warp_id)*33 + warp_lane]; + + // each 32 columns we have new tile + // each tile has size 32*32 = 1024 elements offset + // for each row offset of 32 we increaes the tile first + // after all rows are exhausted, we increase the col + int row_offset = ((base_col+jrow+subrow_loop_row+warp_id)/32)*1024; // global_row+jrow+subrow_loop_row+local_row, increase tile(=256) every 8 rows + + // we increase by row_tile_column every 32 columns + // base_row increase in increments of 32 + //int row_tile_column = 1024*outRows/32; // there are outRows/32 row tiles, and each tile is 1024 elements + //int col_offset = (base_row/32)*row_tile_column; + // -> we can remove the divisions to speed up compute since outRows is always a multiple of 8 + // 1024*outRows/32*base_row/32 = outRows*base_row + int col_offset = outRows*base_row; + + offset = row_offset+col_offset; + + + // same as in the non-transpose case (see below) + // the difference is that now rows = cols + // in this case warp_id = subrow + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = (jrow + warp_id) % 32; // offset for row > 32 is already calculated into row_offset + int ampere_row = ((local_row % 8)/2)*8 + (local_row/8)*2 + (local_row % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx=warp_lane + out[offset + (ampere_row*32) + warp_lane] = data; + } + } + else + { + if(((base_row+subrow) < rows) && (base_col+(j*32)+warp_lane < outCols)) + { + char data = smem_data[(subrow*32*ITEMS_PER_THREAD) + (j*32) + warp_lane]; + + // set offset designates the tile offset among the 32*32 tiles + // we first increase rows and then columns. Since we load 128 columns at once + // we increase the offset by outRows*32 every 32 columns + // additionally, we increase the offset by 32*32=1024 every 32 rows + offset = ((base_col+(j*32))/32)*outRows*32 + (((base_row+subrow)/32)*1024); // global offset (32x32 tile) + + // [0 1 8 9 16 17 24 25] [2 3 10 11 18 19 26 27]... + // subrow % 8 -> [0,1] in tile0, [2, 3] in tile 1 etc + // subrow % 2 -> 0 for 1st row in the pair, 1 for the 2nd row + // every 2 rows, the offset increases by two [0, 1, 8, 9...] + // every 2 rows, the row index increase by 8 [0, 1, 8, 9...] + int local_row = ((subrow % 8)/2)*8 + (subrow/8)*2 + (subrow % 2); + + // global offset + row with 32 cols each + 32 cols per j + col_idx + out[offset + (local_row*32) + warp_lane] = data; + } + } + break; + } + } + } + } + } +} + +#define DENORM 1.0f/127.0f +#define MAX_SPARSE_COUNT 32 +#define SMEM_SIZE 8*256 +template +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +{ + + // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block + // If a block finishes, the next one is scheduled. Since the last blocks like have fewer + // elements they finish faster "fillin up" the gaps left by larger blocks + + // without tensor cores + // 1. use rowidx_length to find what to load (as many blocks as there are rows) + // 2. Load A into registers + // 3. each warp loads all required rows of B but each warp is offset by k + // 4. Do mma operations that accumulate into registers + // 5. Each warp stores its output row into matrix C + + const int count = max_count[blockIdx.x]; + const int local_max_idx = max_idx[blockIdx.x]; + const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1]; + const int local_row_idx = rowidx[offset]; + + const int warp_id = threadIdx.x / 32; + const int warp_idx = threadIdx.x % 32; + const int warp_offset = (warp_id*32)*SPMM_ITEMS; + const int num_items = BITS == 8 ? 8 : 8; + int idx_col_B = warp_offset; + int local_idx_col_B_offset = 0; + + half local_valA[MAX_SPARSE_COUNT]; + int local_colidxA[MAX_SPARSE_COUNT]; + half local_valC[SPMM_ITEMS]; + T local_valsB[num_items]; + half local_valOut[num_items]; + // 128 byte loads per warp == 4 bytes per thread + + // 2. Load A into registers + for(int j = 0; j < MAX_SPARSE_COUNT; j++) + { + local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f); + local_colidxA[j] = j < count ? colidx[offset+j] : 0; + } + + // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192 + // we expect each warp to be SPMM_ITEMS*32 apart + // we have a total of 128 bytes for the bank with a bank size of 4 bytes + // added 3 bytes = 6 values between warps should reduce bank conflicts + __shared__ half smem_dequant_stats[SMEM_SIZE]; + + + while(idx_col_B < colsB) + { + + if(dequant_stats != NULL) + { + for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) + if((idx_col_B+i-local_idx_col_B_offset) < colsB) + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; + + __syncthreads(); + } + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j++) + local_valC[j] = 0.0f; + + #pragma unroll + for(int i = 0; i < count; i++) + { + // 3. each warp loads all required rows of B but each warp is offset by k + int row_offset = colsB*local_colidxA[i]; + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + // 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached + int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j; + if(idx >= colsB){ break; } + if((idx+num_items < colsB)) + { + if(BITS == 8) + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + else + reinterpret_cast(local_valsB)[0] = reinterpret_cast(B)[(row_offset+ idx)/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx+k < colsB) + local_valsB[k] = B[row_offset+idx+k]; + else + local_valsB[k] = 0.0f; + } + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + { + if(BITS == 8 && dequant_stats != NULL) + // we do texture cache reads (__ldg) on dequant_stats which should be super fast + { + float valB = local_valsB[k]; + float valA = local_valA[i]; + if(valB != 0.0 && valA != 0.0) + local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA; + } + else + local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i]; + } + } + } + + int idx_row_C = (colsB*local_row_idx); + + #pragma unroll SPMM_ITEMS + for(int j = 0; j < SPMM_ITEMS; j+=num_items) + { + //int idx_col_C = idx_col_B + (32*j) + warp_idx; + int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j; + int idx_val = idx_col_C + idx_row_C; + + if(idx_col_C +num_items < colsB) + { + + // load outputs to do inplace addition + reinterpret_cast(local_valOut)[0] = reinterpret_cast(out)[idx_val/num_items]; + + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k]; + + reinterpret_cast(out)[idx_val/num_items] = reinterpret_cast(local_valC)[j/num_items]; + } + else + { + #pragma unroll num_items + for(int k = 0; k < num_items; k++) + if(idx_col_C + k < colsB) + out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k]; + } + } + + idx_col_B += blockDim.x*SPMM_ITEMS; + local_idx_col_B_offset += blockDim.x*SPMM_ITEMS; + } +} + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA) +{ + int local_colidx = idx[blockIdx.x]; + + if(FORMAT==COL_TURING) + { + // TURING FORMAT: + // 8*32 tiles with 4*4 subtiles + // the 8*32 subtile has first all 4*4 subtiles of even rows (max 4*4*8 = 128 elements) + // the subsequent 4*4 subtiles are for all odd rows if some rows columns are empty the values are zero + // the tile repeats again after the 8*32 tile in a major column order, meaning: (next 8 rows are A[8:16, 0:32]) + // the next tile is the next 8 rows for the same 32 columns. Once all rows are finished, the column + // index increases by 32 + // columns are grouped in increments of 4, meaning that one has the following rows and columns + // rows: [0 0 0 0, 2 2 2 2, 4 4 4 4, 6 6 6 6, 0 0 0 0 ...] + // cols: [0 1 2 3, 0 1 2 4, 0 1 2 3, 0 1 2 3, 4 5 6 7 ...] + + // each thread reads 1 element = 1 row + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + int offset_per_col_tile = ((rowsA+7)/8)*32*8; + int tile_offset_rows = (row/8)*32*8; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int offset = 0; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 8; + if(row % 2 == 1) + offset += 128 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + ((subtile_row_idx-1)*2); + else + // even + offset += 0 + (subtile_col_idx/4)*16 + (subtile_col_idx%4) + (subtile_row_idx*2); + + offset += tile_offset_rows + tile_offset_cols; + + char val = A[offset]; + + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } + } + else if(FORMAT == COL_AMPERE) + { + + for(int row = threadIdx.x; row < rowsA; row+= blockDim.x) + { + // we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element + // within each tile. + int offset_per_col_tile = ((rowsA+31)/32)*32*32; + int tile_offset_rows = (row/32)*32*32; + int tile_offset_cols = (local_colidx/32)*offset_per_col_tile; + int subtile_col_idx = local_colidx%32; + int subtile_row_idx = row % 32; + // this magic is taken from the cublasLt doc (search for COL32) + int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx; + offset += tile_offset_cols + tile_offset_rows; + + char val = A[offset]; + int out_idx = (row*idx_size) + blockIdx.x; + out[out_idx] = val; + } + } +} + + +//template __global__ void kMatmul_inference_4bit(INPT *A, unsigned char *B, OUTT *out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +//// element-wise kernel +//// 1. Load batch x k into registers +//// 2. Load k x k into registers +//// 3. dequantize and store in second pair of k x k +//// 4. matmul +//// 5. sum with cub +//// 6. store outputs +//// TC kernel +//// use k warps per thread block +//// 1. threadblock use read-only cache to read in register tile for A into shared memory +//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments +//// 3. each warp reads a segment of values 16x32 from B +//// 4. do dequantization from register of B into second pair of registers +//// 5. store (4) into fragment +//// 6. matmul aggregate into fragment C +//// 7. aggreecate files of C into shared memroy block C +//// 8. sum (7) +//// 9. write outputs to matmul output matrix +//} + +template __device__ inline void vector_load(T *local, T * __restrict__ const buffer, int idx, int limit_base, int limit, float zero_value = 0.0f) +{ + if(limit_base + ITEMS <= limit) + reinterpret_cast(local)[0] = reinterpret_cast(buffer)[idx/ITEMS]; + else + { + for(int k = 0; k < ITEMS; k++) + { + if(limit_base + k < limit) + local[k] = buffer[idx+k]; + else + local[k] = (T)zero_value; + } + } +} + +#define WARPS 5 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) +{ + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + const int val_per_iter = blockDim.x-32; + + T local_A[4]; + T local_B[128]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + //local_A[0] = A[idx]; + + //#pragma unroll 32 + //for(int col = 0; col < 32; col++) + // local_B[col] = B[(col_offset+col)*ldb+idx]; + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+(1*val_per_iter)]; + local_A[2] = A[idx+(2*val_per_iter)]; + local_A[3] = A[idx+(3*val_per_iter)]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B[col] = B[(col_offset+col)*ldb+idx]; + local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; + local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; + local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; + } + loaded_values = 3; + + } + else + { + + if(loaded_values == 3) + { + local_A[0] = local_A[1]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(32)]; + } + else if(loaded_values == 2) + { + local_A[0] = local_A[2]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(64)]; + } + else + { + local_A[0] = local_A[3]; + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = local_B[col+(96)]; + } + loaded_values--; + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +} + +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int col_offset = blockIdx.x *32; + const int warp_id = threadIdx.x / 32; + const int half_warp_id = threadIdx.x / 16; + const int half_warp_lane = threadIdx.x % 16; + const int batch_size_warps = (WARPS-1)*2; + + T local_A[2]; + T local_B[64]; + unsigned char local_B_4bit[32]; + + const int a_tile_offset = 16; + const int b_tile_offset = (16*32 + 16); + + __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; + __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; + //__shared__ T smem_C[8*32]; + + wmma::fragment a_frag; + wmma::fragment b_frag; + wmma::fragment c_frag; + wmma::fill_fragment(c_frag, 0.0f); + + int ticktock = 0; + int idx = 0 + threadIdx.x; + int loaded_values = 0; + // prefetch + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) + { + idx = base_idx + threadIdx.x; + + __syncthreads(); + if(idx < K && warp_id < (WARPS-1)) + { + if(loaded_values == 0) + { + local_A[0] = A[idx]; + local_A[1] = A[idx+blockDim.x-32]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + { + local_B_4bit[col] = B[(col_offset+col)*ldb+idx]; + local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx]; + } + + loaded_values = 1; + } + else + { + local_A[0] = local_A[1]; + loaded_values--; + + int absidx = (idx + col_offset)/blocksize; + half local_absmax = __ldg(&(absmax[absidx])); + + #pragma unroll 64 + for(int col = 0; col < 64; col+=2) + { + local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx); + local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx); + } + } + + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; + } + else if(warp_id < (WARPS-1)) + { + local_A[0] = T(0.0); + smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + local_B[col] = 0.0f; + + #pragma unroll 32 + for(int col = 0; col < 32; col++) + smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; + } + ticktock = ticktock == 0 ? 1 : 0; + + if(warp_id == (WARPS-1)) + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + } + + __syncthreads(); + if(warp_id != (WARPS-1)){ return; } + // only warp_id == (WARPS-1) from here + int warp_lane = threadIdx.x % 32; + + ticktock = ticktock == 0 ? 1 : 0; + for(int k = 0; k < batch_size_warps; k++) + { + wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu + wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu + wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); + } + + // 129 mu + if(warp_id == (WARPS-1)) + wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + + if(col_offset + warp_lane < M) + out[col_offset + warp_lane] = smem_A[warp_lane]; +} + +//#define ROWS 2 +//template __global__ void gemm_device(int M, int N, int K, T const* A, T* B, T * out, int lda, int ldb, int ldc) +//{ +//// 0. We want to fill a 8x128 tile for a thread block so we have 8x16 tile for each warp +//// 1. Load dataB into register +//// 2. Dequantize B +//// 3. Fetch data from A and multiply +// +// typedef cub::BlockLoad LoadA; +// //__shared__ typename LoadA::TempStorage loada; +// typedef cub::BlockLoad LoadB; +// //__shared__ typename LoadB::TempStorage loadb; +// typedef cub::BlockReduce BlockReduce; +// // Allocate shared memory for BlockReduce +// //__shared__ typename BlockReduce::TempStorage reduce; +// +// __shared__ union { +// typename BlockReduce::TempStorage reduce; +// typename LoadB::TempStorage loadb; +// typename LoadA::TempStorage loada; +// } temp_storage; +// +// +// T dataA[ITEMS]; +// T local_B[ITEMS]; +// T local_accC[ROWS]; +// int valid_items = 0; +// const int col_offset = blockIdx.x * 8; +// +// __shared__ T tileA[ROWS*THREADS*ITEMS]; +// __shared__ T accumulatorC[ROWS*8]; +// +// //#pragma unroll 8 +// //for(int i = 0; i < 8; i++) +// // tileA[threadIdx.x + (i*256)] = 0.0f; +// //__syncthreads(); +// if(threadIdx.x < 64) +// accumulatorC[threadIdx.x] = 0.0f; +// __syncthreads(); +// +// +// for(int inner_idx = 0; inner_idx < K; inner_idx+= THREADS*ITEMS) +// { +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// int baserow = 0; +// for(int row = baserow; row < (baserow+ROWS) && row < N; row++) +// { +// LoadA(temp_storage.loada).Load(&(A[(row*K) + inner_idx]), dataA, valid_items, 0.0f); +// +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// tileA[row*THREADS*ITEMS + threadIdx.x + (k*THREADS)] = dataA[k]; +// +// __syncthreads(); +// } +// baserow += ROWS; +// +// // load 16 columns from B at a time. B is transposed, so its like loading rows +// // each warp loads one row +// // each thread loads 128 byte +// +// // col: inner_idx + warp_lane +// // row: ldb*(offset + warp_id) +// for(int col = 0; col < 8 && (col_offset + col) < M; col++) +// { +// int colB = col_offset + col; +// +// for(int k = 0; k < ROWS; k++) +// local_accC[k] = 0.0f; +// +// int base_idxB = ldb*colB; +// valid_items = K - inner_idx > THREADS*ITEMS ? THREADS*ITEMS : K - inner_idx; +// LoadB(temp_storage.loadb).Load(&(B[base_idxB + inner_idx]), local_B, valid_items, 0.0f); +// __syncthreads(); +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// #pragma unroll ITEMS +// for(int k = 0; k < ITEMS; k++) +// { +// int idxA = row*THREADS*ITEMS + threadIdx.x + (THREADS*k); +// local_accC[row] += tileA[idxA]*local_B[k]; +// } +// +// local_accC[row] = BlockReduce(temp_storage.reduce).Reduce(local_accC[row], cub::Sum()); +// if(threadIdx.x == 0) +// atomicAdd(&accumulatorC[row*8 + col], local_accC[row]); +// } +// } +// } +// +// for(int row = 0; row < ROWS && row < N; row++) +// { +// int out_idx = ldc*row + col_offset; +// +// //if(threadIdx.x < 8) +// // if(accumulatorC[row*8 + threadIdx.x] != 0.0) +// // printf("%i %i %i %i %f idx %i %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx, blockIdx.x); +// +// if(threadIdx.x < 8 && (col_offset + threadIdx.x) < M) +// { +// //printf("%i %i %i %i %f idx %i %i\n", row, col_offset, threadIdx.x, N, accumulatorC[row*8 + threadIdx.x], ldc, out_idx); +// out[out_idx + threadIdx.x] = accumulatorC[row*8 + threadIdx.x]; +// } +// } +// +// +// +//} + + +template __global__ void kfunc(T *A, T *B, T value, long n) +{ + for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x)) + { + switch(FUNC) + { + case FILL: + A[i] = (T)value; + break; + case ARANGE: + A[i] = (T)i; + break; + case _MUL: + A[i] = A[i]*B[i]; + break; + } + } +} + + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(unsigned char *A, unsigned char *B, unsigned char value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); +template __global__ void kfunc(float *A, float *B, float value, long n); + +// these are not used and make no sense, but the compiler needs them +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +// these are not used and make no sense, but the compiler needs them + +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +//template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); + +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL32>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); +template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); +template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); +template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); + +template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n); +template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n); + +#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit1State(gtype* g, gtype* p, \ + float* state1, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half) +MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, half) +MAKE_PreconditionOptimizer32bit1State(RMSPROP, float) +MAKE_PreconditionOptimizer32bit1State(LION, half) +MAKE_PreconditionOptimizer32bit1State(LION, float) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half) +MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float) + +#define MAKE_Optimizer32bit1State(oname, gtype) \ +template __global__ void kOptimizer32bit1State(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_Optimizer32bit1State(MOMENTUM, half) +MAKE_Optimizer32bit1State(MOMENTUM, float) +MAKE_Optimizer32bit1State(RMSPROP, half) +MAKE_Optimizer32bit1State(RMSPROP, float) +MAKE_Optimizer32bit1State(LION, half) +MAKE_Optimizer32bit1State(LION, float) +MAKE_Optimizer32bit1State(ADAGRAD, half) +MAKE_Optimizer32bit1State(ADAGRAD, float) + +#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ + float* state1, float* state2, float *unorm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const int n); \ + +MAKE_PreconditionOptimizer32bit2State(ADAM, float) +MAKE_PreconditionOptimizer32bit2State(ADAM, half) +MAKE_PreconditionOptimizer32bit2State(ADAM, __nv_bfloat16) + +template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +#define MAKE_PreconditionStatic8bit1State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ + float *unorm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + const float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit1State(MOMENTUM, half) +MAKE_PreconditionStatic8bit1State(MOMENTUM, float) +MAKE_PreconditionStatic8bit1State(RMSPROP, half) +MAKE_PreconditionStatic8bit1State(RMSPROP, float) +MAKE_PreconditionStatic8bit1State(LION, half) +MAKE_PreconditionStatic8bit1State(LION, float) + +#define MAKE_optimizerStatic8bit1State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit1State(gtype* p, gtype* const g, unsigned char* state1, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, \ + const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* max1, float* new_max1, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit1State(MOMENTUM, half) +MAKE_optimizerStatic8bit1State(MOMENTUM, float) +MAKE_optimizerStatic8bit1State(RMSPROP, half) +MAKE_optimizerStatic8bit1State(RMSPROP, float) +MAKE_optimizerStatic8bit1State(LION, half) +MAKE_optimizerStatic8bit1State(LION, float) + +#define MAKE_PreconditionStatic8bit2State(oname, gtype) \ +template __global__ void kPreconditionOptimizerStatic8bit2State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \ + float *unorm, \ + const float beta1, const float beta2, \ + const float eps, const int step, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_PreconditionStatic8bit2State(ADAM, half) +MAKE_PreconditionStatic8bit2State(ADAM, float) + +#define MAKE_optimizerStatic8bit2State(oname, gtype) \ +template __global__ void kOptimizerStatic8bit2State(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \ + const float *unorm, const float max_unorm, const float param_norm, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, \ + const int n); \ + +MAKE_optimizerStatic8bit2State(ADAM, half) +MAKE_optimizerStatic8bit2State(ADAM, float) + +template __global__ void kPercentileClipping(float * __restrict__ g, float *gnorm_vec, int step, const int n); +template __global__ void kPercentileClipping(half * __restrict__ g, float *gnorm_vec, int step, const int n); + +#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \ +template __global__ void kQuantizeBlockwise(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \ + +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4) +MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4) +MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4) +MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4) + +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n); + +#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ + float* absmax1, float* absmax2, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8) +MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8) + + +#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \ +template __global__ void kOptimizerStatic8bit1StateBlockwise( \ + gtype* p, gtype* __restrict__ const g, unsigned char* state1, \ + const float beta1, const float beta2, \ + const float eps, const int step, const float lr, \ + float* __restrict__ const quantiles1, \ + float* absmax1, \ + float weight_decay, \ + const float gnorm_scale, const bool skip_zeros, const int n); \ + +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8) +MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8) diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh new file mode 100644 index 000000000..30faf4a80 --- /dev/null +++ b/csrc/kernels.cuh @@ -0,0 +1,130 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include + +#ifndef kernels +#define kernels + +//template __global__ void kMatmul_inference_4bit(INP_TYPE *A, unsigned char *B, OUT_TYPE *out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n); + +__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n); +__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n); + +template __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); +template __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n); + +template +__global__ void kPreconditionOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit2State(T* g, T* p, + float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void kPreconditionOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const int n); + +template +__global__ void kOptimizer32bit1State(T* g, T* p, + float* state1, float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +template +__global__ void +kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + const float weight_decay, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* max1, float* new_max1, + float weight_decay, const float gnorm_scale, const int n); + + + +template +__global__ void +kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, + float *unorm, + const float beta1, const float beta2, + const float eps, const int step, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + const float gnorm_scale, const int n); + + +template +__global__ void +kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2, + const float *unorm, const float max_unorm, const float param_norm, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, const float gnorm_scale, const int n); + +template __global__ void kOptimizerStatic8bit2StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + const float beta1, const float beta2, const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, + float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); + +template __global__ void kOptimizerStatic8bit1StateBlockwise( + T* p, T* __restrict__ const g, unsigned char* state1, + const float beta1, const float beta2, + const float eps, const int step, const float lr, + float* __restrict__ const quantiles1, + float* absmax1, + float weight_decay, + const float gnorm_scale, const bool skip_zeros, const int n); + + +template __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n); + +__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + + +template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template __global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); + +template __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + +template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); + +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); + +template __global__ void kfunc(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/ops.cu b/csrc/ops.cu new file mode 100644 index 000000000..7f3a83152 --- /dev/null +++ b/csrc/ops.cu @@ -0,0 +1,846 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include + + +using namespace BinSearch; +using std::cout; +using std::endl; + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n) +{ + int threads = 512; + int num_blocks = n/threads; + num_blocks = n % threads == 0 ? num_blocks : num_blocks + 1; + kHistogramScatterAdd2D<<>>(histogram, index1, index2, src, maxidx1, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void estimateQuantiles(T *A, float *code, float offset, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(code, 0, 256*sizeof(float))); + kEstimateQuantiles<<>>(A, code, offset, std::numeric_limits::max(), n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void quantize(float *code, float *A, unsigned char *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + kQuantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void dequantize(float *code, unsigned char *A, float *out, int n) +{ + int num_blocks = n/1024; + num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; + kDequantize<<>>(code, A, out, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + + if(blocksize == 4096) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 2048) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 1024) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 512) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 256) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 128) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + else if(blocksize == 64) + kQuantizeBlockwise<<>>(code, A, absmax, out, rand, rand_offset, n); + + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +{ + int num_blocks = n/blocksize; + num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; + int tile_size = (DATA_TYPE > 0) ? 1024 : 512; + + if(DATA_TYPE > 0) + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + else + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + + +//void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB) +//{ +// int num_blocks = (colsB+32-1)/32; +// kMatmul_inference_4bit<<>>(A, B, out, lda, ldb, rowsA, colsA, colsB); +// CUDA_CHECK_RETURN(cudaPeekAtLastError()); +//} + + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + const float beta1, const float beta2, const float eps, const float weight_decay, + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + switch(OPTIMIZER) + { + case ADAM: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case LION: + // in lion, the momentum update after the parameter update + kOptimizer32bit1State<<>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + if(max_unorm > 0.0f) + { + CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); + kPreconditionOptimizer32bit1State<<>>(g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + break; + } +} + +template void optimizerStatic8bit(T* p, T* g, + unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n) +{ + int num_blocks = n/4096; + num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1; + + if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); } + + switch(OPTIMIZER) + { + case ADAM: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + CUDA_CHECK_RETURN(cudaMemset(new_max2, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit2State<<>>(p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case LION: + // in lion, the momentum update happens after the parameter update + kOptimizerStatic8bit1State<<>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, + quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); + kPreconditionOptimizerStatic8bit1State<<>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + default: + break; + } +} + +#define BLOCKSIZE_2STATE 2048 +#define NUM_2STATE 8 +#define BLOCKSIZE_1STATE 2048 +#define NUM_1STATE 8 + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) +{ + + int num_blocks = 0; + switch(OPTIMIZER) + { + case ADAM: + num_blocks = n/BLOCKSIZE_2STATE; + num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; + kOptimizerStatic8bit2StateBlockwise<<>>(p, g, state1, state2, beta1, beta2, eps, step, lr, + quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + case MOMENTUM: + case RMSPROP: + case ADAGRAD: + case LION: + num_blocks = n/BLOCKSIZE_1STATE; + num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1; + kOptimizerStatic8bit1StateBlockwise<<>>(p, g, state1, beta1, beta2, eps, step, lr, + quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + break; + } +} + + + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n) +{ + int num_blocks = n/2048; + num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1; + CUDA_CHECK_RETURN(cudaMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float))); + kPercentileClipping<<>>(g, gnorm_vec, step, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + cublasStatus_t status; + + status = cublasGemmEx(context->m_handle, + transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, + transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, + m, n, k, + alpha, A, CUDA_R_8I, lda, B, CUDA_R_8I, ldb, beta, + C, CUDA_R_32I, ldc, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + if (status != CUBLAS_STATUS_SUCCESS) + { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } + +} + +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount) +{ + const int falpha = 1; + const int fbeta = 0; + const void * alpha = &falpha; + const void * beta = &fbeta; + cublasStatus_t status; + + //cout << transposeA << transposeB << endl; + //printf("%i %i %i\n", m,n,k); + //printf("%i %i %i\n", lda,ldb,ldc); + //printf("%i %i %i\n", strideA, strideB, strideC); + //printf("%i\n", batchCount); + + status = cublasGemmStridedBatchedEx(context->m_handle, + transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, + transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, + m, n, k, + alpha, A, CUDA_R_8I, lda, (long long int)strideA, B, CUDA_R_8I, ldb, (long long int)strideB, beta, + C, CUDA_R_32I, ldc, (long long int)strideC, batchCount, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT); + + if (status != CUBLAS_STATUS_SUCCESS) + { + std::cout << "CUBLAS ERROR: Status " << status << std::endl; + } + +} + +int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + + +#ifdef NO_CUBLASLT +#else +template cublasLtOrder_t get_order() +{ + switch(ORDER) + { + case ROW: + return CUBLASLT_ORDER_ROW; + break; + case COL: + return CUBLASLT_ORDER_COL; + break; + case COL32: + return CUBLASLT_ORDER_COL32; + break; + case COL_TURING: + return CUBLASLT_ORDER_COL4_4R2_8C; + break; + case COL_AMPERE: + return CUBLASLT_ORDER_COL32_2R_4R4; + break; + default: + break; + } + + return CUBLASLT_ORDER_ROW; +} + +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +template cublasLtOrder_t get_order(); +#endif + + +template int get_leading_dim(int dim1, int dim2) +{ + switch(ORDER) + { + case ROW: + return dim2; + break; + case COL: + return dim1; + break; + case COL32: + // 32*row tiles + return dim1*32; + break; + case COL_TURING: + return 32*roundoff(dim1, 8); + break; + case COL_AMPERE: + // 32*32 tiles + return 32*roundoff(dim1, 32); + break; + default: + return 0; + break; + } +} + +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); +template int get_leading_dim(int dim1, int dim2); + +template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2) +{ +#ifdef NO_CUBLASLT +#else + cublasLtOrder_t orderA = get_order(); + cublasLtOrder_t orderOut = get_order(); + int ldA = get_leading_dim(dim1, dim2); + int ldOut = get_leading_dim(dim1, dim2); + + cublasLtMatrixLayout_t A_desc = NULL, out_desc = NULL; + cublasLtMatrixTransformDesc_t A2Out_desc = NULL; + cublasOperation_t opTranspose = CUBLAS_OP_T; + float transformAlpha = 1.0f, transformBeta = 0.0f; + + + if(DTYPE == 8) + { + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_8I, dim1, dim2, ldA)); + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_8I, dim1, dim2, ldOut)); + } + else if(DTYPE == 32) + { + checkCublasStatus(cublasLtMatrixLayoutCreate(&A_desc, CUDA_R_32I, dim1, dim2, ldA)); + checkCublasStatus(cublasLtMatrixLayoutCreate(&out_desc, CUDA_R_32I, dim1, dim2, ldOut)); + } + else + { + printf("ERROR WRONG TYPE FOR TRANSFORM: %i\n", DTYPE); + } + + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(A_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderA, sizeof(orderA))); + checkCublasStatus(cublasLtMatrixLayoutSetAttribute(out_desc, CUBLASLT_MATRIX_LAYOUT_ORDER, &orderOut, sizeof(orderOut))); + + checkCublasStatus(cublasLtMatrixTransformDescCreate(&A2Out_desc, CUDA_R_32F)); + + if(transpose){ checkCublasStatus(cublasLtMatrixTransformDescSetAttribute(A2Out_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &opTranspose, sizeof(opTranspose))); } + + checkCublasStatus(cublasLtMatrixTransform(ltHandle, A2Out_desc, &transformAlpha, A, A_desc, &transformBeta, NULL, NULL, out, out_desc, 0)); + + if (A_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(A_desc)); + if (out_desc) checkCublasStatus(cublasLtMatrixLayoutDestroy(out_desc)); + if (A2Out_desc) checkCublasStatus(cublasLtMatrixTransformDescDestroy(A2Out_desc)); +#endif +} + +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int8_t *A, int8_t *out, int dim1, int dim2); +template void transform(cublasLtHandle_t ltHandle, int32_t *A, int32_t *out, int dim1, int dim2); + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) +{ +#ifdef NO_CUBLASLT + cout << "" << endl; + cout << "=============================================" << endl; + cout << "ERROR: Your GPU does not support Int8 Matmul!" << endl; + cout << "=============================================" << endl; + cout << "" << endl; + assert(false); + + return 0; +#else + int has_error = 0; + cublasLtMatmulDesc_t matmulDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; + cublasOperation_t opT = CUBLAS_OP_T; + cublasLtPointerMode_t alphaVec = CUBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_ZERO; + cublasLtOrder_t col32 = CUBLASLT_ORDER_COL32; + cublasLtOrder_t col_turing = CUBLASLT_ORDER_COL4_4R2_8C; + cublasLtOrder_t col_ampere = CUBLASLT_ORDER_COL32_2R_4R4; + + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Adesc, CUDA_R_8I, m, k, lda)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Bdesc, CUDA_R_8I, n, k, ldb)); + + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(FORMATB == COL_TURING) + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_turing, sizeof(col_turing))); + else + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col_ampere, sizeof(col_ampere))); + + if(DTYPE_OUT == 32) + { + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32I)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_32I, m, n, ldc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + int alpha = 1, beta = 0; + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + else + { + has_error |= checkCublasStatus(cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F)); + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT))); + has_error |= checkCublasStatus(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_8I, m, n, ldc)); + has_error |= checkCublasStatus(cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32))); + if(!SCALE_ROWS) + { + float alpha = 1.0f, beta = 0.0f; + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + else + { + has_error |= checkCublasStatus(cublasLtMatmulDescSetAttribute(matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &alphaVec, sizeof(alphaVec))); + has_error |= checkCublasStatus(cublasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, NULL, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, NULL, NULL, 0, 0)); + } + } + + + if (Cdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Cdesc)); + if (Bdesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Bdesc)); + if (Adesc) has_error |= checkCublasStatus(cublasLtMatrixLayoutDestroy(Adesc)); + if (matmulDesc) has_error |= checkCublasStatus(cublasLtMatmulDescDestroy(matmulDesc)); + if(has_error == 1) + printf("error detected"); + + return has_error; +#endif +} + +int fill_up_to_nearest_multiple(int value, int multiple) +{ + return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); +} + +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols) +{ + int threads = 512; + int tileCols = fill_up_to_nearest_multiple(numCols, 32); + int n = numRows*tileCols; + int subtile_rows = 128; + int tilesize = 32*subtile_rows; + int num_blocks = numRows/subtile_rows; + num_blocks += (numRows % subtile_rows == 0) ? 0 : 1; + num_blocks = num_blocks*(tileCols/32); + assert(threads <= tilesize); + + kdequant_mm_int32_fp16<4, 128, 512><<>>(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +#define STATS_THREADS 64 +#define STATS_ITEMS 4 +#define STATS_ROWS 16 +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) +{ + int tile_cols = STATS_THREADS*STATS_ITEMS; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, STATS_ROWS); + int row_tiles = (tiledRows/STATS_ROWS); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + if(nnz_threshold == 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + else if(nnz_threshold != 0.0) + kgetColRowStats<<>>(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + +} + +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols) +{ + int threads = 64; + int items_per_thread = 4; + int tile_cols = threads*items_per_thread; + int tile_rows = 16; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + + if(threshold > 0.0f) + kDoubleRowColQuant<64, 4, 16, 64*4, 1><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + else + kDoubleRowColQuant<64, 4, 16, 64*4, 0><<>>(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_block_ptr, threshold, rows, cols, tiledCols); + + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +template void transformRowToFormat(char * A, char *out, int rows, int cols) +{ + int threads = 256; + int items_per_thread = 8; + // we load 128 column values per warp + int tile_cols = 32*items_per_thread; + int tile_rows = 32; + int tiledCols = fill_up_to_nearest_multiple(cols, tile_cols); + int tiledRows = fill_up_to_nearest_multiple(rows, tile_rows); + int row_tiles = (tiledRows/tile_rows); + int col_tiles = (tiledCols/tile_cols); + row_tiles = row_tiles > 0 ? row_tiles : 1; + col_tiles = col_tiles > 0 ? col_tiles : 1; + int num_blocks = row_tiles * col_tiles; + + int outCols = fill_up_to_nearest_multiple(cols, 32); + int outRows = fill_up_to_nearest_multiple(rows, 32); + if(FORMAT == COL_TURING) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 8); + else + outRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + if(TRANSPOSE) + outRows = fill_up_to_nearest_multiple(cols, 32); + else + outRows = fill_up_to_nearest_multiple(rows, 32); + } + else + { + if(TRANSPOSE) + { + outCols = fill_up_to_nearest_multiple(rows, 32); + outRows = cols; + } + } + + kTransformRowToFormat<256, 8, 32, 32*8, TRANSPOSE, FORMAT><<>>(A, out, rows, cols, tiledCols, outRows, outCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) +{ + +#ifdef NO_CUBLASLT +#else + + cusparseSpMatDescr_t descA; + cusparseDnMatDescr_t descB, descC; + + float alpha = 1.0f; + float beta = 0.0f; + void *dBuffer = NULL; + size_t bufferSize = 0; + + CHECK_CUSPARSE( cusparseCreateCoo(&descA, A_rows, A_cols, A_nnz, + A_rowidx, A_colidx, A_vals, + CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, CUDA_R_16F) ); + // Create dense matrix C + CHECK_CUSPARSE( cusparseCreateDnMat(&descC, A_rows, B_cols, ldc, C, + CUDA_R_16F, CUSPARSE_ORDER_ROW) ); + // Create dense matrix B + if(transposed_B) + { + int tmp = A_cols; + A_cols = B_cols; + B_cols = tmp; + } + + CHECK_CUSPARSE( cusparseCreateDnMat(&descB, A_cols, B_cols, ldb, B, + CUDA_R_16F, CUSPARSE_ORDER_ROW) ); + // allocate an external buffer if needed + CHECK_CUSPARSE( cusparseSpMM_bufferSize( + handle, + CUSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, CUDA_R_32F, + CUSPARSE_SPMM_ALG_DEFAULT, &bufferSize) ); + CUDA_CHECK_RETURN( cudaMalloc(&dBuffer, bufferSize) ); + + // execute SpMM + CHECK_CUSPARSE( cusparseSpMM(handle, + CUSPARSE_OPERATION_NON_TRANSPOSE, + transposed_B ? CUSPARSE_OPERATION_TRANSPOSE : CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, descA, descB, &beta, descC, CUDA_R_32F, + CUSPARSE_SPMM_ALG_DEFAULT, dBuffer)); + + // destroy matrix/vector descriptors + CHECK_CUSPARSE( cusparseDestroySpMat(descA) ); + CHECK_CUSPARSE( cusparseDestroyDnMat(descB) ); + CHECK_CUSPARSE( cusparseDestroyDnMat(descC) ); + CUDA_CHECK_RETURN( cudaFree(dBuffer) ); +#endif +} + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ + + kspmm_coo_very_sparse_naive<<>>(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols) +{ + int threads = 256; + // we load 128 column values per warp + int tiledCols = tiledCols = fill_up_to_nearest_multiple(cols, 32); + int tiledRows = 0; + + int num_blocks = idx_size; + + if(FORMAT == COL_TURING) + { + tiledRows = fill_up_to_nearest_multiple(rows, 8); + } + else if(FORMAT == COL_AMPERE) + { + tiledRows = fill_up_to_nearest_multiple(rows, 32); + } + + kExtractOutliers<<>>(A, idx, out, idx_size, rows, cols, tiledRows, tiledCols); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + + + + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + //if(bits == 32) + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + if(bits == 16) + //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); + //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); +} + +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) +{ + + int num_blocks = (m+31)/32; + + //cout << num_blocks << endl; + //cout << lda << endl; + //cout << ldb << endl; + //cout << ldc << endl; + + //cout << m << endl; + //cout << n << endl; + //cout << k << endl; + kgemm_4bit_inference<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); + //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); +} + +template void func(T *A, T *B, T value, long n) +{ + int threads = 512; + int blocks = n/threads; + blocks = n % threads == 0 ? blocks : blocks + 1; + blocks = blocks > 65535 ? 65535 : blocks; + kfunc<<>>(A, B, value, n); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); +} + +//============================================================== +// TEMPLATE DEFINITIONS +//============================================================== + +template void func(float *A, float *B, float value, long n); +template void func(unsigned char *A, unsigned char *B, unsigned char value, long n); +template void func(float *A, float *B, float value, long n); +template void func(float *A, float *B, float value, long n); + +template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); +//template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +template void estimateQuantiles(half *A, float *code, float offset, int n); +template void estimateQuantiles(float *A, float *code, float offset, int n); + +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void quantizeBlockwise(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); + +#define MAKE_optimizer32bit(name, gtype) \ +template void optimizer32bit(gtype* g, gtype* p, \ + float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); + +MAKE_optimizer32bit(ADAM, half) +MAKE_optimizer32bit(ADAM, float) +MAKE_optimizer32bit(ADAM, __nv_bfloat16) +MAKE_optimizer32bit(MOMENTUM, half) +MAKE_optimizer32bit(MOMENTUM, float) +MAKE_optimizer32bit(RMSPROP, half) +MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(LION, half) +MAKE_optimizer32bit(LION, float) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) + +#define MAKE_optimizerStatic8bit(name, gtype) \ +template void optimizerStatic8bit(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, \ + const float gnorm_scale, int n); \ + +MAKE_optimizerStatic8bit(ADAM, half) +MAKE_optimizerStatic8bit(ADAM, float) +MAKE_optimizerStatic8bit(MOMENTUM, half) +MAKE_optimizerStatic8bit(MOMENTUM, float) +MAKE_optimizerStatic8bit(RMSPROP, half) +MAKE_optimizerStatic8bit(RMSPROP, float) +MAKE_optimizerStatic8bit(LION, half) +MAKE_optimizerStatic8bit(LION, float) + +#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ +template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ + +MAKE_optimizerStatic8bitBlockwise(half, ADAM); +MAKE_optimizerStatic8bitBlockwise(float, ADAM); +MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); +MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, LION); +MAKE_optimizerStatic8bitBlockwise(float, LION); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); + +template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); +template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); + +MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM); diff --git a/csrc/ops.cuh b/csrc/ops.cuh new file mode 100644 index 000000000..5b9a32b74 --- /dev/null +++ b/csrc/ops.cuh @@ -0,0 +1,206 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + + +#ifndef ops_H +#define ops_H + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + + + +#define CUDA_CHECK_RETURN(value) { \ + cudaError_t _m_cudaStat = value; \ + if (_m_cudaStat != cudaSuccess) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + cudaGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + +#define THREADS_PER_BLOCKS (512) + +#define CHECK_CUSPARSE(value) { \ + cusparseStatus_t _m_cudaStat = value; \ + if (_m_cudaStat != CUSPARSE_STATUS_SUCCESS) { \ + fprintf(stderr, "Error %s at line %d in file %s\n", \ + cusparseGetErrorString(_m_cudaStat), __LINE__, __FILE__); \ + exit(1); \ + } } + + +#define THREADS_PER_BLOCKS (512) + + +inline void checkCudaStatus(cudaError_t status) { + if (status != cudaSuccess) { + printf("cuda API failed with status %d: %s\n", status, cudaGetErrorString(status)); + throw std::logic_error("cuda API failed"); + } +} + +inline int checkCublasStatus(cublasStatus_t status) { + if (status != CUBLAS_STATUS_SUCCESS) { + printf("cuBLAS API failed with status %d\n", status); + //throw std::logic_error("cuBLAS API failed"); + return 1; + } + return 0; +} + +typedef enum Operations_t +{ + ksmul = 0, +} Operations_t; + +typedef enum Optimizer_t +{ + ADAM = 0, + MOMENTUM = 1, + RMSPROP = 2, + LARS = 3, + ADAGRAD = 4, + LION = 5, +} Optimizer_t; + +typedef enum Transform_t +{ + ROW = 0, + COL = 1, + COL32 = 2, + COL_TURING = 3, + COL_AMPERE = 4, +} Transform_t; + +typedef enum DataType_t +{ + General8bit = 0, + FP4 = 1, + NF4 = 2, +} DataType_t; + +typedef enum Funcs_t +{ + FILL = 0, + ARANGE = 1, + _MUL = 2, +} Funcs_t; + +class Context +{ + public: + cublasHandle_t m_handle; + + Context() + { + cublasHandle_t handle; + cublasCreate_v2(&handle); + m_handle = handle; + } + +}; + +class ContextLt +{ + public: + cublasLtHandle_t m_handle; + + ContextLt() + { + cublasLtHandle_t handle; + cublasLtCreate(&handle); + m_handle = handle; + } + +}; + +class ContextCusparse +{ + public: + cusparseHandle_t m_handle; + + ContextCusparse() + { + cusparseHandle_t handle; + cusparseCreate(&handle); + m_handle = handle; + } + +}; + + +template void estimateQuantiles(T *A, float *code, float offset, int n); + +void quantize(float *code, float *A, unsigned char *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n); +template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); + +template void optimizer32bit(T* g, T* p, + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, float eps, float weight_decay, + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); + +template void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, + float *unorm, float max_unorm, float param_norm, + float beta1, float beta2, + float eps, int step, float lr, + float* quantiles1, float* quantiles2, + float* max1, float* max2, float* new_max1, float* new_max2, + float weight_decay, + const float gnorm_scale, int n); + +template void optimizerStatic8bitBlockwise(T* p, T* g, + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); + +template void percentileClipping(T * g, float *gnorm_vec, int step, const int n); + +void histogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n); + +void gemmex(Context * context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long long int strideA, long long int strideB, long long int strideC, int batchCount); + + +template int igemmlt(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc); + +template void transform(cublasLtHandle_t ltHandle, T *A, T *out, int dim1, int dim2); +void cutlass_igemm(bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc); +void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols); +void getColRowStats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols); +void doubleRowColQuant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, + int *rowidx, int *colidx, half *val, int *nnz_block_ptr, float threshold, int rows, int cols); + +template void transformRowToFormat(char * A, char *out, int rows, int cols); + +void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B); + +template void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB); + +template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); + +void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); + +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); + +template void func(T *A, T *B, T value, long n); + +#endif diff --git a/csrc/pythonInterface.c b/csrc/pythonInterface.c new file mode 100644 index 000000000..776497b67 --- /dev/null +++ b/csrc/pythonInterface.c @@ -0,0 +1,370 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +#if BUILD_CUDA +#include +#endif +#include + +// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. +// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to +// maintain all that boilerplate +//=================================================================================== +// UNMANGLED CALLS +//=================================================================================== + +#if BUILD_CUDA +void estimateQuantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } +void estimateQuantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles(A, code, offset, n); } + + +//void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) +//{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 32); } +void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) +{ gemm_host(M, N, K, A, B, out, lda, ldb, ldc, 16); } + +void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) +{ gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + +#define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ +void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ + +MAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) +MAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) +MAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) +MAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + + +#define MAKE_FUNC32(fname, oname, gtype, gbits) \ +void fname##32bit_grad_##gbits(gtype *g, gtype *p, \ + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ +{ optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + +MAKE_FUNC32(momentum, MOMENTUM, float, 32) +MAKE_FUNC32(momentum, MOMENTUM, half, 16) +MAKE_FUNC32(adam, ADAM, float, fp32) +MAKE_FUNC32(adam, ADAM, half, fp16) +MAKE_FUNC32(adam, ADAM, __nv_bfloat16, bf16) +MAKE_FUNC32(rmsprop, RMSPROP, float, 32) +MAKE_FUNC32(rmsprop, RMSPROP, half, 16) +MAKE_FUNC32(lion, LION, float, 32) +MAKE_FUNC32(lion, LION, half, 16) +MAKE_FUNC32(adagrad, ADAGRAD, float, 32) +MAKE_FUNC32(adagrad, ADAGRAD, half, 16) + +#define MAKE_FUNC8(fname, oname, gtype, gbits) \ +void fname##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, float gnorm_scale, int n) \ +{ \ + optimizerStatic8bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ +} \ + +MAKE_FUNC8(adam, ADAM, float, 32) +MAKE_FUNC8(adam, ADAM, half, 16) +MAKE_FUNC8(momentum, MOMENTUM, float, 32) +MAKE_FUNC8(momentum, MOMENTUM, half, 16) +MAKE_FUNC8(rmsprop, RMSPROP, float, 32) +MAKE_FUNC8(rmsprop, RMSPROP, half, 16) +MAKE_FUNC8(lion, LION, float, 32) +MAKE_FUNC8(lion, LION, half, 16) + +#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ +void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ +{ optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ + +MAKE_BLOCKWISE8(adam, ADAM, half, fp16) +MAKE_BLOCKWISE8(adam, ADAM, float, fp32) +MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16) +MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16) +MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16) +MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32) +MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) +MAKE_BLOCKWISE8(lion, LION, half, fp16) +MAKE_BLOCKWISE8(lion, LION, float, fp32) + + +void percentileClipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } +void percentileClipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping(g, gnorm_vec, step, n); } + +void quantizeBlockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(code, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } +void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } + +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } + + +#define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ +void transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(cublasLtHandle_t ltHandle, dtype *A, dtype *out, int dim1, int dim2) \ +{ \ + transform(ltHandle, A, out, dim1, dim2); \ +} \ + +MAKE_FUNC_TRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8); +MAKE_FUNC_TRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8); +MAKE_FUNC_TRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8); +MAKE_FUNC_TRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32); +MAKE_FUNC_TRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8); +MAKE_FUNC_TRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8); +MAKE_FUNC_TRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8); +MAKE_FUNC_TRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32); + +void transform_row2col32(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2col32T(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2turing(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2turingT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2ampere(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } +void transform_row2ampereT(char * A, char *out, int rows, int cols){ transformRowToFormat(A, out, rows, cols); } + +void extractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } +void extractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers(A, idx, out, idx_size, rows, cols); } + + int igemmlt_turing_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_turing_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_turing_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_ampere_32(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_ampere_8(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int igemmlt_ampere_8_rowscale(cublasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + +void spmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + +void spmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) +{ spmm_coo_very_sparse_naive(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } +#endif + +extern "C" +{ +#if BUILD_CUDA + void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } + void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } + void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } + void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } + void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + + void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } + + void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp16_nf4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + + #define MAKE_CFUNC32(name, gtype, gbits) \ + void c##name##32bit_grad_##gbits(gtype *g, gtype *p, \ + float* state1, float* state2, float *unorm, float max_unorm, float param_norm, \ + const float beta1, const float beta2, const float eps, const float weight_decay, \ + const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ + { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ + + MAKE_CFUNC32(adam, float, fp32) + MAKE_CFUNC32(adam, half, fp16) + MAKE_CFUNC32(adam, __nv_bfloat16, bf16) + MAKE_CFUNC32(momentum, float, 32) + MAKE_CFUNC32(momentum, half, 16) + MAKE_CFUNC32(rmsprop, float, 32) + MAKE_CFUNC32(rmsprop, half, 16) + MAKE_CFUNC32(lion, float, 32) + MAKE_CFUNC32(lion, half, 16) + MAKE_CFUNC32(adagrad, float, 32) + MAKE_CFUNC32(adagrad, half, 16) + + #define MAKE_CFUNC8(name, gtype, gbits) \ + void c##name##_static_8bit_grad_##gbits(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ + float *unorm, float max_unorm, float param_norm, \ + float beta1, float beta2, \ + float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, \ + float* max1, float* max2, float* new_max1, float* new_max2, \ + float weight_decay, float gnorm_scale, int n) \ + { \ + name##_static_8bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr, \ + quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n); \ + } \ + + MAKE_CFUNC8(adam, float, 32) + MAKE_CFUNC8(adam, half, 16) + MAKE_CFUNC8(momentum, float, 32) + MAKE_CFUNC8(momentum, half, 16) + MAKE_CFUNC8(rmsprop, float, 32) + MAKE_CFUNC8(rmsprop, half, 16) + MAKE_CFUNC8(lion, float, 32) + MAKE_CFUNC8(lion, half, 16) + + #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ + void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ + { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + + MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) + MAKE_CBLOCKWISE8(adam, ADAM, float, fp32) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16) + MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16) + MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16) + MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32) + MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) + MAKE_CBLOCKWISE8(lion, LION, half, fp16) + MAKE_CBLOCKWISE8(lion, LION, float, fp32) + + void cpercentile_clipping_g32(float * g, float *gnorm_vec, int step, const int n){ percentileClipping_g32(g, gnorm_vec, step, n); } + void cpercentile_clipping_g16(half * g, float *gnorm_vec, int step, const int n){ percentileClipping_g16(g, gnorm_vec, step, n); } + void chistogram_scatter_add_2d(float* histogram, int *index1, int *index2, float *src, int maxidx1, int n){ histogramScatterAdd2D(histogram, index1, index2, src, maxidx1, n); } + + void cigemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc) + { gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc); } + void cbatched_igemm(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc, + long strideA, long strideB, long strideC, int batchCount) + { strided_gemmex(context, transposeA, transposeB, m, n, k, A, B, C, lda, ldb, ldc, strideA, strideB, strideC, batchCount); } + + Context *get_context(){ return new Context(); } + ContextCusparse *get_cusparse(){ return new ContextCusparse(); } + + int cigemmlt_turing_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + //{ (cublasLtHandle_t)context->m_handle; return 0; } + //{ return 0; }//igemmlt_turing_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_turing_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_turing_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_turing_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_turing_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_ampere_32(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_ampere_32((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_ampere_8_rowscale(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_ampere_8_rowscale((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + int cigemmlt_ampere_8(Context *context, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) + { return igemmlt_ampere_8((cublasLtHandle_t) context->m_handle, m, n, k, A, B, C, row_scale, lda, ldb, ldc); } + + #define MAKE_FUNC_CTRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ + void ctransform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose(Context *context, dtype *A, dtype *out, int dim1, int dim2) \ + { \ + transform_##fbits##_##fsrc##_to_##ftrgt##_##ftranspose((cublasLtHandle_t) context->m_handle, A, out, dim1, dim2); \ + } \ + + MAKE_FUNC_CTRANSFORM(8, row, col, n, int8_t, ROW, COL, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, row, n, int8_t, ROW, ROW, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, col32, n, int8_t, ROW, COL32, false, 8) + MAKE_FUNC_CTRANSFORM(32, row, col32, n, int32_t, ROW, COL32, false, 32) + MAKE_FUNC_CTRANSFORM(8, row, col_turing, n, int8_t, ROW, COL_TURING, false, 8) + MAKE_FUNC_CTRANSFORM(8, row, col_ampere, n, int8_t, ROW, COL_AMPERE, false, 8) + MAKE_FUNC_CTRANSFORM(8, col32, row, n, int8_t, COL32, ROW, false, 8) + MAKE_FUNC_CTRANSFORM(32, col32, row, n, int32_t, COL32, ROW, false, 32) + + void cdequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half* bias, int numRows, int numCols) + { dequant_mm_int32_fp16(A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols); } + void cget_col_row_stats(half * A, float *rowStats, float *colStats, int *nnz_count_row, float nnz_threshold, int rows, int cols) + { getColRowStats(A, rowStats, colStats, nnz_count_row, nnz_threshold, rows, cols); } + + void cdouble_rowcol_quant(half * A, float *rowStats, float *colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int *nnz_row_ptr, float threshold, int rows, int cols) + { doubleRowColQuant(A, rowStats, colStats, out_col_normed, out_row_normed, rowidx, colidx, val, nnz_row_ptr, threshold, rows, cols); } + + void ctransform_row2col32(char * A, char *out, int rows, int cols) + { transform_row2col32(A, out, rows, cols); } + + void ctransform_row2col32T(char * A, char *out, int rows, int cols) + { transform_row2col32T(A, out, rows, cols); } + + void ctransform_row2turing(char * A, char *out, int rows, int cols) + { transform_row2turing(A, out, rows, cols); } + + void ctransform_row2turingT(char * A, char *out, int rows, int cols) + { transform_row2turingT(A, out, rows, cols); } + + void ctransform_row2ampere(char * A, char *out, int rows, int cols) + { transform_row2ampere(A, out, rows, cols); } + + void ctransform_row2ampereT(char * A, char *out, int rows, int cols) + { transform_row2ampereT(A, out, rows, cols); } + + void cspmm_coo(ContextCusparse *context, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B) + { spmm_coo((cusparseHandle_t) context->m_handle, A_rowidx, A_colidx, A_vals, A_nnz, A_rows, A_cols, B_cols, ldb, B, ldc, C, transposed_B); } + + void cspmm_coo_very_sparse_naive_fp16(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) + { spmm_coo_very_sparse_naive_fp16(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + + void cspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB) + { spmm_coo_very_sparse_naive_int8(max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz_rows, nnz, rowsA, rowsB, colsB); } + + void cextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); } + void cextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); } + + //void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) + //{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); } + + void cgemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int lda, int ldb, int ldc) + { gemm_host_fp16(M, N, K, A, B, out, lda, ldb, ldc); } + + void cgemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) + { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } + + void *cget_managed_ptr(size_t bytes) + { + void *ptr; + CUDA_CHECK_RETURN(cudaMallocManaged(&ptr, bytes, cudaMemAttachHost)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + + return ptr; + } + + void cprefetch(void *ptr, size_t bytes, int device) + { + CUDA_CHECK_RETURN(cudaMemPrefetchAsync(ptr, bytes, device, 0)); + CUDA_CHECK_RETURN(cudaPeekAtLastError()); + } + + #define CMAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ + void c##fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ fname##_##type_name(A, B, value, n); } \ + + CMAKE_ELEMENTWISE_FUNC(fill, fp32, float, FILL) + CMAKE_ELEMENTWISE_FUNC(fill, uint8, unsigned char, FILL) + CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) + CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) + +#endif + void cquantize_blockwise_cpu_fp32(float *code, float *A, float *absmax, unsigned char *out, long long blocksize, long long n){ quantize_cpu(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_cpu_fp32(float *code, unsigned char *A, float *absmax, float *out, long long blocksize, long long n){ dequantize_cpu(code, A, absmax, out, blocksize, n); } +} diff --git a/cuda_install.sh b/cuda_install.sh new file mode 100644 index 000000000..678f7ca50 --- /dev/null +++ b/cuda_install.sh @@ -0,0 +1,89 @@ +URL92=https://developer.nvidia.com/compute/cuda/9.2/Prod2/local_installers/cuda_9.2.148_396.37_linux +URL100=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux +URL101=https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run +URL102=https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run +URL110=https://developer.download.nvidia.com/compute/cuda/11.0.3/local_installers/cuda_11.0.3_450.51.06_linux.run +URL111=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run +URL112=https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run +URL113=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run +URL114=https://developer.download.nvidia.com/compute/cuda/11.4.4/local_installers/cuda_11.4.4_470.82.01_linux.run +URL115=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run +URL116=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run +URL117=https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run +URL118=https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run +URL120=https://developer.download.nvidia.com/compute/cuda/12.0.0/local_installers/cuda_12.0.0_525.60.13_linux.run +URL121=https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run + + +CUDA_VERSION=$1 +BASE_PATH=$2 +EXPORT_BASHRC=$3 + +if [[ -n "$CUDA_VERSION" ]]; then + if [[ "$CUDA_VERSION" -eq "92" ]]; then + URL=$URL92 + FOLDER=cuda-9.2 + elif [[ "$CUDA_VERSION" -eq "100" ]]; then + URL=$URL100 + FOLDER=cuda-10.0 + elif [[ "$CUDA_VERSION" -eq "101" ]]; then + URL=$URL101 + FOLDER=cuda-10.1 + elif [[ "$CUDA_VERSION" -eq "102" ]]; then + URL=$URL102 + FOLDER=cuda-10.2 + elif [[ "$CUDA_VERSION" -eq "110" ]]; then + URL=$URL110 + FOLDER=cuda-11.0 + elif [[ "$CUDA_VERSION" -eq "111" ]]; then + URL=$URL111 + FOLDER=cuda-11.1 + elif [[ "$CUDA_VERSION" -eq "112" ]]; then + URL=$URL112 + FOLDER=cuda-11.2 + elif [[ "$CUDA_VERSION" -eq "113" ]]; then + URL=$URL113 + FOLDER=cuda-11.3 + elif [[ "$CUDA_VERSION" -eq "114" ]]; then + URL=$URL114 + FOLDER=cuda-11.4 + elif [[ "$CUDA_VERSION" -eq "115" ]]; then + URL=$URL115 + FOLDER=cuda-11.5 + elif [[ "$CUDA_VERSION" -eq "116" ]]; then + URL=$URL116 + FOLDER=cuda-11.6 + elif [[ "$CUDA_VERSION" -eq "117" ]]; then + URL=$URL117 + FOLDER=cuda-11.7 + elif [[ "$CUDA_VERSION" -eq "118" ]]; then + URL=$URL118 + FOLDER=cuda-11.8 + elif [[ "$CUDA_VERSION" -eq "120" ]]; then + URL=$URL120 + FOLDER=cuda-12.0 + elif [[ "$CUDA_VERSION" -eq "121" ]]; then + URL=$URL121 + FOLDER=cuda-12.1 + else + echo "argument error: No cuda version passed as input. Choose among versions 92 to 121" + fi +else + echo "argument error: No cuda version passed as input. Choose among versions 92 to 112" +fi + +FILE=$(basename $URL) + +if [[ -n "$CUDA_VERSION" ]]; then + echo $URL + echo $FILE + #wget $URL + bash $FILE --no-drm --no-man-page --override --toolkitpath=$BASE_PATH/$FOLDER/ --toolkit --silent + if [ "$EXPORT_BASHRC" -eq "1" ]; then + echo "export LD_LIBRARY_PATH=\$LD_LIBRARY_PATH:$BASE_PATH/$FOLDER/lib64" >> ~/.bashrc + echo "export PATH=\$PATH:$BASE_PATH/$FOLDER/bin" >> ~/.bashrc + source ~/.bashrc + fi +else + echo "" +fi diff --git a/deploy.sh b/deploy.sh new file mode 100644 index 000000000..24d6cbf6b --- /dev/null +++ b/deploy.sh @@ -0,0 +1,265 @@ +#!/bin/bash +BASE_PATH=$1 + +echo "MAKE SURE LD_LIBRARY_PATH IS EMPTY!" +echo $LD_LIBRARY_PATH + +if [[ ! -z "${LD_LIBRARY_PATH}" ]]; then + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + + +module unload cuda && echo "no module function available. Probably not on a slurm cluster." +module unload gcc && echo "no module function available. Probably not on a slurm cluster." + +rm -rf dist build +make cleaneggs +make cleanlibs + +make clean +export CUDA_HOME= +export CUDA_VERSION= +make cpuonly CUDA_VERSION="CPU" + +if [ ! -f "./bitsandbytes/libbitsandbytes_cpu.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.0 +make cuda110 CUDA_VERSION=110 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.1 +make cuda11x CUDA_VERSION=111 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.2 +make cuda11x CUDA_VERSION=112 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda112.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.3 +make cuda11x CUDA_VERSION=113 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda113.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.4 +make cuda11x CUDA_VERSION=114 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.5 +make cuda11x CUDA_VERSION=115 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.6 + +make cuda11x CUDA_VERSION=116 +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda116.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.7 +make cuda11x CUDA_VERSION=117 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.8 +make cuda12x CUDA_VERSION=118 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-12.0 +make cuda12x CUDA_VERSION=120 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-12.1 +make cuda12x CUDA_VERSION=121 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + + +make clean +export CUDA_HOME=$BASE_PATH/cuda-10.2 +make cuda10x_nomatmul CUDA_VERSION=102 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda102_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.0 +make cuda110_nomatmul CUDA_VERSION=110 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda110_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.1 +make cuda11x_nomatmul CUDA_VERSION=111 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda111_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.2 +make cuda11x_nomatmul CUDA_VERSION=112 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda112_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.3 +make cuda11x_nomatmul CUDA_VERSION=113 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda113_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.4 +make cuda11x_nomatmul CUDA_VERSION=114 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda114_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.5 +make cuda11x_nomatmul CUDA_VERSION=115 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda115_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.6 + +make cuda11x_nomatmul CUDA_VERSION=116 +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda116_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.7 +make cuda11x_nomatmul CUDA_VERSION=117 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda117_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-11.8 +make cuda12x_nomatmul CUDA_VERSION=118 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda118_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-12.0 +make cuda12x_nomatmul CUDA_VERSION=120 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda120_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +make clean +export CUDA_HOME=$BASE_PATH/cuda-12.1 +make cuda12x_nomatmul CUDA_VERSION=121 + +if [ ! -f "./bitsandbytes/libbitsandbytes_cuda121_nocublaslt.so" ]; then + # Control will enter here if $DIRECTORY doesn't exist. + echo "Compilation unsuccessul!" 1>&2 + exit 64 +fi + +python -m build +python -m twine upload dist/* --verbose diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000..93f5b3857 --- /dev/null +++ b/environment.yml @@ -0,0 +1,15 @@ +name: 8-bit +channels: + - conda-forge + - pytorch +dependencies: + - python=3.9 + - pytest + - pytorch + - torchaudio + - torchvision + - cudatoolkit=11.1 + - typer + - ca-certificates + - certifi + - openssl diff --git a/errors_and_solutions.md b/errors_and_solutions.md new file mode 100644 index 000000000..5b8cbcdd5 --- /dev/null +++ b/errors_and_solutions.md @@ -0,0 +1,21 @@ +# No kernel image available + +This problem arises with the cuda version loaded by bitsandbytes is not supported by your GPU, or if you pytorch CUDA version mismatches. To solve this problem you need to debug ``$LD_LIBRARY_PATH``, ``$CUDA_HOME``, ``$PATH``. You can print these via ``echo $PATH``. You should look for multiple paths to different CUDA versions. This can include versions in your anaconda path, for example ``$HOME/anaconda3/lib``. You can check those versions via ``ls -l $HOME/anaconda3/lib/*cuda*`` or equivalent paths. Look at the CUDA versions of files in these paths. Does it match with ``nvidia-smi``? + +If you are feeling lucky, you can also try to compile the library from source. This can be still problematic if your PATH variables have multiple cuda versions. As such, it is recommended to figure out path conflicts before you proceed with compilation. + + +__If you encounter any other error not listed here please create an issue. This will help resolve your problem and will help out others in the future. + + +# fatbinwrap + +This error occurs if there is a mismatch between CUDA versions in the C++ library and the CUDA part. Make sure you have right CUDA in your $PATH and $LD_LIBRARY_PATH variable. In the conda base environment you can find the library under: +```bash +ls $CONDA_PREFIX/lib/*cudart* +``` +Make sure this path is appended to the `LD_LIBRARY_PATH` so bnb can find the CUDA runtime environment library (cudart). + +If this does not fix the issue, please try [compilation from source](compile_from_source.md) next. + +If this does not work, please open an issue and paste the printed environment if you call `make` and the associated error when running bnb. diff --git a/examples/int8_inference_huggingface.py b/examples/int8_inference_huggingface.py new file mode 100644 index 000000000..dc80a44db --- /dev/null +++ b/examples/int8_inference_huggingface.py @@ -0,0 +1,27 @@ +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +MAX_NEW_TOKENS = 128 +model_name = 'decapoda-research/llama-7b-hf' + +text = 'Hamburg is in which country?\n' +tokenizer = AutoTokenizer.from_pretrained(model_name) +input_ids = tokenizer(text, return_tensors="pt").input_ids + +free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3) +max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB' + +n_gpus = torch.cuda.device_count() +max_memory = {i: max_memory for i in range(n_gpus)} + +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map='auto', + load_in_8bit=True, + max_memory=max_memory +) +generated_ids = model.generate(input_ids, max_length=MAX_NEW_TOKENS) +print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) + + + diff --git a/howto_config_override.md b/howto_config_override.md new file mode 100644 index 000000000..55b24e3ab --- /dev/null +++ b/howto_config_override.md @@ -0,0 +1,40 @@ +# How to override config hyperparameters for particular weights/parameters + +If you want to optimize some unstable parameters with 32-bit Adam and others with 8-bit Adam, you can use the `GlobalOptimManager`. With this, we can also configure specific hyperparameters for particular layers, such as embedding layers. To do that, we need two things: (1) register the parameter while they are still on the CPU, (2) override the config with the new desired hyperparameters (anytime, anywhere). See our [guide](howto_config_override.md) for more details + +For global overrides in many different places in your code you can do: +```python +import torch +import bitsandbytes as bnb + +mng = bnb.optim.GlobalOptimManager.get_instance() + +model = MyModel() +mng.register_parameters(model.parameters()) # 1. register parameters while still on CPU + +model = model.cuda() +# use 8-bit optimizer states for all parameters +adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8) + +# 2a. override: the parameter model.fc1.weight now uses 32-bit Adam +mng.override_config(model.fc1.weight, 'optim_bits', 32) + +# 2b. override: the two special layers use +# sparse optimization + different learning rate + different Adam betas +mng.override_config([model.special.weight, model.also_special.weight], + key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)}) +``` +Possible options for the config override are: `betas, eps, weight_decay, lr, optim_bits, min_8bit_size, percentile_clipping, block_wise, max_unorm` + +For overrides for particular layers we recommend overriding locally in each module. You can do this by passing the module, the parameter, and its attribute name to the GlobalOptimManager: +```python +class MyModule(torch.nn.Module): + def __init__(din, dout): + super(MyModule, self).__init__() + self.linear = torch.nn.Linear(din, dout) + # optimization will happen in 32-bit and + # learning rate will be set to 0.0001 independent of the main learning rate + config = {'optim_bits': 32, 'lr' : 0.0001} + GlobalOptimManager.get_instance().register_module_override(self, 'weight', config) + +``` diff --git a/include/AAlloc.h b/include/AAlloc.h new file mode 100644 index 000000000..6c2ae419f --- /dev/null +++ b/include/AAlloc.h @@ -0,0 +1,86 @@ +#pragma once + +#include "Portable.h" + +namespace BinSearch { +namespace Details { + +template +bool isAligned(const T *p, size_t A) +{ + return (reinterpret_cast(p) % A) == 0; +} + +template +struct AlignedVec +{ + AlignedVec() + : m_storage(0) + , m_data(0) + , m_sz(0) + { + } + + static size_t nBytes(size_t sz) + { + return sz * sizeof(T) + A; + } + + static size_t shiftAmt(char *p) + { + return A>1? (A - (reinterpret_cast(p) % A)) % A: 0; + } + + void setPtr(char *p, size_t sz) + { + m_sz = sz; + m_data = reinterpret_cast(p + shiftAmt(p)); + } + + //void setPtr(T *p, size_t sz) + //{ + // m_sz = sz; + // if (A>1) + // myassert(((reinterpret_cast(p) % A) == 0), "bad alignment"); + // m_data = p; + //} + + // internal allocation + void resize(size_t sz) + { + m_storage = new char[nBytes(sz)]; + setPtr(m_storage, sz); + } + + // external allocation + void set(char *storage, size_t sz) + { + setPtr(storage, sz); + } + + ~AlignedVec() + { + if (m_storage) + delete [] m_storage; + } + + size_t size() const { return m_sz; } + T& operator[](size_t i) { return m_data[i]; } + const T& operator[](size_t i) const { return m_data[i]; } + T* begin() { return m_data; } + T* end() { return m_data+m_sz; } + const T* begin() const { return m_data; } + const T* end() const { return m_data+m_sz; } + T& front() { return m_data[0]; } + T& back() { return m_data[m_sz-1]; } + const T& front() const { return m_data[0]; } + const T& back() const { return m_data[m_sz - 1]; } + +private: + char *m_storage; + T *m_data; + size_t m_sz; +}; + +} // namespace Details +} // namespace BinSearch diff --git a/include/Algo-Direct-Common.h b/include/Algo-Direct-Common.h new file mode 100644 index 000000000..c97084904 --- /dev/null +++ b/include/Algo-Direct-Common.h @@ -0,0 +1,341 @@ +#pragma once + +#include +#include +#include +#include "AAlloc.h" + +namespace BinSearch { +namespace Details { + +namespace DirectAux { + +#define SAFETY_MULTI_PASS true + +template +struct HResults +{ + HResults(T h, double ratio, size_t n) : H(h), hRatio(ratio), nInc(n) {} + T H; + double hRatio; + size_t nInc; +}; + + +#ifdef USE_FMA +template struct IsDirect { static const bool value = (A == Direct) || (A == DirectFMA); }; +template struct IsDirect2 { static const bool value = (A == Direct2) || (A == Direct2FMA); }; +template struct IsDirectCache { static const bool value = (A == DirectCache) || (A == DirectCacheFMA); }; +#else +template struct IsDirect { static const bool value = (A == Direct); }; +template struct IsDirect2 { static const bool value = (A == Direct2); }; +template struct IsDirectCache { static const bool value = (A == DirectCache); }; +#endif + +// general definition +template +struct BucketElem +{ + FORCE_INLINE void set( uint32 b, const T *) + { + m_b = b; + } + + FORCE_INLINE uint32 index() const { return m_b; } + +private: + uint32 m_b; +}; + +// specialization for DirectCache methods + +template struct MatchingIntType; +template <> struct MatchingIntType { typedef uint64 type; }; +template <> struct MatchingIntType { typedef uint32 type; }; + +template +struct BucketElem::value >::type > +{ + typedef typename MatchingIntType::type I; + + void set(uint32 b, const T *xi) + { + u.u.x = xi[b]; + u.u.b = b; + } + + FORCE_INLINE I index() const { return u.u.b; } + FORCE_INLINE T x() const { return u.u.x; } + +private: + union { + double dummy; + struct + { + T x; + I b; + } u; + } u; +}; + + +template +struct DirectTraits +{ + static void checkH(T scaler, T x0, T xN) + { + T Dn = xN - x0; + T ifmax = Dn * scaler; + myassert((ifmax < std::numeric_limits::max() - (Gap - 1)), + "Problem unfeasible: index size exceeds uint32 capacity:" + << " D[N] =" << Dn + << ", H =" << scaler + << ", H D[n] =" << ifmax << "\n" + ); + } + + FORCE_INLINE static uint32 f(T scaler, T x0, T z) + { + T tmp = scaler * (z - x0); +#ifdef USE_SSE2 + return ftoi(FVec1(tmp)); +#else + return static_cast(tmp); +#endif + } + + template + FORCE_INLINE static typename FTOITraits::vec_t f(const FVec& scaler, const FVec& x0, const FVec& z) + { + return ftoi(scaler*(z-x0)); + } + + static T cst0(T scaler, T x0) + { + return x0; + } +}; + +#ifdef USE_FMA +template +struct DirectTraits +{ + typedef FVec1 fVec1; + + static void checkH(T scaler, T H_Times_x0, T xN) + { + union { + typename FVec1::vec_t v; + T s; + } ifmax; + ifmax.v = mulSub(fVec1(scaler), fVec1(xN), fVec1(H_Times_x0)); + myassert((ifmax.s < std::numeric_limits::max() - (Gap - 1)), + "Problem unfeasible: index size exceeds uint32 capacity:" + << " H X[0] =" << H_Times_x0 + << ", H =" << scaler + << ", X[N] =" << xN + << ", H X[N] - H X[0] =" << ifmax.s << "\n" + ); + } + + FORCE_INLINE static uint32 f(T scaler, T Hx0, T xi) + { + return ftoi(mulSub(fVec1(scaler), fVec1(xi), fVec1(Hx0))); + } + + template + FORCE_INLINE static typename FTOITraits::vec_t f(const FVec& scaler, const FVec& H_Times_X0, const FVec& z) + { + return ftoi(mulSub(scaler, z, H_Times_X0)); + } + + static T cst0(T scaler, T x0) + { + return scaler*x0; + } +}; +#endif + +template +struct DirectInfo +{ + static const bool UseFMA = (A == DirectFMA) || (A == Direct2FMA) || (A == DirectCacheFMA); + typedef DirectTraits fun_t; + typedef BucketElem bucket_t; + typedef AlignedVec bucketvec_t; + + struct Data { + Data() : buckets(0), xi(0), scaler(0), cst0(0) {} + Data( const T *x // for Direct must persist if xws=NULL + , uint32 n + , T H + , bucket_t *bws // assumed to gave size nb, as computed below + , T *xws = NULL // assumed to have size (n+Gap-1). Optional for Direct, unused for DirectCache, required for DirectGap + ) + : buckets(bws) + , scaler(H) + , cst0(fun_t::cst0(H, x[0])) + { + myassert(((bws != NULL) && (isAligned(bws,64))), "bucket pointer not allocated or incorrectly aligned"); + + uint32 nb = 1 + fun_t::f(H, cst0, x[n-1]); + + const uint32 npad = Gap-1; + const uint32 n_sz = n + npad; // size of padded vector + + if (xws) { + myassert(isAligned(xws,8), "x pointer not allocated or incorrectly aligned"); + std::fill_n(xws, npad, x[0]); // pad in front with x[0] + std::copy(x, x+n, xws + npad); + xi = xws; + } + else { + myassert(Gap==1, "if Gap>1 then X workspace must be provided"); + xi = x; + } + + populateIndex(bws, nb, xi, n_sz, scaler, cst0); + } + + const bucket_t *buckets; + const T *xi; + T scaler; + T cst0; // could be x0 or (scaler*x0), depending if we are using FMA or not + } data; + + static T growStep(T H) + { + T step; + T P = next(H); + while ((step = P - H) == 0) + P = next(P); + return step; + } + + static HResults computeH(const T *px, uint32 nx) + { + myassert((nx > Gap), "Array X too small"); + myassert(((Gap == 1) || (Gap == 2)), "Only tested for these values of Gap"); + + const T x0 = px[0]; + const T xN = px[nx-1]; + + const T range = xN - x0; + myassert((range < std::numeric_limits::max()), "range too large"); + + // check that D_i are strictly increasing and compute minimum value D_{i+Offset}-D_i + T deltaDMin = range; + for (uint32 i = Gap; i < nx; ++i) { + T Dnew = px[i] - x0; + T Dold = px[i - Gap] - x0; + myassert((Dnew > Dold), + "Problem unfeasible: D_i sequence not strictly increasing" + << " X[" << 0 << "]=" << x0 + << " X[" << i - Gap << "]=" << px[i - Gap] + << " X[" << i << "]=" << px[i] + << "\n" + ); + T deltaD = Dnew - Dold; + if (deltaD < deltaDMin) + deltaDMin = deltaD; + } + + // initial guess for H + const T H0 = T(1.0) / deltaDMin; + T H = H0; + + T cst0 = fun_t::cst0(H, x0); + fun_t::checkH(H, cst0, xN); + + // adjust H by trial and error until succeed + size_t nInc = 0; + bool modified = false; + size_t npasses = 0; + T step = growStep(H); + uint32 seg_already_checked_from = nx; + do { + myassert((npasses++ < 2), "verification failed\n"); + // if there has been an increase, then check only up to that point + uint32 last_seg_to_be_checked = seg_already_checked_from - 1; + modified = false; + uint32 inew = 0; + for (uint32 i = Gap; i <= last_seg_to_be_checked; ++i) { + uint32 iold = fun_t::f(H, cst0, px[i-Gap]); + uint32 inew = fun_t::f(H, cst0, px[i]); + while (inew == iold) { + seg_already_checked_from = i; + last_seg_to_be_checked = nx-1; // everything needs to be checked + modified = true; + H = H + step; + step *= 2; + // recalculate all constants and indices + cst0 = fun_t::cst0(H, x0); + fun_t::checkH(H, cst0, xN); + iold = fun_t::f(H, cst0, px[i - Gap]); + inew = fun_t::f(H, cst0, px[i]); + } + } + } while (SAFETY_MULTI_PASS && modified); + + return HResults(H, (((double)H) / H0) - 1.0, nInc); + } + + static void populateIndex(BucketElem *buckets, uint32 index_size, const T *px, uint32 x_size, T scaler, T cst0) + { + for (uint32 i = x_size-1, b = index_size-1, j=0; ; --i) { + uint32 idx = fun_t::f(scaler, cst0, px[i]); + while (b > idx) { // in the 1st iteration it is j=0 but this condition is always false + buckets[b].set( j, px ); + --b; + } + if (Gap==1 || b == idx) { // if Gap==1, which is known at compile time, the check b==idx is redundant + j = i - (Gap-1); // subtracting (Gap-1) points to the index of the first X-element to check + buckets[b].set(j, px); + if (b-- == 0) + break; + } + } + } + + DirectInfo(const Data& d) + : data(d) + { + } + + DirectInfo(const T* px, const uint32 n) + { + HResults res = computeH(px, n); + +#ifdef PAPER_TEST + nInc = res.nInc; + hRatio = res.hRatio; +#endif + const uint32 npad = Gap-1; + const uint32 n_sz = n + npad; // size of padded vector + + if (npad) + xi.resize(n_sz); + + T H = res.H; + T cst0 = fun_t::cst0(H, px[0]); + const uint32 maxIndex = fun_t::f(H, cst0, px[n-1]); + buckets.resize(maxIndex + 1); + + data = Data(px, n, H, buckets.begin(), (npad? xi.begin(): NULL)); + } + +private: + bucketvec_t buckets; + AlignedVec xi; + +#ifdef PAPER_TEST +public: + double hRatio; + size_t nInc; +#endif +}; + + +} // namespace DirectAux +} // namespace Details +} // namespace BinSearch diff --git a/include/Algo-Direct2.h b/include/Algo-Direct2.h new file mode 100644 index 000000000..d5fa58d12 --- /dev/null +++ b/include/Algo-Direct2.h @@ -0,0 +1,305 @@ +#pragma once + +#include "Algo-Direct-Common.h" + +namespace BinSearch { +namespace Details { + +template +struct AlgoScalarBase::value>::type> : DirectAux::DirectInfo<2, T, A> +{ +private: + typedef DirectAux::DirectInfo<2, T, A> base_t; + static const size_t Offset=2; + +public: + AlgoScalarBase(const T* x, const uint32 n) + : base_t(x, n) + { + } + + FORCE_INLINE uint32 scalar(T z) const + { + const T* px = base_t::data.xi; + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + uint32 bidx = base_t::fun_t::f(base_t::data.scaler, base_t::data.cst0, z); + uint32 iidx = buckets[bidx]; + px += iidx; + if (z < *px) + --iidx; + if (z < *(px+1)) + --iidx; + return iidx; + } +}; + + +template +struct AlgoVecBase::value>::type> : AlgoScalarBase +{ + static const uint32 nElem = sizeof(typename InstrFloatTraits::vec_t) / sizeof(T); + + typedef FVec fVec; + typedef IVec i128; + + struct Constants + { + fVec vscaler; + fVec vcst0; + IVec one; + }; + +private: + typedef AlgoScalarBase base_t; + + FORCE_INLINE + //NO_INLINE + void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const + { + union U { + __m128i vec; + uint32 ui32[4]; + } u; + + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + const float *xi = base_t::data.xi; + + // read indices t + const double *p3 = reinterpret_cast(&xi[(u.ui32[3] = buckets[bidx.get3()])]); + const double *p2 = reinterpret_cast(&xi[(u.ui32[2] = buckets[bidx.get2()])]); + const double *p1 = reinterpret_cast(&xi[(u.ui32[1] = buckets[bidx.get1()])]); + const double *p0 = reinterpret_cast(&xi[(u.ui32[0] = buckets[bidx.get0()])]); + +#if 0 + // read pairs ( X(t-1), X(t) ) + __m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3)); + __m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2)); + __m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1)); + __m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0)); + + // build: + // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) } + // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) } + __m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6)); + __m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6)); + __m128 u01 = _mm_unpacklo_ps(h02, h13); + __m128 u23 = _mm_unpackhi_ps(h02, h13); + __m128 vxm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6)); + __m128 vxp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6)); +#else + __m128 xp23 = _mm_castpd_ps(_mm_set_pd(*p3, *p2)); + __m128 xp01 = _mm_castpd_ps(_mm_set_pd(*p1, *p0)); + __m128 vxm = _mm_shuffle_ps(xp01, xp23, (0) + (2 << 2) + (0 << 4) + (2 << 6)); + __m128 vxp = _mm_shuffle_ps(xp01, xp23, (1) + (3 << 2) + (1 << 4) + (3 << 6)); +#endif + IVec i(u.vec); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; + i = i + vlem + vlep; + i.store(pr); + } + + FORCE_INLINE + //NO_INLINE + void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const + { + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + const double *xi = base_t::data.xi; + + uint32 b1 = buckets[bidx.get1()]; + uint32 b0 = buckets[bidx.get0()]; + + const double *p1 = &xi[b1]; + const double *p0 = &xi[b0]; + + // read pairs ( X(t-1), X(t) ) + __m128d vx1 = _mm_loadu_pd(p1); + __m128d vx0 = _mm_loadu_pd(p0); + + // build: + // { X(t(0)-1), X(t(1)-1) } + // { X(t(0)), X(t(1)) } + __m128d vxm = _mm_shuffle_pd(vx0, vx1, 0); + __m128d vxp = _mm_shuffle_pd(vx0, vx1, 3); + + IVec i(b1, b0); + IVec vlem = (vz < vxm); + IVec vlep = (vz < vxp); + i = i + vlem + vlep; + + union { + __m128i vec; + uint32 ui32[4]; + } u; + u.vec = i; + pr[0] = u.ui32[0]; + pr[1] = u.ui32[2]; + } + +#ifdef USE_AVX + + FORCE_INLINE + //NO_INLINE + void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const + { + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + const float *xi = base_t::data.xi; + +#if 0 // use gather instructions + + IVec idxm; + idxm.setidx(buckets, bidx); + __m256i z = _mm256_setzero_si256(); + IVec minusone = _mm256_cmpeq_epi32(z,z); + IVec idxp = idxm - minusone; + + FVec vxm = _mm256_i32gather_ps(xi, idxm, sizeof(float)); + FVec vxp = _mm256_i32gather_ps(xi, idxp, sizeof(float)); + IVec ip = idxm; + +#else // do not use gather instrucions + + union U { + __m256i vec; + uint32 ui32[8]; + } u; + + // read indices t + + const double *p7 = reinterpret_cast(&xi[(u.ui32[7] = buckets[bidx.get7()])]); + const double *p6 = reinterpret_cast(&xi[(u.ui32[6] = buckets[bidx.get6()])]); + const double *p5 = reinterpret_cast(&xi[(u.ui32[5] = buckets[bidx.get5()])]); + const double *p4 = reinterpret_cast(&xi[(u.ui32[4] = buckets[bidx.get4()])]); + const double *p3 = reinterpret_cast(&xi[(u.ui32[3] = buckets[bidx.get3()])]); + const double *p2 = reinterpret_cast(&xi[(u.ui32[2] = buckets[bidx.get2()])]); + const double *p1 = reinterpret_cast(&xi[(u.ui32[1] = buckets[bidx.get1()])]); + const double *p0 = reinterpret_cast(&xi[(u.ui32[0] = buckets[bidx.get0()])]); + +#if 0 // perform 8 loads in double precision + + // read pairs ( X(t-1), X(t) ) + __m128 xp7 = _mm_castpd_ps(_mm_load_sd(p7)); + __m128 xp6 = _mm_castpd_ps(_mm_load_sd(p6)); + __m128 xp5 = _mm_castpd_ps(_mm_load_sd(p5)); + __m128 xp4 = _mm_castpd_ps(_mm_load_sd(p4)); + __m128 xp3 = _mm_castpd_ps(_mm_load_sd(p3)); + __m128 xp2 = _mm_castpd_ps(_mm_load_sd(p2)); + __m128 xp1 = _mm_castpd_ps(_mm_load_sd(p1)); + __m128 xp0 = _mm_castpd_ps(_mm_load_sd(p0)); + + // build: + // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) } + // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) } + __m128 h57 = _mm_shuffle_ps(xp5, xp7, (1 << 2) + (1 << 6)); // F- F+ H- H+ + __m128 h46 = _mm_shuffle_ps(xp4, xp6, (1 << 2) + (1 << 6)); // E- E+ G- G+ + __m128 h13 = _mm_shuffle_ps(xp1, xp3, (1 << 2) + (1 << 6)); // B- B+ D- D+ + __m128 h02 = _mm_shuffle_ps(xp0, xp2, (1 << 2) + (1 << 6)); // A- A+ C- C+ + + __m128 u01 = _mm_unpacklo_ps(h02, h13); // A- B- A+ B+ + __m128 u23 = _mm_unpackhi_ps(h02, h13); // C- D- C+ D+ + __m128 u45 = _mm_unpacklo_ps(h46, h57); // E- F- E+ F+ + __m128 u67 = _mm_unpackhi_ps(h46, h57); // G- H- G+ H+ + + __m128 abcdm = _mm_shuffle_ps(u01, u23, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // A- B- C- D- + __m128 abcdp = _mm_shuffle_ps(u01, u23, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // A+ B+ C+ D+ + __m128 efghm = _mm_shuffle_ps(u45, u67, (0) + (1 << 2) + (0 << 4) + (1 << 6)); // E- F- G- H- + __m128 efghp = _mm_shuffle_ps(u45, u67, (2) + (3 << 2) + (2 << 4) + (3 << 6)); // E+ F+ G+ H+ + + FVec vxp = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdm), efghm, 1); + FVec vxm = _mm256_insertf128_ps(_mm256_castps128_ps256(abcdp), efghp, 1); + + IVec ip(u.vec); + +#else // use __mm256_set_pd + + // read pairs ( X(t-1), X(t) ) + __m256 x0145 = _mm256_castpd_ps(_mm256_set_pd(*p5, *p4, *p1, *p0)); // { x0(t-1), x0(t), x1(t-1), x1(t), x4(t-1), x4(t), x5(t-1), x5(t) } + __m256 x2367 = _mm256_castpd_ps(_mm256_set_pd(*p7, *p6, *p3, *p2)); // { x2(t-1), x2(t), x3(t-1), x3(t), x6(t-1), x6(t), x7(t-1), x7(t) } + + // { x0(t-1), x1(t-1), x2(t-1), 3(t-1, x4(t-1), x5(t-1), x6(t-1), xt(t-1) } + FVec vxm = _mm256_shuffle_ps(x0145, x2367, 0 + (2 << 2) + (0 << 4) + (2 << 6) ); + // { x0(t), x1(t), x2(t), 3(t, x4(t), x5(t), x6(t), xt(t) } + FVec vxp = _mm256_shuffle_ps(x0145, x2367, 1 + (3 << 2) + (1 << 4) + (3 << 6) ); + + IVec ip(u.vec); + +#endif + +#endif + + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; + ip = ip + vlem + vlep; + + ip.store(pr); + } + + + + FORCE_INLINE + //NO_INLINE + void resolve(const FVec& vz, const IVec& bidx, uint32 *pr) const + { + union { + __m256i vec; + uint64 ui64[4]; + } u; + + const uint32* buckets = reinterpret_cast(base_t::data.buckets); + const double *xi = base_t::data.xi; + + // read indices t + const double *p3 = &xi[(u.ui64[3] = buckets[bidx.get3()])]; + const double *p2 = &xi[(u.ui64[2] = buckets[bidx.get2()])]; + const double *p1 = &xi[(u.ui64[1] = buckets[bidx.get1()])]; + const double *p0 = &xi[(u.ui64[0] = buckets[bidx.get0()])]; + + // read pairs ( X(t-1), X(t) ) + __m128d xp3 = _mm_loadu_pd(p3); + __m128d xp2 = _mm_loadu_pd(p2); + __m128d xp1 = _mm_loadu_pd(p1); + __m128d xp0 = _mm_loadu_pd(p0); + + // build: + // { X(t(0)-1), X(t(1)-1), X(t(2)-1), X(t(3)-1) } + // { X(t(0)), X(t(1)), X(t(2)), X(t(3)) } + __m256d x02 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp0), xp2, 1); + __m256d x13 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xp1), xp3, 1); + FVec vxm = _mm256_unpacklo_pd(x02,x13); + FVec vxp = _mm256_unpackhi_pd(x02,x13); + + +// __m128d h01m = _mm_shuffle_pd(xp0, xp1, 0); +// __m128d h23m = _mm_shuffle_pd(xp2, xp3, 0); +// __m128d h01p = _mm_shuffle_pd(xp0, xp1, 3); +// __m128d h23p = _mm_shuffle_pd(xp2, xp3, 3); +// FVec vxm = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01m), h23m, 1); +// FVec vxp = _mm256_insertf128_pd(_mm256_castpd128_pd256(h01p), h23p, 1); + + IVec i(u.vec); + IVec vlem = vz < vxm; + IVec vlep = vz < vxp; + i = i + vlem + vlep; + i.extractLo32s().store(pr); + } +#endif + +public: + + AlgoVecBase(const T* x, const uint32 n) : base_t(x, n) {} + + void initConstants(Constants& cst) const + { + cst.vscaler.setN(base_t::data.scaler); + cst.vcst0.setN(base_t::data.cst0); + cst.one.setN(uint32(1)); + } + + void vectorial(uint32 *pr, const T *pz, const Constants& cst) const + { + fVec vz(pz); + resolve(vz, base_t::fun_t::f(cst.vscaler, cst.vcst0, vz), pr); + } +}; +} // namespace Details +} // namespace BinSearch diff --git a/include/AlgoXCodes.h b/include/AlgoXCodes.h new file mode 100644 index 000000000..bdc9b00b6 --- /dev/null +++ b/include/AlgoXCodes.h @@ -0,0 +1,23 @@ +ALGOENUM(DirectCacheFMA, 5) +ALGOENUM(DirectFMA, 15) +ALGOENUM(Direct2FMA, 25) +ALGOENUM(DirectCache, 10) +ALGOENUM(Direct, 20) +ALGOENUM(Direct2, 30) +ALGOENUM(Nonary, 40) +ALGOENUM(Pentary, 50) +ALGOENUM(Ternary, 60) +ALGOENUM(Eytzinger, 70) +ALGOENUM(BitSet, 80) +ALGOENUM(ClassicOffset, 90) +#ifdef PAPER_TEST +ALGOENUM(MorinOffset, 100) +ALGOENUM(BitSetNoPad, 110) +ALGOENUM(ClassicMod, 120) +ALGOENUM(MorinBranchy, 130) +ALGOENUM(Classic, 140) +ALGOENUM(LowerBound, 145) +#ifdef USE_MKL +ALGOENUM(MKL, 150) +#endif +#endif diff --git a/include/BinAlgo.h b/include/BinAlgo.h new file mode 100644 index 000000000..aac67a0c7 --- /dev/null +++ b/include/BinAlgo.h @@ -0,0 +1,77 @@ +#pragma once + +#include "Type.h" +#include + +namespace BinSearch { + +template +struct BinAlgo : Details::BinAlgoBase +{ + typedef Details::BinAlgoBase base_t; + + BinAlgo(const T* px, const uint32 n) : base_t(px, n), x0(px[0]), xN(px[n-1]), N(n) {} + BinAlgo(const T* px, const uint32 n, const typename base_t::Data& d) : base_t(d), x0(px[0]), xN(px[n-1]), N(n) {} + + FORCE_INLINE + uint32 scalar(T z) const + { + if (!L || z >= x0) + if (!R || z < xN) + return base_t::scalar(z); + else + return N; + else + return std::numeric_limits::max(); + } + + + FORCE_INLINE + void vectorial(uint32 *pr, const T *pz, uint32 n) const + { + if (!L && !R) { + Details::Loop::loop(*this, pr, pz, n); + } + else { + const uint32 nElem = base_t::nElem; + const uint32 idealbufsize = 256; + const uint32 bufsize = nElem * (idealbufsize / nElem + ((idealbufsize % nElem) ? 1 : 0)); + T databuf[bufsize]; + uint32 resbuf[bufsize]; + uint32 indexbuf[bufsize]; + + uint32 *prend = pr + n; + while(pr != prend) { + uint32 cnt = 0; + uint32 niter = std::min(bufsize, (uint32)std::distance(pr,prend)); + for (uint32 j = 0; j < niter; ++j) { + T z = pz[j]; + // FIXME: use SSE2? + if (!L || z >= x0) + if (!R || z < xN) { + databuf[cnt] = z; + indexbuf[cnt] = j; + ++cnt; + } + else + pr[j] = N; + else + pr[j] = std::numeric_limits::max(); + } + // FIXME: merge these two loops + Details::Loop::loop(*this, resbuf, databuf, cnt); + for (uint32 j = 0; j < cnt; ++j) + pr[indexbuf[j]] = resbuf[j]; + pr += niter; + pz += niter; + } + } + } + + Details::CondData x0; + Details::CondData xN; + Details::CondData N; +}; + + +} // namespace BinSearch diff --git a/include/BinSearch.h b/include/BinSearch.h new file mode 100644 index 000000000..336f52963 --- /dev/null +++ b/include/BinSearch.h @@ -0,0 +1,11 @@ +#pragma once + +#include "AAlloc.h" +#include "BinAlgo.h" +#include "SIMD.h" + +#include +#include + + +#include "Algo-Direct2.h" diff --git a/include/Portable.h b/include/Portable.h new file mode 100644 index 000000000..1710b0502 --- /dev/null +++ b/include/Portable.h @@ -0,0 +1,151 @@ +#pragma once +#include +#include +#include +#include + +#ifdef __FMA__ +#define USE_FMA +#endif + +#ifdef __AVX2__ +#define USE_AVX2 +#endif + +#ifdef __AVX__ +#define USE_AVX +#endif + + +#ifdef __SSE4_1__ +#define USE_SSE41 +#endif + +#ifdef __SSE4_2__ +#define USE_SSE42 +#endif + + +#ifndef _MSC_VER +#include +#endif + +namespace BinSearch { + +#ifndef _MSC_VER +typedef int8_t int8; +typedef uint8_t uint8; +typedef int32_t int32; +typedef uint32_t uint32; +typedef int64_t int64; +typedef uint64_t uint64; +#else +typedef __int8 int8; +typedef unsigned __int8 uint8; +typedef __int32 int32; +typedef unsigned __int32 uint32; +typedef __int64 int64; +typedef unsigned __int64 uint64; +#endif + +namespace Details { + +#define myassert(cond, msg) if (!cond){ std::ostringstream os; os << "\nassertion failed: " << #cond << ", " << msg << "\n"; throw std::invalid_argument(os.str()); } + +// log2 is not defined in VS2008 +#if defined(_MSC_VER) +inline uint32 log2 (uint32 val) { + if (val == 1) return 0; + uint32 ret = 0; + do { + ret++; + val >>= 1; + } while (val > 1); + return ret; +} +#endif + +#ifdef _DEBUG +#define DEBUG +#endif + +#ifdef _MSC_VER +# define FORCE_INLINE __forceinline +# define NO_INLINE __declspec(noinline) +#else +# define NO_INLINE __attribute__((noinline)) +# ifdef DEBUG +# define FORCE_INLINE NO_INLINE +# else +# define FORCE_INLINE __attribute__((always_inline)) inline +# endif +#endif + +#ifdef USE_AVX +#define COMISS "vcomiss" +#define COMISD "vcomisd" +#else +#define COMISS "comiss" +#define COMISD "comisd" +#endif + +// nextafter is not defined in VS2008 +#if defined(_MSC_VER) && (_MSC_VER <= 1500) +#include +inline float mynext(float x) +{ + return _nextafterf(x, std::numeric_limits::max()); +} + +inline double mynext(double x) +{ + return _nextafter(x, std::numeric_limits::max()); +} +inline float myprev(float x) +{ + return _nextafterf(x, -std::numeric_limits::max()); +} + +inline double myprev(double x) +{ + return _nextafter(x, -std::numeric_limits::max()); +} +#else +inline float mynext(float x) +{ + return std::nextafterf(x, std::numeric_limits::max()); +} + +inline double mynext(double x) +{ + return std::nextafter(x, std::numeric_limits::max()); +} +inline float myprev(float x) +{ + return std::nextafterf(x, -std::numeric_limits::max()); +} + +inline double myprev(double x) +{ + return std::nextafter(x, -std::numeric_limits::max()); +} +#endif + +template +inline T next(T x) +{ + for (int i = 0; i < 4; ++i) + x = mynext(x); + return x; +} + +template +inline T prev(T x) +{ + for (int i = 0; i < 4; ++i) + x = myprev(x); + return x; +} + +} // namepsace Details +} // namespace BinSearch diff --git a/include/SIMD.h b/include/SIMD.h new file mode 100644 index 000000000..a2ac1a9ae --- /dev/null +++ b/include/SIMD.h @@ -0,0 +1,562 @@ +#pragma once + +#include "Portable.h" + +#ifdef USE_SSE42 +#ifndef _MSC_VER +#include +#define popcnt32 _mm_popcnt_u32 +#else +#include +#define popcnt32 __popcnt +#endif +#else // USE_SSE42 +namespace BinSearch { +FORCE_INLINE int popcnt32(int x32) +{ + // strictly speaking this is not correct, as it ignores higher order bits + // however, this is only used on the resuot of movemask on a 128-bit register, which is 8 at most, so it is ok + // with 256-bit registers, SSE42 is defined, and we do not use this function + uint8 x = static_cast(x32); + x = (x & 0x55) + (x >> 1 & 0x55); + x = (x & 0x33) + (x >> 2 & 0x33); + x = (x & 0x0f) + (x >> 4 & 0x0f); + return x; +} +} // namespace +#endif + +#if defined(USE_AVX) || defined(USE_AVX2) +#include +#else +#include +#ifdef USE_SSE41 +#include +#endif +#endif + +#include "Type.h" + +namespace BinSearch { +namespace Details { + +template +struct FVec; + +template +struct IVec; + +template +struct FVec1; + +template <> struct InstrIntTraits +{ + typedef __m128i vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m128 vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m128d vec_t; +}; + +template +struct FTOITraits +{ + typedef IVec vec_t; +}; + +#ifdef USE_AVX + +template <> +struct FTOITraits +{ + typedef IVec vec_t; +}; + +template <> struct InstrIntTraits +{ + typedef __m256i vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m256 vec_t; +}; + +template <> struct InstrFloatTraits +{ + typedef __m256d vec_t; +}; + +#endif + + +template +struct VecStorage +{ + typedef typename TR::vec_t vec_t; + + FORCE_INLINE operator vec_t&() { return vec; } + FORCE_INLINE operator const vec_t&() const { return vec; } + +protected: + FORCE_INLINE VecStorage() {} + FORCE_INLINE VecStorage(const vec_t& v) : vec( v ) {} + + vec_t vec; +}; + +template +struct IVecBase; + +template <> +struct IVecBase : VecStorage> +{ +protected: + FORCE_INLINE IVecBase() {} + FORCE_INLINE IVecBase( const vec_t& v) : VecStorage>( v ) {} +public: + FORCE_INLINE static vec_t zero() { return _mm_setzero_si128(); } + + FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32( vec ); } + + FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask ) + { +#ifdef USE_SSE41 + vec = _mm_blendv_epi8(vec, val, mask); +#else + vec = _mm_or_si128(_mm_andnot_si128(mask,vec), _mm_and_si128(mask,val)); +#endif + } + FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask) + { + vec = _mm_or_si128(vec, _mm_and_si128(val,mask)); + } +}; + +template <> +struct IVec : IVecBase +{ + FORCE_INLINE IVec() {} + FORCE_INLINE IVec( int32 i ) : IVecBase( _mm_set1_epi32( i ) ) {} + FORCE_INLINE IVec( const vec_t& v) : IVecBase( v ) {} + FORCE_INLINE IVec( uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase( _mm_set_epi32( u3, u2, u1, u0 ) ) {} + + void setN( int32 i ) { vec = _mm_set1_epi32( i ); } + +#ifdef USE_SSE41 + FORCE_INLINE int32 get1() const { return _mm_extract_epi32(vec, 1); } + FORCE_INLINE int32 get2() const { return _mm_extract_epi32(vec, 2); } + FORCE_INLINE int32 get3() const { return _mm_extract_epi32(vec, 3); } +#else + FORCE_INLINE int32 get1() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 1 ) ); } + FORCE_INLINE int32 get2() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) ); } + FORCE_INLINE int32 get3() const { return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 3 ) ); } +#endif + + FORCE_INLINE void store( uint32 *pi ) const { _mm_storeu_si128( reinterpret_cast(pi), vec ); } + + FORCE_INLINE int countbit() + { + return popcnt32(_mm_movemask_ps(_mm_castsi128_ps(vec))); + } +}; + +template <> +struct IVec : IVecBase +{ + FORCE_INLINE IVec() {} + FORCE_INLINE IVec( int32 i ) : IVecBase( _mm_set1_epi64x( i ) ) {} + FORCE_INLINE IVec( const vec_t& v) : IVecBase( v ) {} + FORCE_INLINE IVec( uint64 u1, uint64 u0 ) : IVecBase( _mm_set_epi64x(u1, u0) ) {} + + void setN( int32 i ) { vec = _mm_set1_epi64x( i ); } + + FORCE_INLINE int32 get1() const + { +#ifdef USE_SSE41 + return _mm_extract_epi32(vec, 2); +#else + return _mm_cvtsi128_si32( _mm_shuffle_epi32( vec, 2 ) ); +#endif + } + + // extract the 2 32 bits integers no. 0, 2 and store them in a __m128i + FORCE_INLINE IVec extractLo32s() const + { + return _mm_shuffle_epi32(vec, ((2 << 2) | 0)); + } + + FORCE_INLINE void store( uint32 *pi ) const + { + pi[0] = get0(); + pi[1] = get1(); + } + + FORCE_INLINE int countbit() + { +#if 1 + // takes 4 cycles + __m128i hi = _mm_shuffle_epi32(vec, 2); // 1 cycle + __m128i s = _mm_add_epi32(vec, hi); + int32 x = _mm_cvtsi128_si32(s); + return -x; +#else + // takes 6 cycles + return popcnt32(_mm_movemask_pd(_mm_castsi128_pd(vec))); +#endif + } +}; + +template +FORCE_INLINE IVec operator>> (const IVec& a, unsigned n) { return _mm_srli_epi32(a, n); } +template +FORCE_INLINE IVec operator<< (const IVec& a, unsigned n) { return _mm_slli_epi32(a, n); } +template +FORCE_INLINE IVec operator& (const IVec& a, const IVec& b ) { return _mm_and_si128( a, b ); } +template +FORCE_INLINE IVec operator| (const IVec& a, const IVec& b ) { return _mm_or_si128( a, b ); } +template +FORCE_INLINE IVec operator^ (const IVec& a, const IVec& b ) { return _mm_xor_si128( a, b ); } +template +FORCE_INLINE IVec operator+ (const IVec& a, const IVec& b ) { return _mm_add_epi32( a, b ); } +template +FORCE_INLINE IVec operator- (const IVec& a, const IVec& b ) { return _mm_sub_epi32( a, b ); } +#ifdef USE_SSE41 +template +FORCE_INLINE IVec min (const IVec& a, const IVec& b ) { return _mm_min_epi32( a, b ); } +#endif + +typedef VecStorage> FVec128Float; + +template <> +struct FVec1 : FVec128Float +{ + FORCE_INLINE FVec1() {} + FORCE_INLINE FVec1( float f ) : FVec128Float( _mm_load_ss( &f ) ) {} + FORCE_INLINE FVec1( const vec_t& v ): FVec128Float( v ) {} + + FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); } +}; + +template <> +struct FVec : FVec128Float +{ + FORCE_INLINE FVec() {} + FORCE_INLINE FVec( float f ) : FVec128Float( _mm_set1_ps( f ) ) {} + FORCE_INLINE FVec( const float *v ) : FVec128Float( _mm_loadu_ps( v ) ) {} + FORCE_INLINE FVec( const vec_t& v) : FVec128Float(v) {} + FORCE_INLINE FVec( float f3, float f2, float f1, float f0 ) : FVec128Float( _mm_set_ps(f3, f2, f1, f0) ) {} + + void set0( float f ) { vec = _mm_load_ss( &f ); } + void setN( float f ) { vec = _mm_set1_ps( f ); } + + FORCE_INLINE void setidx( const float *xi, const IVec& idx ) + { + uint32 i0 = idx.get0(); + uint32 i1 = idx.get1(); + uint32 i2 = idx.get2(); + uint32 i3 = idx.get3(); + vec = _mm_set_ps( xi[i3], xi[i2], xi[i1], xi[i0] ); + } + + FORCE_INLINE float get0() const { return _mm_cvtss_f32( vec ); } + FORCE_INLINE float get1() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 1 ) ); } + FORCE_INLINE float get2() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 2 ) ); } + FORCE_INLINE float get3() const { return _mm_cvtss_f32( _mm_shuffle_ps( vec, vec, 3 ) ); } +}; + +FORCE_INLINE FVec1 operator+ (const FVec1& a, const FVec1& b) { return _mm_add_ss( a, b ); } +FORCE_INLINE FVec1 operator- (const FVec1& a, const FVec1& b) { return _mm_sub_ss( a, b ); } +FORCE_INLINE FVec1 operator* (const FVec1& a, const FVec1& b) { return _mm_mul_ss( a, b ); } +FORCE_INLINE FVec1 operator/ (const FVec1& a, const FVec1& b) { return _mm_div_ss( a, b ); } +FORCE_INLINE int ftoi (const FVec1& a) { return _mm_cvttss_si32(a); } +FORCE_INLINE IVec operator> (const FVec1& a, const FVec1& b) { return _mm_castps_si128( _mm_cmpgt_ss( a, b ) ); } +#ifdef USE_FMA +FORCE_INLINE FVec1 mulSub(const FVec1& a, const FVec1& b, const FVec1& c) { return _mm_fmsub_ss(a, b, c); } +#endif + +FORCE_INLINE FVec operator- (const FVec& a, const FVec& b) { return _mm_sub_ps( a, b ); } +FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_ps( a, b ); } +FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_ps( a, b ); } +FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttps_epi32(a); } +FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmple_ps( a, b ) ); } +FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castps_si128( _mm_cmpge_ps( a, b ) ); } +FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castps_si128(_mm_cmplt_ps(a, b)); } +#ifdef USE_FMA +FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm_fmsub_ps(a, b, c); } +#endif + +typedef VecStorage> FVec128Double; + +template <> +struct FVec1 : FVec128Double +{ + FORCE_INLINE FVec1() {} + FORCE_INLINE FVec1( double f ) : FVec128Double( _mm_load_sd( &f ) ) {} + FORCE_INLINE FVec1( const vec_t& v ) : FVec128Double( v ) {} + + FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); } +}; + +template <> +struct FVec : FVec128Double +{ + FORCE_INLINE FVec() {} + FORCE_INLINE FVec( double d ) : FVec128Double( _mm_set1_pd( d ) ) {} + FORCE_INLINE FVec( const double *v ) : FVec128Double( _mm_loadu_pd( v ) ) {} + FORCE_INLINE FVec( const vec_t& v) : FVec128Double( v ) {} + FORCE_INLINE FVec( double f1, double f0 ) : FVec128Double( _mm_set_pd(f1, f0) ) {} + + void set0( double f ) { vec = _mm_load_sd( &f ); } + void setN( double f ) { vec = _mm_set1_pd( f ); } + + FORCE_INLINE void setidx( const double *xi, const IVec& idx ) + { + vec = _mm_set_pd( xi[idx.get1()], xi[idx.get0()] ); + } + + FORCE_INLINE double get0() const { return _mm_cvtsd_f64( vec ); } + FORCE_INLINE double get1() const { return _mm_cvtsd_f64( _mm_shuffle_pd( vec, vec, 1 ) ); }; +}; + +FORCE_INLINE FVec1 operator+ (const FVec1& a, const FVec1& b) { return _mm_add_sd( a, b ); } +FORCE_INLINE FVec1 operator- (const FVec1& a, const FVec1& b) { return _mm_sub_sd( a, b ); } +FORCE_INLINE FVec1 operator* (const FVec1& a, const FVec1& b) { return _mm_mul_sd( a, b ); } +FORCE_INLINE FVec1 operator/ (const FVec1& a, const FVec1& b) { return _mm_div_sd( a, b ); } +FORCE_INLINE int ftoi (const FVec1& a) { return _mm_cvttsd_si32(a); } +FORCE_INLINE IVec operator> (const FVec1& a, const FVec1& b) { return _mm_castpd_si128( _mm_cmpgt_sd( a, b ) ); } +#ifdef USE_FMA +FORCE_INLINE FVec1 mulSub(const FVec1& a, const FVec1& b, const FVec1& c) { return _mm_fmsub_sd(a, b, c); } +#endif + +FORCE_INLINE FVec operator- (const FVec& a, const FVec& b) { return _mm_sub_pd( a, b ); } +FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm_mul_pd( a, b ); } +FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm_div_pd( a, b ); } +FORCE_INLINE IVec ftoi (const FVec& a) { return _mm_cvttpd_epi32(a); } +FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmple_pd( a, b ) ); } +FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm_castpd_si128(_mm_cmplt_pd(a, b)); } +FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm_castpd_si128( _mm_cmpge_pd( a, b ) ); } +#ifdef USE_FMA +FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c ) { return _mm_fmsub_pd(a, b, c); } +#endif + +#ifdef USE_AVX + +template <> +struct IVecBase : VecStorage> +{ +protected: + FORCE_INLINE IVecBase() {} + FORCE_INLINE IVecBase( const vec_t& v) : VecStorage>( v ) {} +public: + FORCE_INLINE static vec_t zero() { return _mm256_setzero_si256(); } + + FORCE_INLINE int32 get0() const { return _mm_cvtsi128_si32(_mm256_castsi256_si128(vec)); } + + FORCE_INLINE void assignIf( const vec_t& val, const vec_t& mask ) { vec = _mm256_blendv_epi8(vec, val, mask); } + FORCE_INLINE void orIf(const vec_t& val, const vec_t& mask) + { + vec = _mm256_blendv_epi8(vec, val, mask); + //vec = _mm256_or_si256(vec, _mm256_and_si256(val,mask)); + } + + FORCE_INLINE __m128i lo128() const { return _mm256_castsi256_si128(vec); } + FORCE_INLINE __m128i hi128() const { return _mm256_extractf128_si256(vec, 1); } +}; + +template <> +struct IVec : IVecBase +{ + FORCE_INLINE IVec() {} + FORCE_INLINE IVec( int32 i ) : IVecBase( _mm256_set1_epi32( i ) ) {} + FORCE_INLINE IVec( const vec_t& v) : IVecBase( v ) {} + FORCE_INLINE IVec(uint32 u7, uint32 u6, uint32 u5, uint32 u4, uint32 u3, uint32 u2, uint32 u1, uint32 u0) : IVecBase(_mm256_set_epi32(u7, u6, u5, u4, u3, u2, u1, u0)) {} + + void setN( int32 i ) { vec = _mm256_set1_epi32( i ); } + + FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 1); } + FORCE_INLINE int32 get2() const { return _mm256_extract_epi32(vec, 2); } + FORCE_INLINE int32 get3() const { return _mm256_extract_epi32(vec, 3); } + FORCE_INLINE int32 get4() const { return _mm256_extract_epi32(vec, 4); } + FORCE_INLINE int32 get5() const { return _mm256_extract_epi32(vec, 5); } + FORCE_INLINE int32 get6() const { return _mm256_extract_epi32(vec, 6); } + FORCE_INLINE int32 get7() const { return _mm256_extract_epi32(vec, 7); } + + FORCE_INLINE void setidx( const uint32 *bi, const IVec& idx ) + { + vec = _mm256_i32gather_epi32(reinterpret_cast(bi), idx, sizeof(uint32)); + } + + FORCE_INLINE void store( uint32 *pi ) const { _mm256_storeu_si256( reinterpret_cast(pi), vec ); } + + FORCE_INLINE int countbit() + { + return popcnt32(_mm256_movemask_ps(_mm256_castsi256_ps(vec))); + } +}; + +template <> +struct IVec : IVecBase +{ + FORCE_INLINE IVec() {} + FORCE_INLINE IVec( int32 i ) : IVecBase( _mm256_set1_epi64x( i ) ) {} + FORCE_INLINE IVec( const vec_t& v) : IVecBase( v ) {} + FORCE_INLINE IVec(uint64 u3, uint64 u2, uint64 u1, uint64 u0) : IVecBase(_mm256_set_epi64x(u3, u2, u1, u0)) {} + + void setN( int32 i ) { vec = _mm256_set1_epi64x( i ); } + + // extract the 4 32 bits integers no. 0, 2, 4, 6 and store them in a __m128i + FORCE_INLINE IVec extractLo32s() const + { + union { + uint32 u32[4]; + __m128i u; + } mask = {0,2,4,6}; + //__m256 ps256 = _mm256_castsi256_ps(vec); + //__m128 lo128 = _mm256_castps256_ps128(ps256); + //__m128 hi128 = _mm256_extractf128_ps(ps256, 1); + //__m128 blend = _mm_shuffle_ps(lo128, hi128, 0 + (2<<2) + (0<<4) + (2<<6)); + __m256i blend = _mm256_permutevar8x32_epi32(vec, _mm256_castsi128_si256(mask.u)); + return _mm256_castsi256_si128(blend); + } + + //int32 get1() const { return _mm256_cvtsi256_si32( _mm256_shuffle_epi32( vec, 2 ) ); }; + FORCE_INLINE int32 get1() const { return _mm256_extract_epi32(vec, 2); } + + FORCE_INLINE void store( uint32 *pi ) const + { + extractLo32s().store(pi); + } + + FORCE_INLINE int countbit() + { + return popcnt32(_mm256_movemask_pd(_mm256_castsi256_pd(vec))); + } +}; + +template +FORCE_INLINE IVec operator>> (const IVec& a, unsigned n) { return _mm256_srli_epi32(a, n); } +template +FORCE_INLINE IVec operator<< (const IVec& a, unsigned n) { return _mm256_slli_epi32(a, n); } +template +FORCE_INLINE IVec operator& (const IVec& a, const IVec& b ) { return _mm256_and_si256( a, b ); } +template +FORCE_INLINE IVec operator| (const IVec& a, const IVec& b ) { return _mm256_or_si256( a, b ); } +template +FORCE_INLINE IVec operator^ (const IVec& a, const IVec& b ) { return _mm256_xor_si256( a, b ); } +template +FORCE_INLINE IVec min (const IVec& a, const IVec& b ) { return _mm256_min_epi32( a, b ); } + +FORCE_INLINE IVec operator+ (const IVec& a, const IVec& b ) { return _mm256_add_epi32( a, b ); } +FORCE_INLINE IVec operator- (const IVec& a, const IVec& b ) { return _mm256_sub_epi32( a, b ); } +FORCE_INLINE IVec operator+ (const IVec& a, const IVec& b ) { return _mm256_add_epi64( a, b ); } +FORCE_INLINE IVec operator- (const IVec& a, const IVec& b ) { return _mm256_sub_epi64( a, b ); } + + +typedef VecStorage> FVec256Float; + +template <> +struct FVec : FVec256Float +{ + FORCE_INLINE FVec() {} + FORCE_INLINE FVec( float f ) : FVec256Float( _mm256_set1_ps( f ) ) {} + FORCE_INLINE FVec( const float *v ) : FVec256Float( _mm256_loadu_ps( v ) ) {} + FORCE_INLINE FVec( const vec_t& v) : FVec256Float(v) {} + FORCE_INLINE FVec(float f7, float f6, float f5, float f4, float f3, float f2, float f1, float f0) : FVec256Float(_mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0)) {} + + //void set0( float f ) { vec = _mm256_load_ss( &f ); } + void setN( float f ) { vec = _mm256_set1_ps( f ); } + + FORCE_INLINE void setidx( const float *xi, const IVec& idx ) + { +#if 1 // use gather primitives + vec = _mm256_i32gather_ps (xi, idx, 4); +#elif 0 + uint32 i0 = idx.get0(); + uint32 i1 = idx.get1(); + uint32 i2 = idx.get2(); + uint32 i3 = idx.get3(); + uint32 i4 = idx.get4(); + uint32 i5 = idx.get5(); + uint32 i6 = idx.get6(); + uint32 i7 = idx.get7(); + vec = _mm256_set_ps( xi[i7], xi[i6], xi[i5], xi[i4], xi[i3], xi[i2], xi[i1], xi[i0] ); +#else + union { + __m256i vec; + uint32 ui32[8]; + } i; + i.vec = static_cast(idx); + vec = _mm256_set_ps(xi[i.ui32[7]], xi[i.ui32[6]], xi[i.ui32[5]], xi[i.ui32[4]], xi[i.ui32[3]], xi[i.ui32[2]], xi[i.ui32[1]], xi[i.ui32[0]]); +#endif + } + + FORCE_INLINE FVec lo128() const { return _mm256_castps256_ps128(vec); } + FORCE_INLINE FVec hi128() const { return _mm256_extractf128_ps(vec, 1); } + + //FORCE_INLINE float get0() const { return _mm256_cvtss_f32( vec ); } + //FORCE_INLINE float get1() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 1 ) ); } + //FORCE_INLINE float get2() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 2 ) ); } + //FORCE_INLINE float get3() const { return _mm256_cvtss_f32( _mm256_shuffle_ps( vec, vec, 3 ) ); } +}; + +FORCE_INLINE FVec operator- (const FVec& a, const FVec& b) { return _mm256_sub_ps( a, b ); } +FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm256_mul_ps( a, b ); } +FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm256_div_ps( a, b ); } +FORCE_INLINE IVec ftoi (const FVec& a) { return _mm256_cvttps_epi32(a); } +FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_LE_OS) ); } +FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm256_castps_si256( _mm256_cmp_ps( a, b, _CMP_GE_OS ) ); } +FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm256_castps_si256(_mm256_cmp_ps(a, b, _CMP_LT_OS )); } +#ifdef USE_FMA +FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm256_fmsub_ps(a, b, c); } +#endif + +typedef VecStorage> FVec256Double; + +template <> +struct FVec : FVec256Double +{ + FORCE_INLINE FVec() {} + FORCE_INLINE FVec( double d ) : FVec256Double( _mm256_set1_pd( d ) ) {} + FORCE_INLINE FVec( const double *v ) : FVec256Double( _mm256_loadu_pd( v ) ) {} + FORCE_INLINE FVec( const vec_t& v) : FVec256Double( v ) {} + FORCE_INLINE FVec(double d3, double d2, double d1, double d0) : FVec256Double(_mm256_set_pd(d3, d2, d1, d0)) {} + + //void set0( double f ) { vec = _mm256_load_sd( &f ); } + void setN( double f ) { vec = _mm256_set1_pd( f ); } + + FORCE_INLINE void setidx( const double *xi, const IVec& idx ) + { + vec = _mm256_i32gather_pd(xi, idx, 8); + } + + FORCE_INLINE void setidx( const double *xi, const IVec& idx ) + { + vec = _mm256_i64gather_pd(xi, idx, 8); + } + +// FORCE_INLINE double get0() const { return _mm256_cvtsd_f64( vec ); } +// FORCE_INLINE double get1() const { return _mm256_cvtsd_f64( _mm256_shuffle_pd( vec, vec, 1 ) ); }; +}; + +FORCE_INLINE FVec operator- (const FVec& a, const FVec& b) { return _mm256_sub_pd( a, b ); } +FORCE_INLINE FVec operator* (const FVec& a, const FVec& b) { return _mm256_mul_pd( a, b ); } +FORCE_INLINE FVec operator/ (const FVec& a, const FVec& b) { return _mm256_div_pd( a, b ); } +FORCE_INLINE IVec ftoi (const FVec& a) { return _mm256_cvttpd_epi32(a); } +FORCE_INLINE IVec operator<= (const FVec& a, const FVec& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_LE_OS ) ); } +FORCE_INLINE IVec operator< (const FVec& a, const FVec& b) { return _mm256_castpd_si256(_mm256_cmp_pd(a, b, _CMP_LT_OS)); } +FORCE_INLINE IVec operator>= (const FVec& a, const FVec& b) { return _mm256_castpd_si256(_mm256_cmp_pd( a, b, _CMP_GE_OS ) ); } +#ifdef USE_FMA +FORCE_INLINE FVec mulSub(const FVec& a, const FVec& b, const FVec& c) { return _mm256_fmsub_pd(a, b, c); } +#endif + +#endif + +} // namepsace Details +} // namespace BinSearch diff --git a/include/Type.h b/include/Type.h new file mode 100644 index 000000000..720bfb86f --- /dev/null +++ b/include/Type.h @@ -0,0 +1,221 @@ + #pragma once + +#include +#include +#include + +#include "Portable.h" + +using std::size_t; + +namespace BinSearch { + +enum InstrSet { Scalar, SSE, AVX }; + +#define ALGOENUM(x, b) x, +enum Algos + { +#include "AlgoXCodes.h" + }; +#undef ALGOENUM + +namespace Details { + + template + struct InstrIntTraits; + + template + struct InstrFloatTraits; + + // base class for algorithm supporting the method: + // uint32 scalar(T z) const + template + struct AlgoScalarBase; + + // base class for algorithm supporting the following methods, constants and definitions: + // static const uint32 nElem + // struct Constants; + // void initConstants(Constants& cst) const + // void vectorial(uint32 *pr, const T *pz, const Constants& cst) const + // The function vectorial processes nElem items + template + struct AlgoVecBase; + + template struct IntTraits; + + template <> struct IntTraits + { + typedef uint32 itype; + }; + template <> struct IntTraits + { + typedef uint64 itype; + }; + + template + struct Body + { + template + FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const typename Expr::Constants& cst) + { + e.vectorial(ri, zi, cst); + Body::template iteration(e, ri + D, zi + D, cst); + } + + }; + + template <> + struct Body<0> + { + template + FORCE_INLINE static void iteration(const Expr& e, uint32 *ri, const T* zi, const H&) + { + } + }; + + template + struct Loop + { + typedef Algo algo_type; + static const uint32 M = 4; + static const uint32 D = algo_type::nElem; + + FORCE_INLINE static void loop(const algo_type& e, uint32 *ri, const T* zi, uint32 n) + { + typename algo_type::Constants cst; + e.initConstants(cst); + + uint32 j = 0; + while (j + (D*M) <= n) { + Details::Body::template iteration(e, ri + j, zi + j, cst); + j += (D*M); + } + while (j + D <= n) { + e.vectorial(ri + j, zi + j, cst); + j += D; + } + while (D > 1 && j < n) { + ri[j] = e.scalar(zi[j]); + j += 1; + } + } + }; + + template + struct _Pipeliner + { + template + FORCE_INLINE static void go(const Expr& e, Data* d) + { + e.template run(d); + _Pipeliner::go(e, d); + } + }; + + template + struct _Pipeliner + { + template + FORCE_INLINE static void go(const Expr& e, Data* d) + { + } + }; + + template + struct Pipeliner + { + template + FORCE_INLINE static void go(const Expr& e, Data* d) + { + _Pipeliner::go(e, d); + } + }; + + +#if 1 + template + char is_complete_impl(char (*)[sizeof(T)]); + + template + long is_complete_impl(...); + + template + struct IsComplete + { + static const bool value = sizeof(is_complete_impl(0)) == sizeof(char); + }; +#else + template + std::true_type is_complete_impl(T *); + + std::false_type is_complete_impl(...); + + template + struct IsComplete : decltype(is_complete_impl(std::declval())) {}; +#endif + +template +struct AlgoScalarToVec : AlgoScalarBase +{ + typedef AlgoScalarBase base_t; + + AlgoScalarToVec(const typename base_t::Data& d) : base_t(d) {} + AlgoScalarToVec(const T* px, const uint32 n) : base_t(px, n) {} + + static const uint32 nElem = 1; + + struct Constants + { + }; + + void initConstants(Constants& cst) const + { + } + + FORCE_INLINE + void vectorial(uint32 *pr, const T *pz, const Constants& cst) const + { + *pr = base_t::scalar(*pz); + } +}; + +template +struct conditional { typedef T type; }; + +template +struct conditional { typedef F type; }; + +template +struct CondData +{ + FORCE_INLINE CondData(T x) : v(x) {} + FORCE_INLINE operator const T&() const { return v;} +private: + T v; +}; + +template +struct CondData +{ + FORCE_INLINE CondData(T) {} + FORCE_INLINE operator const T() const { return 0;} +}; + +template +struct BinAlgoBase : Details::conditional< Details::IsComplete>::value + , Details::AlgoVecBase + , Details::AlgoScalarToVec + >::type +{ + typedef typename Details::conditional< Details::IsComplete>::value + , Details::AlgoVecBase + , Details::AlgoScalarToVec + >::type base_t; + + BinAlgoBase(const T* px, const uint32 n) : base_t(px, n) {} + BinAlgoBase(const typename base_t::Data& d) : base_t(d) {} +}; + +} // namespace Details + +} // namespace BinSearch diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..374b58cbf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,6 @@ +[build-system] +requires = [ + "setuptools>=42", + "wheel" +] +build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..883b2e42e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +lion-pytorch +pytest diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..009fd3d94 --- /dev/null +++ b/setup.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import glob +import os + +from setuptools import find_packages, setup + +libs = list(glob.glob("./bitsandbytes/libbitsandbytes*.so")) +libs = [os.path.basename(p) for p in libs] +print("libs:", libs) + + +def read(fname): + return open(os.path.join(os.path.dirname(__file__), fname)).read() + + +setup( + name=f"bitsandbytes", + version=f"0.38.1", + author="Tim Dettmers", + author_email="dettmers@cs.washington.edu", + description="8-bit optimizers and matrix multiplication routines.", + license="MIT", + keywords="gpu optimizers optimization 8-bit quantization compression", + url="https://github.com/TimDettmers/bitsandbytes", + packages=find_packages(), + package_data={"": libs}, + long_description=read("README.md"), + long_description_content_type="text/markdown", + classifiers=[ + "Development Status :: 4 - Beta", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], +) diff --git a/tests/test_autograd.py b/tests/test_autograd.py new file mode 100644 index 000000000..152243efb --- /dev/null +++ b/tests/test_autograd.py @@ -0,0 +1,627 @@ +from itertools import permutations, product + +import pytest +import torch + +import bitsandbytes as bnb + +n = 1 +k = 25 +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() +funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] +str_funcs = ["bmm", "matmul"] +req_grad = [(False, False), (True, False), (True, True), (False, True)] +req_grad_str = ["FF", "TF", "TT", "FT"] +transpose = [(False, False), (False, True), (True, True), (True, False)] +str_transpose = ["FF", "FT", "TT", "TF"] +dtype = [torch.float32, torch.float16] +values = list( + product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose) +) +str_values = list( + product( + dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose + ) +) +names = [ + "dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format( + *vals + ) + for vals in str_values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", + values, + ids=names, +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): + if dim2 > 0: + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + dim4 = dim4 - (dim4 % 16) + for i in range(k): + + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) + target = torch.randn( + size=(dim2, dim4), device="cuda", requires_grad=req_grad[1] + ) + torch.nn.init.xavier_uniform_(B) + + if not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + elif not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t()) + elif transpose[0] and not transpose[1]: + out_torch = funcs[0](A.t(), B) + out_bnb = funcs[1](A.t(), B) + elif transpose[0] and transpose[1]: + out_torch = funcs[0](A.t(), B.t()) + out_bnb = funcs[1](A.t(), B.t()) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx == 0).sum().item() < n * 0.0175 + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx == 0).sum().item() < n * 0.001 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_close( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.02 + torch.testing.assert_close( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) + + # batched matrix multiply + if funcs[0] in [torch.bmm, torch.matmul]: + A = torch.randn( + size=(dim1, dim2, dim3), + device="cuda", + requires_grad=req_grad[0], + ) + B = torch.randn( + size=(dim1, dim3, dim4), + device="cuda", + requires_grad=req_grad[1], + ) + target = torch.randn( + size=(dim1, dim2, dim4), + device="cuda", + requires_grad=req_grad[1], + ) + torch.nn.init.xavier_uniform_(B) + + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx == 0).sum().item() < n * 0.01 + torch.testing.assert_close( + out_bnb, out_torch, atol=0.027, rtol=0.2 + ) + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_close( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.02 + + if funcs[0] in [torch.matmul]: + dim1 = dim1 - (dim1 % 16) + A = torch.randn( + size=(dim1, dim2, dim3), + device="cuda", + requires_grad=req_grad[0], + ) + dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) + target = torch.randn( + size=(dim1, dim2, dim4), + device="cuda", + requires_grad=req_grad[1], + ) + torch.nn.init.xavier_uniform_(B) + + if transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t()) + else: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B) + + n = out_bnb.numel() + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx == 0).sum().item() < n * 0.0175 + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx == 0).sum().item() < n * 0.001 + + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_close( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) + if req_grad[1]: + n = gradB1.numel() + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.02 + + +n = 1 +k = 3 +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() + +dim2.append(0) + +decomp = [0.0, 6.0] +funcs = [(torch.matmul, bnb.matmul), (torch.matmul, bnb.research.switchback_bnb)] +str_funcs = ["matmullt", 'switchback_bnb'] +req_grad = [(False, False), (True, False), (True, True), (False, True)] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + +transpose = [(False, True), (False, False)] +str_transpose = ["NT", "NN"] +dtype = [torch.float16, torch.bfloat16, torch.float32] +has_fp16_weights = [True, False] +has_bias = [True, False] +values = list( + product( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + decomp, + has_fp16_weights, + has_bias + ) +) +str_values = list( + product( + dim1, + dim2, + dim3, + dim4, + str_funcs, + dtype, + req_grad_str, + str_transpose, + decomp, + has_fp16_weights, + has_bias + ) +) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias", + values, + ids=names, +) +def test_matmullt( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + decomp, + has_fp16_weights, + has_bias +): + if not torch.cuda.is_available(): pytest.skip('No GPU found.') + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") + if has_bias == False: + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(k): + + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn( + size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype + ) + if decomp == 6.0: + with torch.no_grad(): + A[:, outlier_dim] = 6.0 + B = torch.randn( + size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype + ) + target = torch.randn( + size=(dim2, dim4), + device="cuda", + requires_grad=req_grad[1], + dtype=dtype, + ) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() + torch.nn.init.xavier_uniform_(B) + B2 = B.clone() + + state = bnb.MatmulLtState() + state.threshold = decomp + state.has_fp16_weights = has_fp16_weights + if not has_fp16_weights: + if not transpose[0] and not transpose[1]: + B2 = B2.t().contiguous() + ( + state.CB, + CBt, + state.SCB, + SCBt, + coo_tensorB, + ) = bnb.functional.double_quant(B2.to(torch.float16)) + B2 = state.CB + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B2, state=state, bias=bias2) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B2.t(), state=state, bias=bias2) + + if has_bias: + out_torch += bias + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).mean().item() + # print(f'abs error {err:.4f}') + + idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) + assert (idx == 0).sum().item() <= n * (0.0175 if dtype == torch.float16 else 0.021) + idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) + assert (idx == 0).sum().item() <= n * 0.001 + + if has_fp16_weights: + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss( + out_bnb, target + ).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias1 = bias.grad + bias.grad = None + + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias2 = bias.grad + bias.grad = None + + if req_grad[0]: + torch.testing.assert_close( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) + if req_grad[1]: + n = gradB1.numel() + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + + assert (idx == 0).sum().item() <= n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() <= n * 0.02 + torch.testing.assert_close( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) + + if req_grad[2]: + torch.testing.assert_close(gradBias1, gradBias2) + + +n = 1 +k = 3 +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() + +dim2.append(0) + +funcs = [(torch.matmul, bnb.matmul_4bit)] +str_funcs = ["matmul"] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + +transpose = [(False, True), (False, False)] +str_transpose = ["NT", "NN"] +dtype = [torch.float16, torch.float32] +compress_statistics = [False, True] +has_fp16_weights = [True, False] +has_bias = [True, False] +quant_type = ['fp4', 'nf4'] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose, has_bias, compress_statistics, quant_type)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}_compress_statistics_{}_quant_type_{}".format(*vals) for vals in str_values] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.skip('bitsandbytes 4-bit beta feature') +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type", values, ids=names) +def test_matmul_4bit( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias, compress_statistics, quant_type): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + if has_bias == False: + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(k): + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + bias = None + bias2 = None + if has_bias: + bias = torch.randn(dim4, device='cuda', dtype=dtype, requires_grad=req_grad[2]) + bias2 = bias.clone() + torch.nn.init.xavier_uniform_(B) + + B2, quant_state = bnb.functional.quantize_4bit(B, compress_statistics=compress_statistics, quant_type=quant_type) + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B2.t(), quant_state, bias=bias2) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B2, quant_state, bias=bias2) + + if has_bias: + out_torch += bias + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.115 + + #assert err < 0.20 + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias1 = bias.grad + bias.grad = None + + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + if has_bias: + gradBias2 = bias.grad + bias.grad = None + + if req_grad[0]: + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + + if req_grad[2]: + torch.testing.assert_close(gradBias1, gradBias2) + + +funcs = [(torch.matmul, bnb.research.matmul_fp8_mixed), (torch.matmul, bnb.research.matmul_fp8_global)] +str_funcs = ["matmul_fp8_mixed", 'matmul_fp8_global'] +req_grad = list(product([True, False], repeat=3)) +req_grad_str = [] +for c in req_grad: + strval = '' + for v in c: + if v == True: strval += 'T' + else: strval += 'F' + req_grad_str.append(strval) + +transpose = [(False, True), (False, False)] +str_transpose = ["NT", "NN"] +dtype = [torch.float16, torch.float32] +has_fp16_weights = [True, False] +values = list(product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose)) +str_values = list(product(dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose)) +names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}".format(*vals) for vals in str_values] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize( "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) +def test_matmul_fp8( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): + dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) + dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) + req_grad = list(req_grad) + req_grad[2] = False + + for i in range(k): + # normal multiply + if funcs[0] in [torch.mm, torch.matmul]: + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype) + target = torch.randn(size=(dim2, dim4), device="cuda", requires_grad=req_grad[1], dtype=dtype) + + torch.nn.init.xavier_uniform_(B) + + fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to(A.device) + bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to(A.device) + + if not transpose[0] and transpose[1]: + out_torch = funcs[0](A, B.t()) + out_bnb = funcs[1](A, B.t(), fw_code, bw_code) + elif not transpose[0] and not transpose[1]: + out_torch = funcs[0](A, B) + out_bnb = funcs[1](A, B, fw_code, bw_code) + + assert out_bnb.dtype == A.dtype, f"bnb matmullt received {A.dtype} but returned {out_bnb.dtype}" + + n = out_bnb.numel() + err = torch.abs(out_bnb - out_torch).float().mean().item() + if n > 0: + assert err < 0.115 + #assert err < 0.20 + if any(req_grad): + out_bnb.data.copy_(out_torch) + torch.cuda.synchronize() + loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb.backward() + gradA1 = A.grad + gradB1 = B.grad + A.grad = None + B.grad = None + + loss_torch = torch.nn.functional.mse_loss( out_torch, target ).mean() + loss_torch.backward() + gradA2 = A.grad + gradB2 = B.grad + A.grad = None + B.grad = None + + if req_grad[0]: + torch.testing.assert_close( gradA1, gradA2, atol=0.015, rtol=0.1) + + if req_grad[1]: + n = gradB1.numel() + if dim2 > 0: + assert torch.abs(gradB1).sum() > 0.0 + assert torch.abs(gradB2).sum() > 0.0 + else: + assert torch.abs(gradB1).sum() == 0.0 + assert torch.abs(gradB2).sum() == 0.0 + idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) + + assert (idx == 0).sum().item() <= n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() <= n * 0.02 + grad_err = (gradB1-gradB2).abs().mean() + assert grad_err.item() < 0.003 + torch.testing.assert_close( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) + diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py new file mode 100644 index 000000000..4973da50d --- /dev/null +++ b/tests/test_cuda_setup_evaluator.py @@ -0,0 +1,40 @@ +import os +from typing import List, NamedTuple + +import pytest + +import bitsandbytes as bnb +from bitsandbytes.cuda_setup.main import ( + determine_cuda_runtime_lib_path, + evaluate_cuda_setup, + extract_candidate_paths, +) + + +def test_cuda_full_system(): + ## this only tests the cuda version and not compute capability + + # if CONDA_PREFIX exists, it has priority before all other env variables + # but it does not contain the library directly, so we need to look at the a sub-folder + version = "" + if "CONDA_PREFIX" in os.environ: + ls_output, err = bnb.utils.execute_and_return(f'ls -l {os.environ["CONDA_PREFIX"]}/lib/libcudart.so.11.0') + major, minor, revision = (ls_output.split(" ")[-1].replace("libcudart.so.", "").split(".")) + version = float(f"{major}.{minor}") + + if version == "" and "LD_LIBRARY_PATH" in os.environ: + ld_path = os.environ["LD_LIBRARY_PATH"] + paths = ld_path.split(":") + version = "" + for p in paths: + if "cuda" in p: + idx = p.rfind("cuda-") + version = p[idx + 5 : idx + 5 + 4].replace("/", "") + version = float(version) + break + + + assert version > 0 + binary_name, cudart_path, cuda, cc, cuda_version_string = evaluate_cuda_setup() + binary_name = binary_name.replace("libbitsandbytes_cuda", "") + assert binary_name.startswith(str(version).replace(".", "")) diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100644 index 000000000..40abaa23b --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,2514 @@ +import math +import random +import time +from itertools import product + +import einops +import pytest +import torch +import numpy as np + +import bitsandbytes as bnb +from bitsandbytes import functional as F +from scipy.stats import norm + +torch.set_printoptions( + precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 +) +k = 20 + + +def assert_all_approx_close(a, b, rtol=1e-3, atol=1e-3, count=0, throw=True): + idx = torch.isclose(a, b, rtol, atol) + sumval = (idx == 0).sum().item() + if sumval > count: + if throw: + print(f"Too many values not close: assert {sumval} < {count}") + torch.testing.assert_close(a, b, rtol, atol) + + return sumval + + +class FFN(torch.nn.Module): + def __init__(self, input_features, hidden_size, bias=True): + super().__init__() + self.fc1 = torch.nn.Linear(input_features, hidden_size, bias=bias) + self.fc2 = torch.nn.Linear(hidden_size, input_features, bias=bias) + + with torch.no_grad(): + torch.nn.init.xavier_uniform_(self.fc1.weight) + torch.nn.init.xavier_uniform_(self.fc2.weight) + + def forward(self, x): + x = torch.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +class Timer: + def __init__(self): + self.starts = {} + self.ends = {} + self.agg = {} + + def tick(self, name="default"): + if name not in self.starts: + self.starts[name] = torch.cuda.Event(enable_timing=True) + self.ends[name] = torch.cuda.Event(enable_timing=True) + self.starts[name].record() + else: + ms = self.tock(name, evict=True, print_ms=False) + + def tock(self, name="default", evict=True, print_ms=True): + if name in self.ends: + self.ends[name].record() + torch.cuda.synchronize() + ms = self.starts[name].elapsed_time(self.ends[name]) + if name not in self.agg: + self.agg[name] = 0.0 + self.agg[name] += ms + if evict: + self.starts.pop(name) + self.ends.pop(name) + + if print_ms and name in self.agg: + print(f"{name} took: {self.agg[name] / 1000.0:.5f}s") + + return self.agg[name] + + def reset(self): + self.starts = {} + self.ends = {} + self.agg = {} + print("Resetting benchmark data") + + +def setup(): + pass + + +def teardown(): + pass + + +@pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16], ids=["float", "half"] +) +def test_estimate_quantiles(dtype): + A = torch.rand(1024, 1024, device="cuda") + A = A.to(dtype) + code = F.estimate_quantiles(A) + + percs = torch.linspace(1 / 512, 511 / 512, 256, device=A.device) + torch.testing.assert_close(percs, code, atol=1e-3, rtol=1e-2) + + A = torch.randn(1024, 1024, device="cuda") + A = A.to(dtype) + code = F.estimate_quantiles(A) + + quantiles = torch.quantile(A.float(), percs) + diff = torch.abs(code - quantiles) + assert (diff > 5e-02).sum().item() == 0 + + +def test_quantile_quantization(): + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1 - A2).mean().item() + assert diff < 0.0075 + + A1 = torch.rand(1024, 1024, device="cuda") + code = F.estimate_quantiles(A1) + C = F.quantize_no_absmax(A1, code) + A2 = F.dequantize_no_absmax(C, code) + diff = torch.abs(A1 - A2).mean().item() + torch.testing.assert_close(A1, A2, atol=5e-3, rtol=0) + assert diff < 0.001 + + +def test_dynamic_quantization(): + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diff.mean().item() < 0.0135 + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs)) + + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize(A1) + A2 = F.dequantize(C, S) + diff = torch.abs(A1 - A2).mean().item() + torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + assert diff < 0.004 + + + +@pytest.mark.parametrize("nested", [False, True], ids=["False", "True"]) +@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) +def test_dynamic_blockwise_quantization(nested, blocksize): + #print('') + diffs = [] + reldiffs = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.011 + assert relerr < 0.018 + #print('nested=', nested, 'randn', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'randn', blocksize, sum(reldiffs)/len(reldiffs)) + + diffs = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, S = F.quantize_blockwise(A1, blocksize=blocksize, nested=nested) + A2 = F.dequantize_blockwise(C, S) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + #torch.testing.assert_close(A1, A2, atol=1e-2, rtol=0) + abserr = sum(diffs)/len(diffs) + relerr = sum(reldiffs)/len(reldiffs) + assert abserr < 0.0035 + assert relerr < 0.015 + #print('nested=', nested, 'rand', blocksize, sum(diffs)/len(diffs)) + #print('nested=', nested, 'rand', blocksize, sum(reldiffs)/len(reldiffs)) + + + +@pytest.mark.parametrize( + "gtype", [torch.float32, torch.float16], ids=["float", "half"] +) +def test_percentile_clipping(gtype): + gnorm_vec1 = torch.zeros(100, device="cuda") + gnorm_vec2 = torch.zeros(100, device="cuda") + n = 4 + step = 0 + percentile = 5 + for i in range(k): + step += 1 + g = torch.randn(n, n, dtype=gtype, device="cuda") + gnorm1, clip2, gnorm_scale = F.percentile_clipping( + g, gnorm_vec2, step, percentile=percentile + ) + assert gnorm_scale == 1.0 if gnorm1 < clip2 else clip2 / gnorm1 + + gnorm2 = torch.norm(g.float()) + if step == 1: + gnorm_vec1[:] = gnorm2 + else: + gnorm_vec1[step % 100] = gnorm2 + + vals, idx = torch.sort(gnorm_vec1) + clip1 = vals[percentile] + + torch.testing.assert_close(gnorm_vec1, torch.sqrt(gnorm_vec2)) + torch.testing.assert_close(clip1, clip2) + torch.testing.assert_close(gnorm1, gnorm2) + + +def quant(x): + max1 = torch.abs(x).max() + x = torch.round(x / max1 * 127) + return max1, x.to(torch.int8) + + +def dequant(c, maxC): + return c.float() * (maxC / 127) + + +def mm_dequant(maxA, maxB, C): + return C.float() * (maxA / 127) * (maxB / 127) + + +def quant_multi(x, dim): + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + max1[max1 == 0] = 1.0 + x = torch.round(x / max1 * 127) + return max1, x.to(torch.int8) + + +def quant_multi_chunk(x, dim, chunk_size=32): + if dim == 1: + x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size) + max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True) + max1 = torch.tile(max1, (1, 1, x.shape[1])) + max1 = max1.view(x.shape) + elif dim == 0: + x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size) + max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True) + max1 = torch.tile(max1, (x.shape[0], 1, 1)) + max1 = max1.view(x.shape) + max1[max1 == 0] = 1.0 + x = torch.round(x / max1 * 127) + return max1, x.to(torch.int8) + + +def quant_minmax(A): + minA = A.min() + maxA = A.max() + + +def mean(xx): + return sum(xx) / float(len(xx)) + + +# dim1 = torch.randint(1,1024*4, size=(4,)).tolist() +# dim2 = torch.randint(1,1024*4, size=(4,)).tolist() +dim1 = [1024 * 2] +dim2 = [1024 * 16] +methods = [ + ( + lambda x, dim: quant(x), + lambda x, dim: quant(x), + dequant, + dequant, + mm_dequant, + ) +] +methods.append((quant_multi, quant_multi, dequant, dequant, mm_dequant)) +# methods.append((lambda x: quant_multi_chunk(x, dim=-1), lambda x: quant_multi_chunk(x, dim=0), dequant, dequant, mm_dequant)) +method_names = ["linear", "vectorwise"] +batched = [False, True] +values = list(product(dim1, dim2, methods, batched)) +values_names = list(product(dim1, dim2, method_names, batched)) +names = [ + "dim1_{}_dim2_{}_quant_{}_batched_{}".format(*vals) + for vals in values_names +] + + +@pytest.mark.parametrize( + "dim1, dim2, quant_methods, batched", values, ids=names +) +def test_approx_igemm(dim1, dim2, quant_methods, batched): + dim1 = dim1 - (dim1 % 32) + dim2 = dim2 - (dim2 % 32) + errors = [] + relerrors = [] + #print("") + for i in range(5): + if batched: + A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda") + B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda") + maxA, Ac = quant_methods[0](A, 2) + maxB, Bc = quant_methods[1](B, 1) + else: + A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda") + B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda") + maxA, Ac = quant_methods[0](A, 1) + maxB, Bc = quant_methods[1](B, 0) + torch.testing.assert_close( + quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05 + ) + if batched: + out2 = torch.bmm(A, B) + C = torch.bmm(Ac.float(), Bc.float()) + else: + out2 = torch.mm(A, B) + C = F.igemm(Ac, Bc) + out = quant_methods[4](maxA, maxB, C) + std = out2.std() + out /= std + out2 /= std + err = torch.abs(out - out2) + relerr = err / torch.abs(out2) + errors.append(err.mean().item()) + relerrors.append(relerr.mean().item()) + #print(mean(errors)) + #print(mean(relerrors)) + + +def test_stable_embedding(): + layer = bnb.nn.StableEmbedding(1024, 1024) + layer.reset_parameters() + + +n = 2 +hidden_dim = torch.randint(32, 256, size=(n,)).tolist() +batch_dim = torch.randint(16, 256, size=(n,)).tolist() +seq_dim = torch.randint(16, 256, size=(n,)).tolist() +transpose = [(False, False), (False, True), (True, False), (True, True)] +values = list(product(hidden_dim, batch_dim, transpose, seq_dim)) +names = [ + "hidden_dim_{}_batch_dim_{},transpose_{}_seq_dim_{}".format(*vals) + for vals in values +] + + +@pytest.mark.parametrize( + "hidden_dim, batch_dim, transpose, seq_dim", values, ids=names +) +def test_igemm(hidden_dim, batch_dim, transpose, seq_dim): + hidden_dim = hidden_dim - (hidden_dim % 32) + batch_dim = batch_dim - (batch_dim % 16) + seq_dim = seq_dim - (seq_dim % 16) + for i in range(k): + shapeA = ( + (batch_dim, hidden_dim) + if not transpose[0] + else (hidden_dim, batch_dim) + ) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) + if transpose[1] + else (hidden_dim, 32 * random.randint(1, 4)) + ) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) + if not transpose[0] and not transpose[1]: + out2 = torch.matmul(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.matmul(A.float(), B.t().float()) + out = F.igemm(A, B.t()) + elif transpose[0] and not transpose[1]: + out2 = torch.matmul(A.t().float(), B.float()) + out = F.igemm(A.t(), B) + elif transpose[0] and transpose[1]: + out2 = torch.matmul(A.t().float(), B.t().float()) + out = F.igemm(A.t(), B.t()) + + torch.testing.assert_close(out.float(), out2) + + for i in range(k): + shapeA = (batch_dim, seq_dim, hidden_dim) + shapeB = ( + (32 * random.randint(1, 4), hidden_dim) + if transpose[1] + else (hidden_dim, 32 * random.randint(1, 4)) + ) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) + if not transpose[0] and not transpose[1]: + out2 = torch.matmul(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.matmul(A.float(), B.t().float()) + out = F.igemm(A, B.t()) + + torch.testing.assert_close(out.float(), out2) + + +n = 3 +seq_dim = torch.randint(32, 512, size=(n,)).tolist() +hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() +batch_dim = torch.randint(2, 16, size=(n,)).tolist() +values = list(product(seq_dim, hidden_dim, batch_dim)) +names = [ + "seq_dim{}_hidden_dim{}_batch_dim{}".format(*vals) for vals in values +] + + +@pytest.mark.parametrize("seq_dim, hidden_dim, batch_dim", values, ids=names) +def test_dim3_igemm(seq_dim, hidden_dim, batch_dim): + seq_dim = seq_dim - (seq_dim % 32) + hidden_dim = hidden_dim - (hidden_dim % 32) + batch_dim = batch_dim - (batch_dim % 2) + for i in range(25): + A = torch.randint( + -128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda" + ).to(torch.int8) + B = torch.randint( + -128, 127, size=(batch_dim, seq_dim, 1024), device="cuda" + ).to(torch.int8) + out2 = torch.einsum("bsi, bso->io", A.float(), B.float()) + iout = torch.empty( + A.shape[2], B.shape[2], dtype=torch.int32, device=A.device + ) + out = F.igemm(A, B, out=iout) + + torch.testing.assert_close(out.float(), out2) + + +n = 2 +seq_dim = torch.randint(32, 512, size=(n,)).tolist() +hidden_dim = torch.randint(32, 1024 * 4, size=(n,)).tolist() +batch_dim = torch.randint(2, 16, size=(n,)).tolist() +transpose = [False, True] +values = list(product(seq_dim, hidden_dim, batch_dim, transpose)) +names = [ + "seq_dim={}_hidden_dim={}_batch_dim={}_transpose{}".format(*vals) + for vals in values +] + + +@pytest.mark.parametrize( + "seq_dim, hidden_dim, batch_dim, transpose", values, ids=names +) +def test_minmax_igemm(seq_dim, hidden_dim, batch_dim, transpose): + def min_max(x): + maxA = torch.amax(x, dim=2, keepdim=True) + minA = torch.amin(x, dim=2, keepdim=True) + scale = (maxA - minA) / 2.0 + return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale + + seq_dim = seq_dim - (seq_dim % 16) + hidden_dim = hidden_dim - (hidden_dim % 16) + batch_dim = batch_dim - (batch_dim % 2) + errs = [] + relerrs = [] + errs2 = [] + relerrs2 = [] + for i in range(k): + A = torch.normal( + 0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda" + ) + if transpose: + B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda") + else: + B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda") + Ac, minA, scale = min_max(A) + if transpose: + maxB, Bc = quant_multi(B, dim=(1 if transpose else 0)) + out = F.igemm(Ac, Bc.t()) + out2 = torch.matmul(A, B.t()) + offset = B.t().sum(0) * (minA + scale) + out = out.float() + out = (out * maxB.t() * scale / (127 * 127)) + offset + + maxA, Ac = quant_multi(A, dim=2) + out3 = F.igemm(Ac, Bc.t()) + out3 = mm_dequant(maxA, maxB.t(), out3) + else: + maxB, Bc = quant_multi(B, dim=0) + offset = B.sum(0) * (minA + scale) + out = F.igemm(Ac, Bc) + out2 = torch.matmul(A, B) + out = out.float() + out = (out * maxB * scale / (127 * 127)) + offset + + maxA, Ac = quant_multi(A, dim=2) + out3 = F.igemm(Ac, Bc) + out3 = mm_dequant(maxA, maxB, out3) + + std = out2.std() + out2 /= std + out /= std + out3 /= std + + err = torch.abs(out - out2) + relerr = err / (torch.abs(out2) + 1e-7) + + err2 = torch.abs(out3 - out2) + relerr2 = err2 / (torch.abs(out2) + 1e-7) + + errs.append(err.mean().item()) + relerrs.append(relerr.mean().item()) + errs2.append(err2.mean().item()) + relerrs2.append(relerr2.mean().item()) + # print(mean(errs)) + # print(mean(relerrs)) + # print(mean(errs2)) + # print(mean(relerrs2)) + assert mean(errs) < 0.015 + assert mean(relerrs) < 0.3 + + +n = 2 +dim1 = torch.randint(1, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 128, size=(n,)).tolist() +dim3 = torch.randint(32, 256, size=(n,)).tolist() +dim4 = torch.randint(32, 256, size=(n,)).tolist() +transpose = [(False, False), (True, False), (False, True), (True, True)] +values = list(product(dim1, dim2, dim3, dim4, transpose)) +names = [ + "dim1_{}_dim2_{}_dim3_{}_dim4_{}_transpose_{}".format(*vals) + for vals in values +] + + +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, transpose", values, ids=names) +def test_ibmm(dim1, dim2, dim3, dim4, transpose): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + dim4 = dim4 - (dim4 % 16) + for i in range(k): + shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3) + shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4) + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8) + + if not transpose[0] and not transpose[1]: + out2 = torch.bmm(A.float(), B.float()) + out = F.igemm(A, B) + elif not transpose[0] and transpose[1]: + out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float()) + out = F.igemm(A, B.permute([0, 2, 1])) + elif transpose[0] and not transpose[1]: + out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float()) + out = F.igemm(A.permute([0, 2, 1]), B) + elif transpose[0] and transpose[1]: + out2 = torch.bmm( + A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float() + ) + out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1])) + torch.testing.assert_close(out.float(), out2.float()) + + +n = 1 +dim1 = torch.randint(1, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 128, size=(n,)).tolist() +dim3 = torch.randint(32, 256, size=(n,)).tolist() +values = list(product(dim1, dim2, dim3)) +names = ["dim1_{}_dim2_{}_dim3_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim2, dim3", values, ids=names) +def test_vector_quant(dim1, dim2, dim3): + dim2 = dim2 - (dim2 % 16) + dim3 = dim3 - (dim3 % 16) + for i in range(k): + A = torch.randn(size=(dim2, dim3), device="cuda") + qA, SA = F.vectorwise_quant(A, dim=0) + A1 = F.vectorwise_dequant(qA, SA) + n = A1.numel() + assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n*0.002)) + + + + +n = 2 +dim1 = torch.randint(2, 256, size=(n,)).tolist() +dim2 = torch.randint(2, 256, size=(n,)).tolist() +dim3 = torch.randint(2, 256, size=(n,)).tolist() +# dim1, dim2 = (256,), (256,) +dtype = [torch.int8, torch.int32] +a_order = ["row"] +out_order = ["col", "row", "col32"] +transpose = [False] +dims = [2, 3] +values = list(product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose)) + +names = ["dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_transpose_{}".format(*vals)for vals in values] + + +@pytest.mark.parametrize("dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",values,ids=names) +def test_nvidia_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + if dims == 3 and out_order != "col32": + return + if dtype == torch.int32 and out_order != "col32": + return + func = F.get_transform_func(dtype, orderA, orderOut, transpose) + + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim2), device="cuda").to(dtype) + elif dims == 3: + A = torch.randint(-128, 127, size=(dim1, dim2, dim3), device="cuda").to( + dtype + ) + + out, S = F.nvidia_transform(A, to_order=orderOut) + + if orderOut == "row": + torch.testing.assert_close(A.flatten(), out.flatten()) + elif orderOut == "col": + torch.testing.assert_close(A.t().flatten(), out.flatten()) + elif orderOut == "col32": + if dims == 2: + n = A.shape[0] * (A.shape[1] + (32 - (A.shape[1] % 32))) + elif dims == 3: + n = ( + A.shape[0] + * A.shape[1] + * (A.shape[2] + (32 - (A.shape[2] % 32))) + ) + assert out.numel() == n + elif orderOut == "col_turing": + # 32 col 8 row tiles + n = (A.shape[0] + (8 - A.shape[0] % 8)) * ( + A.shape[1] + (32 - (A.shape[1] % 32)) + ) + assert out.numel() == n + total_coltile = (A.shape[1] // 32) + (1 if A.shape[1] % 32 != 0 else 0) + for row in range(A.shape[0]): + for col in range(A.shape[1]): + i = row * A.shape[1] + j = col + + coltile = (col // 32) + (1 if col % 32 != 0 else 0) + rowtile = ( + (row // 8) + (1 if row % 8 != 0 else 0) + ) * total_coltile + offset = 32 * 8 * (rowtile + coltile) + col2 = col % 32 + row2 = (row % 8) * 32 + + assert A.flatten()[i + j] == A[row, col] + # assert A.flatten()[i+j] == out.flatten()[row2+col2] + # torch.testing.assert_close(A.flatten()[i+j], A[row, col]) + # torch.testing.assert_close(A.flatten()[i+j], out.flatten()[row2+ col2+block_offset]) + + if orderOut == "col32": + out2, S = F.nvidia_transform( + out, from_order=orderOut, to_order="row", state=S + ) + torch.testing.assert_close(A, out2) + + +n = 1 +dim1 = torch.randint(1, 256, size=(n,)).tolist() +dim2 = torch.randint(32, 512, size=(n,)).tolist() +dim3 = torch.randint(32, 1024, size=(n,)).tolist() +dim4 = torch.randint(32, 1024, size=(n,)).tolist() + +# dim1 = [2] +# dim2 = [2] +# dim3 = [2] +# dim4 = [2] + +dims = (2, 3) +ldb = [0] +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dim3, dim4, dims, ldb)) +names = [ + "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}_ldb_{}".format(*vals) + for vals in values +] + + +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims, ldb", values, ids=names) +def test_igemmlt_int(dim1, dim2, dim3, dim4, dims, ldb): + for i in range(k): + if dims == 2: + A = torch.randint(-128, 127, size=(dim1, dim3), device="cuda").to( + torch.int8 + ) + elif dims == 3: + A = torch.randint( + -128, 127, size=(dim1, dim2, dim3), device="cuda" + ).to(torch.int8) + B = torch.randint(-128, 127, size=(dim4, dim3), device="cuda").to( + torch.int8 + ) + C1 = torch.matmul(A.float(), B.t().float()) + + A2, SA = F.transform(A, "col32") + B2, SB = F.transform(B, "col_turing") + C2, SC = F.igemmlt(A2, B2, SA, SB) + C3, S = F.nvidia_transform(C2, "row", state=SC) + torch.testing.assert_close(C1, C3.float()) + + # transpose + B = torch.randint(-128, 127, size=(dim3, dim4), device="cuda").to( + torch.int8 + ) + C1 = torch.matmul(A.float(), B.float()) + + B2t, SBt = F.transform(B, "col_turing", transpose=True) + C2, SC = F.igemmlt(A2, B2t, SA, SBt) + C3, S = F.nvidia_transform(C2, "row", state=SC) + torch.testing.assert_close(C1, C3.float()) + + +dim1 = [32] +dim2 = [32] +dim3 = [32] +dim4 = [32] + +dims = (2,) +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dim3, dim4, dims)) +names = [ + "dim1_{}_dim2_{}_dim3_{}_dim4_{}_dims_{}".format(*vals) + for vals in values +] + + +@pytest.mark.parametrize("dim1, dim2, dim3, dim4, dims", values, ids=names) +def test_igemmlt_half(dim1, dim2, dim3, dim4, dims): + formatB = F.get_special_format_str() + for i in range(k): + if dims == 2: + A = torch.normal(0, 0.5, size=(dim1, dim3), device="cuda").half() + elif dims == 3: + A = torch.normal( + 0, 0.5, size=(dim1, dim2, dim3), device="cuda" + ).half() + B = torch.randn((dim4, dim3), device="cuda").half() + torch.nn.init.xavier_uniform_(B) + C1 = torch.matmul(A, B.t()) + C2 = bnb.matmul(A, B.t()) + + A = A.view(-1, A.shape[-1]) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) + C32A, SA = F.transform(CA, "col32") + CxB, SB = F.transform(CB, to_order=formatB) + out1_32, Sout1_32 = F.igemmlt(C32A, CxB, SA, SB) + output = F.mm_dequant(out1_32, Sout1_32, statsAt, statsBt) + + # print('') + # print(output.flatten()[:10]) + # print(C1.flatten()[:10]) + # print(C2.flatten()[:10]) + + # torch.testing.assert_close(C1.view(-1, C1.shape[-1]), output, atol=0.025, rtol=0.05) + + # transpose + # B = torch.randint(-128, 127, size=(dim3, dim4), device='cuda').to(torch.int8) + # C1 = torch.matmul(A.float(), B.float()) + + # B2t, SBt = F.transform2(B, 'col_turing', transpose=True) + # C2, SC = F.igemmlt(A2, B2t, SA, SBt) + # C3, S = F.transform(C2, 'row', state=SC) + # torch.testing.assert_close(C1, C3.float()) + + +batch_size = 2 +seqdim = 512 +# values = [(batch_size, seqdim, 4*1024, 16*1024),(batch_size, seqdim, 5120, 4*5120),(batch_size, seqdim, 12*1024, 4*12*1024)] +values = [ + (batch_size, seqdim, 4 * 1024, 3 * 4 * 1024), + (batch_size, seqdim, 5120, 3 * 5120), + (batch_size, seqdim, 12 * 1024, 4 * 12 * 1024), +] + + +# values = list(product(batch, seq, model, hidden)) +names = [ + "batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values +] + + +@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +def test_bench_8bit_training(batch, seq, model, hidden): + formatB = F.get_special_format_str() + A = torch.randn(batch, seq, model, device="cuda").half() + grad = torch.randn(batch, seq, model, device="cuda").half() + w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half() + w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half() + print("") + + # torch.cuda.synchronize() + ## warmup + # for i in range(100): + # torch.matmul(A, w1.t()) + # torch.cuda.synchronize() + + dtype = torch.int8 + A = A.view(-1, A.shape[-1]).contiguous() + grad = grad.view(-1, grad.shape[-1]).contiguous() + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + + out1 = torch.matmul(A, w1.t()) # fc1 + # out2 = torch.matmul(out1, w2.t())# fc2 + + # d1 = torch.matmul(grad, w2) # delta1 + # d2 = torch.matmul(d1, w1) # delta2 + + # grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2 + # grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1 + + torch.cuda.synchronize() + t16 = time.time() - t0 + print(t16) + + # torch.cuda.empty_cache() + + # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + # CTw1, Sw1 = F.transform2(Cw1, formatB) + # CTw2, Sw2 = F.transform2(Cw2, formatB) + # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + + # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + # C32A, SA = F.transform2(CA, 'col32') + ## fc1 + # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + ##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t) + + ## fc2 + # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + # C32out1, Sout1 = F.transform2(Cout1, 'col32') + # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + ##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t) + + ## delta1 + # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + # C32grad, Sgrad = F.transform2(Cgrad, 'col32') + ##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) + ##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2) + + ## delta2 + # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + # C32d1, Sd1 = F.transform2(Cd1, 'col32') + ##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) + ##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1) + + ## grad1 + # C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + # CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + ##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) + ##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad) + + ## grad2 + # C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + # CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + ##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) + ##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1) + + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + # Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2) + + # CTw1, Sw1 = F.transform2(Cw1, formatB) + # CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # CTw2, Sw2 = F.transform2(Cw2, formatB) + # CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(k): + # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # #CTw1, Sw1 = F.transform2(Cw1, formatB) + # #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + # #CTw1, Sw1 = F.transform2(Cw1, formatB) + + # #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5) + # CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + # #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True) + # #CTw2, Sw2 = F.transform2(Cw2, formatB) + # #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True) + + # C32A, SA = F.transform2(CA, 'col32') + + # # fc1 + # out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype) + # #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + # #print(coo_tensor.nnz) + # #out1sp = F.spmm_coo(coo_tensor, w1.t()) + # #print(w1.t().shape) + # #out1 = out1dn + out1sp + + # # fc2 + # Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1) + # C32out1, Sout1 = F.transform2(Cout1, 'col32') + # out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype) + # #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2) + + # # delta1 + # Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad) + # C32grad, Sgrad = F.transform2(Cgrad, 'col32') + # d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype) + # #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t) + + # # delta2 + # Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1) + # C32d1, Sd1 = F.transform2(Cd1, 'col32') + # d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype) + # #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t) + + # # grad1 + # #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True) + # #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True) + # #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype) + # #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt) + + # ## grad2 + # #C32At, SAt = F.transform2(CAt, 'col32', transpose=True) + # #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True) + # #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype) + # #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t) + + # torch.cuda.synchronize() + # t8 = time.time() - t0 + # print(t8) + + +n = 2 +dim1 = torch.randint(64, 256, size=(n,)).tolist() +dim4 = torch.randint(64, 1024, size=(n,)).tolist() + +#dim1 = [2*1024] +#dim4 = [2*1024] + +#dim1 = [4] +#dim4 = [4] + +dims = (2,) +formatB = ["col_turing", "col_ampere"] +has_bias = [True, False] +values = list(product(dim1, dim4, dims, formatB, has_bias)) +names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names) +def test_dequant_mm(dim1, dim4, dims, formatB, has_bias): + inner = torch.randint(1, 128, size=(1,)).item() + bias = None + if has_bias: bias = torch.randn(dim4, device='cuda', dtype=torch.float16) + formatB = F.get_special_format_str() + for i in range(1): + A = torch.randn(dim1, inner, device="cuda") + B = torch.randn(dim4, inner, device="cuda") + C1 = torch.matmul(A.half(), B.t().half()) + if has_bias: C1 += bias + + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + A2, SA = F.nvidia_transform(A1, "col32") + B2, SB = F.nvidia_transform(B1, formatB) + C2, SC = F.igemmlt(A2, B2, SA, SB) + + C3, S = F.nvidia_transform(C2, "row", state=SC) + C4 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + if has_bias: C4 += bias + + # TODO: is something wrong here? If so, the problem goes deeper + #n = C1.numel() + #p = 0.06 + std = C1.std(0).view(1, -1) + C1 /= std + C4 /= std + #assert_all_approx_close(C1, C4, atol=0.02, rtol=0.1, count=int(n*0.06)) + #assert (count / n < p), f"error in more than {p} of elements: {count}/{n}={count/n}" + + C5 = F.mm_dequant(C2, SC, maxA.flatten(), maxB.flatten(), bias=bias) + #torch.testing.assert_close(C5, C4, atol=0.015, rtol=0.1) + n = C5.numel() + assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01*n)) + + +n = 2 +dim1 = [1 * 1024] +dim2 = [1 * 1024] +# dim1 = torch.randint(1,4*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() + +dims = (2,) +# ldb = list(range(256, 1*1024, 256)) +values = list(product(dim1, dim2, dims)) +names = ["dim1_{}_dim2_{}_dims_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim2, dims", values, ids=names) +def test_colrow_absmax(dim1, dim2, dims): + for i in range(k): + threshold = 3.0 + A = torch.randn(dim1, dim2, device="cuda").half() + A_truncated = A.clone() + A_truncated[torch.abs(A_truncated) >= 3.0] = 0.0 + if dims == 2: + row_stats1, _ = torch.abs(A.float()).max(1) + col_stats1, _ = torch.abs(A.float()).max(0) + row_stats1_trunc, _ = torch.abs(A_truncated.float()).max(1) + col_stats1_trunc, _ = torch.abs(A_truncated.float()).max(0) + else: + assert False + + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( + A, threshold=threshold + ) + + A_blocked = einops.rearrange( + torch.abs(A), + "(rows row_tiles) (cols block_size)-> rows cols row_tiles block_size", + row_tiles=16, + block_size=64 * 4, + ) + nnz_rows1_counts = (torch.abs(A_blocked) >= threshold).sum(3).flatten() + nnz_block_ptr1 = torch.zeros( + nnz_rows1_counts.shape[0] + 1, + dtype=nnz_rows1_counts.dtype, + device=nnz_rows1_counts.device, + ) + nnz_block_ptr1[1:] = nnz_rows1_counts.cumsum(0) + + torch.testing.assert_close(col_stats1_trunc, col_stats2) + torch.testing.assert_close(row_stats1_trunc, row_stats2) + torch.testing.assert_close(nnz_block_ptr1.int(), nnz_block_ptr2) + + row_stats2, col_stats2, nnz_block_ptr2 = F.get_colrow_absmax( + A, threshold=0.0 + ) + + torch.testing.assert_close(col_stats1, col_stats2) + torch.testing.assert_close(row_stats1, row_stats2) + assert nnz_block_ptr2 is None + + +n = 2 +# dim1 = [8*1024] +# dim2 = [4*1024] +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() + +values = list(product(dim1, dim2)) +names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_double_quant(dim1, dim2): + for i in range(k): + A = torch.randn(dim1, dim2, device="cuda").half() + out_col1, Scol = F.vectorwise_quant(A, dim=0) + out_row1, Srow = F.vectorwise_quant(A, dim=1) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + + # max difference is 1 due to rounding differences + torch.testing.assert_close(CA, out_row1, atol=1, rtol=0) + torch.testing.assert_close(CAt, out_col1, atol=1, rtol=0) + + n = CAt.numel() + num_not_close_rows = ( + (torch.isclose(CA, out_row1, atol=1) == 0).sum().item() + ) + num_not_close_cols = ( + (torch.isclose(CAt, out_col1, atol=1) == 0).sum().item() + ) + + # allow for 1:500 error due to rounding differences + min_error = 1 / 500 + if num_not_close_cols > (min_error * n): + print( + f"Min error exceeded {num_not_close_cols} elements are different. Error: {num_not_close_cols/n:.4f}" + ) + assert False + if num_not_close_rows > (min_error * n): + print( + f"Min error exceeded {num_not_close_rows} elements are different. Error: {num_not_close_rows/n:.4f}" + ) + assert False + + torch.testing.assert_close(Srow.flatten().float(), statsA) + torch.testing.assert_close(Scol.flatten().float(), statsAt) + + +n = 4 +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() + +values = list(zip(dim1, dim4, inner)) +names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +def test_integrated_igemmlt(dim1, dim4, inner): + for i in range(k): + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() + + out1 = torch.matmul(A.half(), B.t().half()) + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + A1, maxA = F.vectorwise_quant(A, dim=1) + B1, maxB = F.vectorwise_quant(B, dim=1) + + torch.testing.assert_close(maxA.flatten().float(), stats1a) + torch.testing.assert_close(maxB.flatten().float(), stats2a) + torch.testing.assert_close(C1a, A1, rtol=0, atol=1) + torch.testing.assert_close(C2a, B1, rtol=0, atol=1) + + A2, SA = F.nvidia_transform(C1a, "col32") + B2, SB = F.nvidia_transform(C2a, "col_turing") + outC32, SC = F.igemmlt(A2, B2, SA, SB) + out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + + A2, SA = F.nvidia_transform(A1, "col32") + B2, SB = F.nvidia_transform(B1, "col_turing") + C2, SC = F.igemmlt(A2, B2, SA, SB) + + C3, S = F.nvidia_transform(C2, "row", state=SC) + out3 = F.vectorwise_mm_dequant(C3.float(), maxA, maxB.t()) + + err1 = torch.abs(out1 - out2).mean().item() + err2 = torch.abs(out1 - out3).mean().item() + assert err2 <= err1 * 1.025 + + +n = 6 +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim4 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +inner = torch.randint(1, 4 * 1024, size=(n,)).tolist() + +values = list(zip(dim1, dim4, inner)) +names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.skip("Row scale has some bugs for ampere") +def test_igemmlt_row_scale(dim1, dim4, inner): + formatB = F.get_special_format_str() + err1, err2, err3 = [], [], [] + relerr1, relerr2 = [], [] + scale = 1 + for i in range(k): + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() + torch.nn.init.xavier_uniform_(B) + C1 = torch.matmul(A, B.t()) + + out1 = torch.matmul(A.half(), B.t().half()) + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") + A2, SA = F.nvidia_transform(C1a, "col32") + B2, SB = F.nvidia_transform(CB, formatB) + A1, maxA = F.vectorwise_quant(A, dim=1) + + c = 10.0 * inner * scale + row_scale = torch.ones_like(maxA) / c + outC32, SC = F.igemmlt( + A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale + ) + C3, S = F.nvidia_transform(outC32, "row", state=SC) + maxval = torch.abs(C3).max() + if maxval == 127: + scale = 1.5 + else: + scale = maxval / 120 + out3 = C3 * maxA * absmaxB * c / (127 * 127) + + C4 = torch.matmul(C1a.float(), CB.float().t()) + + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + B2, SB = F.nvidia_transform(C2a, formatB) + outC32, SC = F.igemmlt(A2, B2, SA, SB) + out2 = F.mm_dequant(outC32, SC, stats1a, stats2a) + + CA, SA = F.vectorwise_quant(A, dim=1, quant_type="vector") + CB, SB = F.vectorwise_quant(B, dim=1, quant_type="linear") + + C = torch.matmul(CA.float(), CB.t().float()) + out4 = C * SA * SB / (127 * 127) + # out4 = torch.clip(torch.round(C*SA/c), -127, 127)*c*SB/(127*127) + + # print('='*80) + # print(out1) + # print(out2) + # print(out3) + + # print(out1) + # print(out2) + # print(out3) + err1.append(torch.abs(out1 - out2).mean().item()) + err2.append(torch.abs(out1 - out3).mean().item()) + err3.append(torch.abs(out1 - out4).mean().item()) + + # assert_all_approx_close(C3.float(), torch.round(C4*row_scale), rtol=0, atol=0, count=10) + print("") + print(sum(err1) / len(err1)) + print(sum(err2) / len(err2)) + print(sum(err3) / len(err3)) + + +dim1 = [1024, 2048] +inner = [12288 * 4, 4096 * 4] +dim4 = [12288, 4096] + +values = list(zip(dim1, dim4, inner)) +names = ["dim1_{}_dim4_{}_inner_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim4, inner", values, ids=names) +@pytest.mark.skip("Row scale has some bugs for ampere") +def test_row_scale_bench(dim1, dim4, inner): + err1, err2, err3 = [], [], [] + relerr1, relerr2 = [], [] + scale = 1 + A = torch.randn(dim1, inner, device="cuda").half() + B = torch.randn(dim4, inner, device="cuda").half() + torch.nn.init.xavier_uniform_(B) + # warmpup + for i in range(k): + C1 = torch.matmul(A, B.t()) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + C1 = torch.matmul(A, B.t()) + torch.cuda.synchronize() + print("16", time.time() - t0) + + C1a, C1b, stats1a, stats1b, coo_tensor = F.double_quant(A) + CB, absmaxB = F.vectorwise_quant(B, quant_type="linear") + A2, SA = F.nvidia_transform(C1a, "col32") + B2, SB = F.nvidia_transform(CB, formatB) + A1, maxA = F.vectorwise_quant(A, dim=1) + + c = 10.0 * inner * scale + row_scale = maxA / c + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + outC32, SC = F.igemmlt( + A2, B2, SA, SB, dtype=torch.int8, row_scale=row_scale + ) + torch.cuda.synchronize() + print("row-wise", time.time() - t0) + + C2a, C2b, stats2a, stats2b, coo_tensor = F.double_quant(B) + B2, SB = F.nvidia_transform(C2a, formatB) + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + outC32, SC = F.igemmlt(A2, B2, SA, SB) + torch.cuda.synchronize() + print("vector-wise", time.time() - t0) + + +n = 2 +dim1 = torch.randint(2, 1024, size=(n,)).tolist() +dim2 = torch.randint(2, 1024, size=(n,)).tolist() +# dim1 = [8*1024] +# dim2 = [4*1024] + +dim3 = [0] +dtype = [torch.int8] +a_order = ["row"] +out_order = ["col32", "col_turing", "col_ampere"] +transpose = [False, True] +dims = [2] +values = list( + product(dim1, dim2, dim3, dims, dtype, a_order, out_order, transpose) +) +names = [ + "dim1_{}_dim2_{}_dim3_{}_dims_{}_dtype_{}_orderA_{}_orderOut_{}_{}".format( + *vals + ) + for vals in values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose", + values, + ids=names, +) +def test_transform(dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose): + for i in range(k): + if dims == 2: + A = torch.randint(10, 99, size=(dim1, dim2), device="cuda").to( + dtype + ) + elif dims == 3: + A = torch.randint( + 10, 99, size=(dim1, dim2, dim3), device="cuda" + ).to(dtype) + + A.view(-1)[-1] = -1 + if transpose: + At = A.t().contiguous() + out1, S1 = F.nvidia_transform(At, to_order=orderOut) + else: + out1, S1 = F.nvidia_transform(A, to_order=orderOut) + out2, S2 = F.transform(A, to_order=orderOut, transpose=transpose) + + assert S1[0][0] == S2[0][0] + assert S1[0][1] == S2[0][1] + # print(out1) + # print(out2) + + torch.testing.assert_close(out1, out2) + + +n = 2 +# dim1 = torch.randint(2,1024, size=(n,)).tolist() +# dim2 = torch.randint(2,1024, size=(n,)).tolist() +dim1 = [1] +dim2 = [33] + +dtype = [torch.int8] +# a_order = ['col_turing', 'col_ampere'] +a_order = ["col_turing"] +out_order = ["row"] +values = list(product(dim1, dim2, dtype, a_order, out_order)) +names = [ + "dim1_{}_dim2_{}_dtype_{}_orderA_{}_orderOut_{}".format(*vals) + for vals in values +] + + +def test_overflow(): + formatB = F.get_special_format_str() + print(formatB) + for i in range(2): + a = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) + b = torch.arange(5, 15).cuda().to(torch.int8).view(-1, 1) + + Ca, Sa = F.nvidia_transform(a, "col32") + Cb, Sb = F.nvidia_transform(b, formatB) + + c = F.igemmlt(Ca, Cb, Sa, Sb, dtype=torch.int8) + c2 = torch.matmul(a.float(), b.float().t()) + + +n = 2 +dim1 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 4 * 1024, size=(n,)).tolist() +# dim1 = [4] +# dim2 = [5] + +values = list(product(dim1, dim2)) +names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_coo_double_quant(dim1, dim2): + threshold = 3.00 + for i in range(k): + A = torch.randn(dim1, dim2, device="cuda").half() + + idx = torch.abs(A) >= threshold + CA2, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( + A, threshold=threshold + ) + + if coo_tensor is not None: + A1 = A * idx + A2 = torch.zeros_like(A) + A2[ + coo_tensor.rowidx.long(), coo_tensor.colidx.long() + ] = coo_tensor.values + torch.testing.assert_close(A1, A2) + + A1 = A * (idx == 0) + A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + torch.testing.assert_close( + A * (idx == 0), A2, rtol=0.05, atol=1.5e-2 + ) + + +n = 2 +dim1 = torch.randint(1, 1 * 1024, size=(n,)).tolist() +dim2 = torch.randint(1, 1 * 1024, size=(n,)).tolist() +# dim1 = [7] +# dim2 = [11] +transposed_B = [False, True] +values = list(product(dim1, dim2, transposed_B)) +names = ["dim1_{}_dim2_{}_transposed_B_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim2, transposed_B", values, ids=names) +def test_spmm_coo(dim1, dim2, transposed_B): + threshold = 1.5 + dim3 = torch.randint(32, 128, size=(1,)).item() + # dim3 = 17 + for i in range(k): + A = torch.randn(dim1, dim2).cuda().half() + if transposed_B: + B = torch.randn(dim3, dim2).cuda().half() + else: + B = torch.randn(dim2, dim3).cuda().half() + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx + + if transposed_B: + out2 = F.spmm_coo(cooA, B.t()) + out1 = torch.matmul(A2, B.t()) + else: + out2 = F.spmm_coo(cooA, B) + out1 = torch.matmul(A2, B) + + assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=30) + + +def test_spmm_bench(): + batch = 2 + model = 1024 * 1 + hidden = model * 4 + seq = 1024 + dim1 = batch * seq + dim2 = model + dim3 = hidden + threshold = 4 + A = torch.randn(dim1, dim2, device="cuda").half() + B = torch.randn(dim2, dim3, device="cuda").half() + for i in range(10): + C1 = bnb.matmul(A, B.t()) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + C1 = bnb.matmul(A, B.t()) + torch.cuda.synchronize() + t8 = time.time() - t0 + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + print(nnz / idx.numel()) + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + + for i in range(10): + out2 = F.spmm_coo(cooA, B) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(k): + out2 = F.spmm_coo(cooA, B) + torch.cuda.synchronize() + tsp = time.time() - t0 + print(tsp, t8) + print(tsp / t8) + + +n = 2 +dim1 = torch.randint(256, 1 * 1024, size=(n,)).tolist() +dim2 = torch.randint(256, 1 * 1024, size=(n,)).tolist() +values = list(product(dim1, dim2)) +names = ["dim1_{}_dim2_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim2", values, ids=names) +def test_integrated_sparse_decomp(dim1, dim2): + threshold = 3.0 + formatB = "col_turing" + for i in range(k): + A = torch.randn(dim1, dim2).cuda().half() + w1 = torch.randn(dim1, dim2).cuda().half() + out1 = torch.matmul(A, w1.t()) + + Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) + CTw1, Sw1 = F.transform(Cw1, formatB) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + C32A, SA = F.transform(CA, "col32") + + out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + CA, CAt, statsA, statsAt, coo_tensor = F.double_quant( + A, threshold=threshold + ) + C32A, SA = F.transform(CA, "col32") + + out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1) + out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) + + assert coo_tensor is not None + + out4 = F.spmm_coo(coo_tensor, w1.t()) + out5 = out3 + out4 + + err1 = torch.abs(out1 - out2).mean().item() + err2 = torch.abs(out1 - out5).mean().item() + assert err2 < err1 + + +def test_matmuls(): + a = torch.randn(256, 512).half().cuda() + b = torch.randn(256, 512).half().cuda() + c1 = torch.matmul(a, b.t()) + c2 = bnb.matmul(a, b) + c3 = bnb.matmul_cublas(a, b.t()) + + err1 = torch.abs(c1 - c2).mean().item() + err2 = torch.abs(c1 - c3).mean().item() + assert err1 < 0.2 + assert err2 < 0.2 + print(err1, err2) + + +n = 2 +# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1 * 2048] +dim2 = [12288] +# dim1 = [32] +# dim2 = [32] +# dtype = [torch.float16, torch.int8] +dtype = [torch.float16] +out_function = ["zeros", "ones"] +values = list(product(dim1, dim2, dtype, out_function)) +names = [ + "dim1_{}_dim2_{}_dtype_{}_out_func_{}".format(*vals) for vals in values +] + + +@pytest.mark.parametrize("dim1, dim2, dtype, out_func", values, ids=names) +def test_spmm_coo_very_sparse(dim1, dim2, dtype, out_func): + out_func = getattr(torch, out_func) + + threshold = 3.3 + # threshold = 2.8 + # threshold = 0.0 + A = torch.randn(dim1, dim2, device="cuda").half() + if dtype == torch.float16: + B = torch.randn(dim2, dim2 * 4, device="cuda").half() + torch.nn.init.xavier_uniform_(B) + else: + B = torch.randn(dim2, dim2 * 4, device="cuda").half() + torch.nn.init.xavier_uniform_(B) + B, SB = F.vectorwise_quant(B, quant_type="linear") + # B = torch.randint(-127, 127, size=(dim2, dim2*4), device='cuda').to(torch.int8) + + print("") + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx + out1 = torch.matmul(A2.half(), B.half()) + out = out_func(out1.shape, dtype=torch.float16, device=out1.device) + out1 += out.clone() + out2 = F.spmm_coo_very_sparse(cooA, B, out=out) + # print(B) + # print(out1) + # print(out2) + p = 200 / (2048 * 12288 * 4) + n = out1.numel() + count = math.ceil(p * n) + std = out1.std() + out1 /= std + out2 /= std + assert_all_approx_close( + out1, out2.half(), rtol=0.01, atol=3.0e-2, count=count + ) + # assert_all_approx_close(out1, out2.half(), rtol=0.05, atol=0.01, count=count) + + idx_col = torch.randint(0, A2.shape[-1], size=(15,)) + + # torch.testing.assert_close(out1, out2.half(), rtol=0.05, atol=0.001) + + # Bt = torch.randn(dim2*4, dim2, device='cuda').half() + # torch.cuda.synchronize() + # t0 = time.time() + # print(A2.shape, B.shape) + # for i in range(100): + # #out3 = F.spmm_coo(cooA, Bt.t()) + # #out2 = F.spmm_coo(cooA, B) + # #out2 = F.spmm_coo_very_sparse(cooA, B) + # #out1 = torch.matmul(A, Bt.t()) + + # torch.cuda.synchronize() + # print(time.time() - t0) + + +def test_coo2csr(): + threshold = 1 + A = torch.randn(128, 128).half().cuda() + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx + csrA = F.coo2csr(cooA) + counts = csrA.rowptr[1:] - csrA.rowptr[:-1] + assert counts.numel() == A.shape[0] + + torch.testing.assert_close(counts.long(), (A2 != 0).sum(1)) + idx = A2 != 0 + torch.testing.assert_close(A2[idx], csrA.values) + + +def test_coo2csc(): + threshold = 1 + A = torch.randn(128, 128).half().cuda() + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx + cscA = F.coo2csc(cooA) + counts = cscA.colptr[1:] - cscA.colptr[:-1] + assert counts.numel() == A.shape[1] + + torch.testing.assert_close(counts.long(), (A2 != 0).sum(0)) + # torch uses row-major -> use transpose to transfer to col-major + idx = A2.t() != 0 + torch.testing.assert_close(A2.t()[idx], cscA.values) + + +n = 2 +# dim1 = torch.randint(1,1*1024, size=(n,)).tolist() +# dim2 = torch.randint(1,4*1024, size=(n,)).tolist() +dim1 = [1 * 2048] +# dim2 = [12288] +dim2 = [2048] +# dim1 = [2] +# dim2 = [2] +dtype = [torch.int8] +values = list(product(dim1, dim2, dtype)) +names = ["dim1_{}_dim2_{}_dtype_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim2, dtype", values, ids=names) +def test_spmm_coo_dequant(dim1, dim2, dtype): + threshold = 6.0 + # threshold = 2.8 + # threshold = 0.0 + A = torch.randn(dim1, dim2, device="cuda").half() + B = torch.empty(dim2, dim2 * 4, device="cuda", dtype=torch.float16) + torch.nn.init.xavier_uniform_(B) + Bt = B.t().contiguous() + + CB, CBt, statsB, statsBt, coo_tensor = F.double_quant(B) + + rowidx = torch.randint(0, A.shape[-1], size=(15,)) + + A[:, rowidx] = 8.0 + + idx = torch.abs(A) >= threshold + nnz = (idx == 1).sum().item() + rows, cols = torch.where(idx) + values = A[idx] + cooA = F.COOSparseTensor( + A.shape[0], A.shape[1], nnz, rows.int(), cols.int(), values + ) + A2 = A * idx + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out1 = torch.matmul(A2, B.half()) + out3 = F.spmm_coo_very_sparse(cooA, CBt.half()) + out3 = out3 * statsBt.half() / 127 + + values, counts = torch.unique(cooA.rowidx, return_counts=True) + offset = counts.cumsum(0).int() + max_count, max_idx = torch.sort(counts, descending=True) + print(torch.median(max_count.float())) + + torch.testing.assert_close(out2, out3, rtol=0.05, atol=0.001) + + p = 200 / (2048 * 12288 * 4) + n = out1.numel() + count = math.ceil(p * n) + assert_all_approx_close(out1, out2, rtol=0.01, atol=3.0e-2, count=count) + + # torch.cuda.synchronize() + # t0 = time.time() + # for i in range(100): + # out2 = F.spmm_coo_very_sparse(cooA, B) + # torch.cuda.synchronize() + # print('fp16', time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo(cooA, B) + torch.cuda.synchronize() + print("cusparse fp16", time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo_very_sparse(cooA, CBt) + torch.cuda.synchronize() + print("int8", time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + torch.cuda.synchronize() + print("int8+dequant", time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out2 = torch.matmul(A, B) + torch.cuda.synchronize() + print("matmul", time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + out2 = F.spmm_coo_very_sparse(cooA, CBt, dequant_stats=statsBt) + out = out1 + out2 + torch.cuda.synchronize() + print("sparse+ matmul", time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + torch.matmul(A[:, rowidx], Bt.t()[rowidx], out=out1) + torch.cuda.synchronize() + print("partial matmul", time.time() - t0) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + out1 = bnb.matmul(A, Bt) + torch.cuda.synchronize() + print("partial matmul", time.time() - t0) + + +batch_size = 2 +seqdim = 2048 +values = [] +values.append((batch_size, seqdim, 768, 4 * 768)) +#values.append((batch_size, seqdim, 1024, 4*1024)) +#values.append((batch_size, seqdim, 1536, 4*1536)) +#values.append((batch_size, seqdim, 2048, 4*2048)) +#values.append((batch_size, seqdim, 2560, 4*2560)) +#values.append((batch_size, seqdim, 4096, 4*4096)) +#values.append((batch_size, seqdim, 5140, 4*5140)) +#values.append((batch_size, seqdim, 12288, 4*12288)) +names = ["batch_{}_seq_{}_model_{}_hidden_{}".format(*vals) for vals in values] +@pytest.mark.parametrize("batch, seq, model, hidden", values, ids=names) +def test_bench_matmul(batch, seq, model, hidden): + iters = 1 + formatB = F.get_special_format_str() + + A = torch.randn(batch, seq, model, device="cuda").half() + B = torch.empty(hidden, model, dtype=torch.float16, device="cuda") + torch.nn.init.xavier_uniform_(B) + + B_fp4, state = F.quantize_fp4(B) + B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True) + + B_nf4, state_nf4= F.quantize_nf4(B) + + linear8bit = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit.eval() + + outliers = torch.randint(0, model, size=(5,)).cuda() + A[:, :, outliers] = 8.0 + + linearMixedBit = (bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()) + linearMixedBit.eval() + + linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half() + linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half() + + # warmup + for i in range(iters): + torch.matmul(A, B.t()) + torch.cuda.synchronize() + print("") + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + torch.matmul(A, B.t()) + torch.cuda.synchronize() + print( f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_4bit(A, B_fp4.t(), quant_state=state) + torch.cuda.synchronize() + print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c) + torch.cuda.synchronize() + print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4) + torch.cuda.synchronize() + print( f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" ) + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # bnb.matmul(A, B, threshold=6.0) + #torch.cuda.synchronize() + #print(f"CB -> CxB conversion + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0) + #C32A, SA = F.transform(CA, "col32") + #CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B) + #CxB, SB = F.transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + #torch.cuda.synchronize() + #print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #BA, statsB = F.vectorwise_quant(B, dim=1) + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1) + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # F.vectorwise_mm_dequant(Cout, statsA, statsB.t()) + #torch.cuda.synchronize() + #print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear") + #CxB, SB = F.nvidia_transform(CB, to_order=formatB) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # A2 = A.view(-1, A.shape[-1]).contiguous() + # CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear") + # C32A, SA = F.nvidia_transform(CA, "col32") + # out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB) + # Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32) + # out = Cout * statsB * statsA * (1.0 / (127 * 127)) + #torch.cuda.synchronize() + #print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #linear8bit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #linearMixedBit(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linearMixedBit(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #linear8bit_train(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + + #linear8bit_train_thresh(A) + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # linear8bit_train(A) + #torch.cuda.synchronize() + #print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s") + +def test_zeropoint(): + def quant_zp(x): + dtype = x.dtype + x = x.float() + dyna = x.max() - x.min() + if dyna == 0: + dyna = 1 + qx = 254.0 / dyna + minx = x.min() + # zpx = torch.round(minx* qx) + # zpx = 127 - torch.round(x.max()* qx) + zpx = torch.round(x.min() * qx) - 127 + x = (qx * x) + zpx + return x, qx, zpx + + batch = 2 + seq = 512 + model = 1024 + hidden = 4 * model + A = torch.randn(batch * seq, model, device="cuda").half() * 0.1 + B = torch.randn(model, hidden, device="cuda").half() * 0.1 + + C0 = torch.matmul(A, B) + + # A, SA = F.vectorwise_quant(A, quant_type='linear') + # B, SB = F.vectorwise_quant(B, quant_type='linear') + A = A.float() + B = B.float() + + C1 = torch.matmul(A, B) + C3 = bnb.matmul(A.half(), B.t().contiguous().half()) + + zp = 1 + # C2 = torch.matmul(A-zp, B) + # C2 += B.sum(0).view(1, -1)*zp + C2 = torch.matmul(A, B - zp) + C2 -= A.sum(1).view(-1, 1) * zp + + ca, cqa, cza = quant_zp(A) + print(ca.min(), ca.max()) + print((ca - cza).min(), (ca - cza).max()) + + zp = 1 + scale = 2.0 + C5 = torch.matmul((A * scale) - zp, B) + C5 += B.sum(0) * zp + C5 /= scale + + CA, qa, zpa = quant_zp(A) + C4 = torch.matmul(CA, B) + C4 -= B.sum(0) * zpa + C4 /= qa + + zpb = 1 + zpa = 1 + qa = 2 + qb = 2 + C6 = torch.matmul((A * qa) + zpa, (B * qb) + zpb) + C6 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb) + C6 -= zpa * zpb * A.shape[1] + C6 /= qa * qb + + CA, qa, zpa = quant_zp(A) + CB, qb, zpb = quant_zp(B) + C7 = torch.matmul(CA, CB) + C7 -= (qb * B.sum(0).view(1, -1) * zpa) + (qa * A.sum(1).view(-1, 1) * zpb) + C7 -= zpa * zpb * A.shape[1] + C7 /= qa * qb + + print("") + # print(C0.flatten()[:10]) + print(C1.flatten()[:10]) + print(C2.flatten()[:10]) + print(C3.flatten()[:10]) + print(C5.flatten()[:10]) + print(C6.flatten()[:10]) + print(C7.flatten()[:10]) + err1 = torch.abs(C1 - C2).mean().item() + err2 = torch.abs(C1 - C3).mean().item() + err3 = torch.abs(C1 - C4).mean().item() + err4 = torch.abs(C1 - C5).mean().item() + err5 = torch.abs(C1 - C6).mean().item() + err6 = torch.abs(C1 - C7).mean().item() + print(err1, err2, err3, err4, err5, err6) + + +def test_extract_outliers(): + for i in range(k): + shapeA = (4096, 4096 * 4) + idx = torch.unique(torch.randint(0, shapeA[1], size=(10,)).int()).cuda() + # idx = torch.Tensor([0]).int().cuda() + A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8) + outliers1 = A[:, idx.long()] + + CA, SA = F.transform(A, "col_turing") + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + + torch.testing.assert_close(outliers1, outliers2) + + CA, SA = F.transform(A, "col_ampere") + + outliers2 = F.extract_outliers(CA, SA, idx) + + assert outliers2.shape[0] == shapeA[0] + assert outliers2.shape[1] == idx.numel() + + torch.testing.assert_close(outliers1, outliers2) + + + +def test_blockwise_cpu_large(): + diffs = [] + reldiffs = [] + batch = 128 + seq = 128 + for hidden in [128]:#, 14336]: + for blocksize in [4096, 16384]: + for i in range(2): + A1 = torch.randn(batch, seq, hidden, device='cpu') + t0 = time.time() + C, S = F.quantize_blockwise(A1, blocksize=blocksize) + A2 = F.dequantize_blockwise(C, S, blocksize=blocksize) + print(time.time() - t0) + diff = torch.abs(A1 - A2) + reldiff = diff / torch.abs(A1 + 1e-8) + diffs.append(diff.mean().item()) + reldiffs.append(reldiff.mean().item()) + assert diffs[-1] < 0.011 + # print(sum(diffs)/len(diffs)) + # print(sum(reldiffs)/len(reldiffs)) + + + +def test_fp8_quant(): + for e_bits in range(1, 7): + p_bits = 7-e_bits + code = F.create_fp8_map(True, e_bits, p_bits).cuda() + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.rand(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1, code=code) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + #print(sum(abserr)/len(abserr)) + #print(sum(relerr)/len(relerr)) + + abserr = [] + relerr = [] + for i in range(100): + A1 = torch.randn(1024, 1024, device="cuda") + C, SC = F.quantize_blockwise(A1) + A2 = F.dequantize_blockwise(C, SC) + diff = torch.abs(A1 - A2) + reldiff = diff/torch.abs(A1+1e-8) + abserr.append(diff.mean().item()) + relerr.append(reldiff.mean().item()) + #assert diff < 0.0075 + #print(3, sum(abserr)/len(abserr)) + #print(3, sum(relerr)/len(relerr)) + + +def test_few_bit_quant(): + + #print('') + for bits in range(2, 9): + #print('='*30, bits, '='*30) + for method in ['linear', 'fp8', 'dynamic', 'quantile']: + abserrs = [] + relerrs = [] + code = None + if method == 'linear': + code = F.create_linear_map(True, total_bits=bits).cuda() + elif method == 'fp8': + ebits = math.ceil(bits/2) + pbits = bits-ebits-1 + code = F.create_fp8_map(True, ebits, pbits, bits).cuda() + elif method == 'dynamic': + code = F.create_dynamic_map(True, bits-0, bits).cuda() + elif method == 'quantile': + values = torch.randn(2048, 2048, device='cuda') + code = F.create_quantile_map(values, bits).cuda() + # for some data types we have no zero + # for some data types we have one zero + # for some data types we have two zeros + assert torch.unique(code).numel() in [2**bits, 2**bits-1], f'bits: {bits}, method: {method}' + #print(method, (code==0).sum()) + assert code.numel() == 256 + for i in range(10): + + values = torch.randn(1, 32, device='cuda') + values /= values.abs().max() + #values[values.abs() < 1e-6] += 1e-5 + + q1 = [] + v1 = [] + for v in values[0]: + idx = torch.abs(v-code).argmin() + q1.append(idx.item()) + v1.append(code[idx].item()) + + q1 = torch.Tensor(q1).cuda() + v1 = torch.Tensor(v1).cuda() + + q2, S2 = F.quantize_blockwise(values, code=code) + v2 = F.dequantize_blockwise(q2, S2) + + idx = torch.isclose(q1.int(), q2.int()) + err2 = torch.abs(v2-values) + abserrs.append(err2.mean().item()) + relerrs.append((err2/(1e-10+values).abs()).mean().item()) + if idx.sum(): + # some weird cases + err1 = torch.abs(v1-values).mean() + #assert err2.mean() <= err1 + + else: + torch.testing.assert_close(q1, q2) + #print(method, 'abserr:', sum(abserrs)/len(abserrs), 'relerr:', sum(relerrs)/len(relerrs)) + #assert False + + +def test_kbit_quantile_estimation(): + for i in range(100): + data = torch.randn(1024, 1024, device='cuda') + for bits in range(2, 9): + p = np.linspace(1.3e-4, 1-1.3e-4, 2**bits) + val1 = torch.Tensor(norm.ppf(p)).cuda() + val2 = F.estimate_quantiles(data, offset=0, num_quantiles=2**bits) + err = torch.abs(val1-val2).mean() + assert err < 0.038 + + for i in range(100): + data = torch.randn(1024, 1024, device='cuda') + for bits in range(2, 4): + total_values = 2**bits-1 + p = np.linspace(0, 1, 2*total_values+1) + idx = np.arange(1, 2*total_values+1, 2) + p = p[idx] + offset = 1/(2*total_values) + p = np.linspace(offset, 1-offset, total_values) + val1 = torch.Tensor(norm.ppf(p)).cuda() + val2 = F.estimate_quantiles(data, num_quantiles=2**bits-1) + err = torch.abs(val1-val2).mean() + assert err < 0.035 + + +def test_bench_dequantization(): + a = torch.rand(1024, 1024, device='cuda').half() + code =F.create_fp8_map(True, 3, 0, 4).cuda() + qa, SA = F.quantize_blockwise(a, code=code) + print(qa.max()) + + max_theoretical_mu = 1024*1024*2/1024**3/672*1000*1000 + #print(max_theoretical_mu) + + torch.cuda.synchronize() + t0 = time.time() + for i in range(100): + qa, SA = F.quantize_blockwise(a) + torch.cuda.synchronize() + #print((time.time()-t0)/1e6) + + + +@pytest.mark.skip('bitsandbytes 4-bit beta feature') +def test_fp4_quant(): + vals = list(product([0, 1], repeat=4)) + + code = {} + for bits in vals: + result = 0 + bias = 3 + sign, e1, e2, p1 = bits + idx = sign*8 + e1*4 + e2*2 + p1*1 + sign = -1.0 if sign else 1.0 + exp = e1*2 + e2*1 + if exp == 0: + # sub-normal + if p1 == 0: result = 0 + else: result = sign*0.0625 + else: + # normal + exp = 2**(-exp + bias + 1) + frac = 1.5 if p1 else 1.0 + result = sign*exp*frac + code[idx] = result + + A1 = torch.randn(1024, 1024, device='cuda').half() + qa, SA = F.quantize_fp4(A1, blocksize=64) + A2 = F.dequantize_fp4(qa, SA) + + err = (A1 - A2).abs().float() + relerr = (err/A1.abs().float()).mean() + idx = err > 1.0 + err = err.mean() + + + assert err.item() < 0.1 + assert relerr.item() < 0.28 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.skip('bitsandbytes 4-bit beta feature') +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +def test_4bit_compressed_stats(quant_type): + for blocksize in [128, 64]: + errs1 = [] + errs2 = [] + for i in range(10): + A1 = torch.randn(1024, 1024, device='cuda').half() + q2, SA2 = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type) + q3, SA3= F.quantize_4bit(A1, blocksize=blocksize, compress_statistics=True, quant_type=quant_type) + A2 = F.dequantize_4bit(q2, SA2, quant_type=quant_type) + A3 = F.dequantize_4bit(q3, SA3, quant_type=quant_type) + + + err = (A1 - A2).abs().float() + relerr = (err/(A1.abs().float()+1e-15)).mean() + err = err.mean() + + errs1.append(err.item()) + + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + err = (A1 - A3).abs().float() + relerr = (err/(A1.abs().float()+1e-15)).mean() + err = err.mean() + + errs2.append(err.item()) + + assert err.item() < 0.11 + assert relerr.item() < 0.28 + + #print(sum(errs1)/len(errs1), blocksize, quant_type) + #print(sum(errs2)/len(errs2), blocksize, quant_type) + + + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("quant_type", ['fp4', 'nf4']) +def test_bench_4bit_dequant(quant_type): + blocksize = 256 + a = torch.rand(1024*12*4, 1024*12, device='cuda').half() + qa, SA = F.quantize_4bit(a, blocksize=blocksize, quant_type=quant_type) + + input_size = a.numel()/2 + output_size = a.numel()*2 + num_bytes = input_size+output_size + GB = num_bytes/1e9 + max_theoretical_s = GB/768 + #print(max_theoretical_s*1e6) + b = torch.randn(128, 1024*12, device='cuda').half() + + iters = 5 + torch.cuda.synchronize() + t0 = time.time() + for i in range(iters): + F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type) + #b.copy_(a) + torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6) + + #torch.cuda.synchronize() + #t0 = time.time() + #for i in range(iters): + # torch.matmul(b, a.t()) + #torch.cuda.synchronize() + #print((time.time()-t0)/iters*1e6) + + + +def test_normal_map_tree(): + code = F.create_normal_map() + values =code[:8].tolist() + code[-8:].tolist() + num_pivots = 1 + print(values) + while num_pivots <16: + idx = list(range(16//num_pivots//2, 16, 16//num_pivots)) + print(idx) + num_pivots *= 2 + pivots = [] + for i in idx: + pivots.append((values[i-1]+values[i])/2) + print(pivots) + + +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_cutlass3_gemm(dtype): + debug = True + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + for dim in [4096]: + #for dim in [128+1]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(100): + A = torch.randn(1, dim, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + #B = torch.randn(1, dim, dtype=dtype, device='cuda')/math.sqrt(dim) + + #print('') + #print(A) + #print(B.t()) + #A[:, :-1] = 0 + #B[:, :-1] = 0 + + + C1 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, B.t()) + + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + + errs.append(err) + relerrs.append(relerr) + + #if not debug and err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + # print('') + # print(i, err, relerr) + # print(A.flatten()[-6:]) + # print(B.flatten()[-6:]) + # out = A.flatten()[-6:]*B.flatten()[-6:] + # print(out) + # print(out[:-1].sum()) + # print('='*80) + # print(C1.flatten()[-6:]) + # print(C2.flatten()[-6:]) + # #assert False, 'ERROR' + + c = int(C1.numel()*0.0014*(dim/256))+1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=not debug) + #print(c/math.sqrt(dim)) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) + +#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16']) +@pytest.mark.parametrize("dtype", [torch.float16], ids=['fp16']) +def test_gemm_4bit(dtype): + #for dim in [32, 64, 128, 256, 512, 1024, 2048, 4096]: + #for dim in [4096, 5120, 6656, 8192]: + #for dim in [32]: + for dim in [4096]: + errs = [] + relerrs = [] + max_err = 0 + max_relerr = 0 + for i in range(1): + #A = torch.rand(2, 4092, dtype=dtype, device='cuda') + #B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda') + #A = torch.rand(1, 4096, dtype=dtype, device='cuda') + #B = torch.rand(4*4096, 4096, dtype=dtype, device='cuda') + A = torch.randn(1, dim+0, dtype=dtype, device='cuda') + B = torch.randn(4*dim, dim+0, dtype=dtype, device='cuda')/math.sqrt(dim) + + #print('') + #print(A) + #print(B.t()) + #A[:, :-1] = 0 + #B[:, :-1] = 0 + + qB, state = F.quantize_nf4(B) + F.dequantize_nf4(qB, state) + + C3 = torch.matmul(A, B.t()) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + C1 = bnb.matmul_4bit(A, qB.t(), state) + C2 = F.cutlass3_gemm(A, qB.t(), state=state) + + print(C1.shape, C2.shape) + + # tensor cores are non-deterministic + # so we need to analyze errors around the mean + # to test our implementation + err = torch.abs(C1-C2) + mag = torch.abs(C1)+1e-8 + relerr = err/mag + max_err = max(err.max(), max_err) + max_relerr = max(relerr.max(), max_relerr) + err = err.mean().item() + relerr = relerr.mean().item() + + errs.append(err) + relerrs.append(relerr) + + if err/torch.abs(C1).mean() > 5e-5 or err > 3.2e-5: + print('') + print(i, err, relerr) + print(A.flatten()[-6:]) + print(B.flatten()[-6:]) + out = A.flatten()[-6:]*B.flatten()[-6:] + print(out) + print(out[:-1].sum()) + print('='*80) + print(C1.flatten()[-6:]) + print(C2.flatten()[-6:]) + #assert False, 'ERROR' + + c = int(C1.numel()*0.0014*(dim/256))+1 + + c = assert_all_approx_close(C1, C2, 1e-5, 0.01, count=c, throw=False) + #print(c/math.sqrt(dim)) + print('') + print(dim, sum(errs)/len(errs)/math.sqrt(dim)) + print(dim, sum(relerrs)/len(relerrs)/math.sqrt(dim)) + print(dim, (max_err.item(), max_relerr.item())) + +def test_managed(): + n = 32*10 + A = F.get_paged(n, n, dtype=torch.float32) + B = F.get_paged(n, n, dtype=torch.uint8) + B2 = F.get_paged(n, n, dtype=torch.float32) + assert A.is_paged + assert B.is_paged + assert A.page_deviceid==0 + assert B.page_deviceid==0 + F.fill(A, 17.0) + F.fill(B, 17) + F.fill(B2, 2) + assert (A==17).sum().item() == n*n + assert (B==17).sum().item() == n*n + C = A*B.float() + assert (C==289).sum().item() == n*n + F._mul(A, B2) + F._mul(A, B2) + F._mul(A, B2) + assert (A==17*(2**3)).sum().item() == n*n + # F.prefetch_tensor(A) + # F.prefetch_tensor(B) + + + # F.fill(B2, 17.0) + # F._mul(A, B2) + + # F.prefetch_tensor(A, to_cpu=True) + # F.prefetch_tensor(B, to_cpu=True) + # F.prefetch_tensor(B2, to_cpu=True) + # torch.cuda.synchronize() + + # assert (A==17).sum().item() == n*n + + # torch.testing.assert_close(A, torch.ones(A.shape)*289) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py new file mode 100644 index 000000000..37f7af9cb --- /dev/null +++ b/tests/test_linear8bitlt.py @@ -0,0 +1,143 @@ +import os +from contextlib import nullcontext +from itertools import product +from tempfile import TemporaryDirectory + +import pytest +import torch + +import bitsandbytes as bnb +from bitsandbytes import functional as F +from bitsandbytes.autograd import get_inverse_transform_indices, undo_layout +from bitsandbytes.nn.modules import Linear8bitLt + + +# contributed by Alex Borzunov, see: +# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability() < (7, 5), + reason="this test requires a turing-generation or newer GPU, see bitsandbytes docs", +) +def test_layout_exact_match(): + x = (torch.randn(14336 * 3, 14336) * 10).to(torch.int8).cuda() + for tile_size, order in ((8, 32), "col_turing"), ((32, 32), "col_ampere"): + transform = lambda x: F.transform(x.cuda(), from_order="row", to_order=order)[0].to(x.device) + tile_indices = get_inverse_transform_indices(transform, tile_size) + cxb = transform(x) + + torch.cuda.synchronize() + restored_x = undo_layout(cxb, tile_indices) + torch.cuda.synchronize() + assert restored_x.is_contiguous() + assert torch.all(torch.eq(restored_x, x)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +def test_linear_no_igemmlt(): + linear = torch.nn.Linear(1024, 3072) + x = torch.randn(3, 1024, dtype=torch.half) + linear_custom = Linear8bitLt( + linear.in_features, + linear.out_features, + linear.bias is not None, + has_fp16_weights=False, + threshold=6.0, + ) + linear_custom.state.force_no_igemmlt = True + + linear_custom.weight = bnb.nn.Int8Params( + linear.weight.data.clone(), requires_grad=False, has_fp16_weights=False + ).to(linear.weight.dtype) + linear_custom.bias = linear.bias + linear_custom = linear_custom.cuda() + linear = linear.half().cuda() + + x_ref = x.clone().cuda().requires_grad_(True) + x_ours = x.clone().cuda().requires_grad_(True) + fx_ref = linear(x_ref).float() + grad_proj = torch.randn_like(fx_ref) + (fx_ref * grad_proj).mean().backward() + + fx_ours = linear_custom(x_ours).float() + (fx_ours * grad_proj).mean().backward() + assert torch.allclose(fx_ref, fx_ours, atol=0.02) + assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) + assert not linear_custom.state.has_fp16_weights + assert linear_custom.state.CB is not None + assert linear_custom.state.CxB is None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt", + list(product([False, True], [False, True], [False, True], [False, True]))) +def test_linear_serialization(has_fp16_weights, serialize_before_forward, deserialize_before_cuda, force_no_igemmlt): + linear = torch.nn.Linear(32, 96) + x = torch.randn(3, 32, dtype=torch.half) + + linear_custom = Linear8bitLt( + linear.in_features, + linear.out_features, + linear.bias is not None, + has_fp16_weights=has_fp16_weights, + threshold=6.0, + ) + if force_no_igemmlt: + linear_custom.state.force_no_igemmlt = True + + linear_custom.weight = bnb.nn.Int8Params( + linear.weight.data.clone(), requires_grad=has_fp16_weights, has_fp16_weights=has_fp16_weights + ) + linear_custom.bias = linear.bias + linear_custom = linear_custom.cuda() + + if serialize_before_forward: + state_dict_8bit = linear_custom.state_dict() + + x_first = x.clone().cuda().requires_grad_(True) + fx_first = linear_custom(x_first).float() + grad_proj = torch.randn_like(fx_first) + (fx_first * grad_proj).mean().backward() + + if not serialize_before_forward: + state_dict_8bit = linear_custom.state_dict() + + with TemporaryDirectory() as tmpdir: + state_path_8bit = os.path.join(tmpdir, "state_8bit.pth") + state_path = os.path.join(tmpdir, "state.pth") + + torch.save(linear.state_dict(), state_path) + torch.save(state_dict_8bit, state_path_8bit) + + if not has_fp16_weights: + assert os.path.getsize(state_path_8bit) < 0.5 * os.path.getsize(state_path) + + new_state_dict = torch.load(state_path_8bit) + + new_linear_custom = Linear8bitLt( + linear.in_features, + linear.out_features, + linear.bias is not None, + has_fp16_weights=has_fp16_weights, + threshold=6.0, + ) + if force_no_igemmlt: + new_linear_custom.state.force_no_igemmlt = True + + if deserialize_before_cuda: + with nullcontext() if has_fp16_weights else pytest.raises(RuntimeError): + new_linear_custom.load_state_dict(new_state_dict, strict=True) + + new_linear_custom = new_linear_custom.cuda() + + if not deserialize_before_cuda: + new_linear_custom.load_state_dict(new_state_dict, strict=True) + + x_second = x.clone().cuda().requires_grad_(True) + fx_second = new_linear_custom(x_second).float() + (fx_second * grad_proj).mean().backward() + + # if 8-bit weights were loaded before .cuda, state is incorrect anyway and RuntimeError was raised + if has_fp16_weights or not deserialize_before_cuda: + assert torch.allclose(fx_first, fx_second, atol=1e-5) + assert torch.allclose(x_first.grad, x_second.grad, atol=1e-5) diff --git a/tests/test_modules.py b/tests/test_modules.py new file mode 100644 index 000000000..714d07dab --- /dev/null +++ b/tests/test_modules.py @@ -0,0 +1,618 @@ +from itertools import product + +import pytest +import torch +from torch import nn + +import bitsandbytes as bnb + + +class MockArgs: + def __init__(self, initial_data): + for key in initial_data: + setattr(self, key, initial_data[key]) + + +class MLP8bit(torch.nn.Module): + def __init__(self, dim1, dim2, has_fp16_weights=True, memory_efficient_backward=False, threshold=0.0): + super().__init__() + self.fc1 = bnb.nn.Linear8bitLt( + dim1, dim2, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold + ) + self.fc2 = bnb.nn.Linear8bitLt( + dim2, dim1, has_fp16_weights=has_fp16_weights, memory_efficient_backward=memory_efficient_backward, + threshold=threshold + ) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + +def get_args(): + args = MockArgs([]) + args.quant_type = "vector" + args.use_8bit_training = "full" + args.clip_freq = 9999 + return args + + +def assert_all_approx_close(a, b, atol=1e-8, rtol=1e-5, count=10): + idx = torch.isclose(a, b, rtol, atol) + sumval = (idx == 0).sum().item() + if sumval > count: + print(f"Too many values not close: assert {sumval} < {count}") + torch.testing.assert_close(a, b, rtol, atol) + + +class LinearFunction(torch.autograd.Function): + @staticmethod + def get_8bit_linear_trimmed(x, stochastic=False, trim_value=3.0): + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) + norm = math.sqrt(math.pi) / math.sqrt(2.0) + # std = torch.abs(x).mean()*norm + std = torch.std(x) + max1 = std * trim_value + x = x / max1 * 127 + x = round_func(x) + x[x > 127] = 127 + x[x < -127] = -127 + x = x / 127 * max1 + + return x + + def quant(x, quant_type, dim=1): + if quant_type == "linear": + max1 = torch.abs(x).max().float() + xq = torch.round(x / max1 * 127).to(torch.int8) + return xq, max1 + elif quant_type == "vector": + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + xq = torch.round(x / max1 * 127).to(torch.int8) + return xq, max1 + elif quant_type == "min-max": + maxA = torch.amax(x, dim=dim, keepdim=True).float() + minA = torch.amin(x, dim=dim, keepdim=True).float() + scale = (maxA - minA) / 2.0 + xq = torch.round(127 * (x - minA - scale) / scale).to(torch.int8) + return xq, (minA.float(), scale.float()) + else: + return None + + def dequant(xq, S1, S2, dtype, quant_type): + if quant_type == "linear": + norm = S1 * S2 / (127 * 127) + # double cast needed to prevent overflows + return (xq.float() * norm).to(dtype) + elif quant_type == "vector": + x = xq.float() + if len(xq.shape) == 2 and len(S1.shape) == 3: + S1 = S1.squeeze(0) + if len(xq.shape) == 2 and len(S2.shape) == 3: + S2 = S2.squeeze(0) + # print(x.shape, S1.shape, S2.shape) + if len(S1.shape) == 2: + x *= S1.t() / 127 + else: + x *= S1 / 127 + x *= S2 / 127 + return x.to(dtype) + else: + return None + + def dequant_min_max(xq, A, B, SA, SB, dtype): + offset = B.float().t().sum(0) * (SA[0] + SA[1]) + x = xq.float() + if len(xq.shape) == 2 and len(SB.shape) == 3: + SB = SB.squeeze(0) + if len(xq.shape) == 2 and len(SA.shape) == 3: + SA = SA.squeeze(0) + if len(SB.shape) == 2: + x *= SB.t() / 127 + else: + x *= SB / 127 + x *= SA[1] / 127 + x += offset + return x.to(dtype) + + def get_8bit_linear(x, stochastic=False): + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) + max1 = torch.abs(x).max() + x = x / max1 * 127 + x = round_func(x) / 127 * max1 + # x = torch.round(x)/128*max1 + return x + + @staticmethod + def get_8bit_vector_wise(x, dim, stochastic=False): + round_func = ( + LinearFunction.round_stoachastic if stochastic else torch.round + ) + max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True) + max1[max1 == 0] = 1.0 + x = (x * 127) / max1 + x = round_func(x) / 127 * max1 + return x + + @staticmethod + def round_stoachastic(x): + sign = torch.sign(x) + absx = torch.abs(x) + decimal = absx - torch.floor(absx) + rdm = torch.rand_like(decimal) + return sign * (torch.floor(absx) + (rdm < decimal).to(x.dtype)) + + @staticmethod + def fake_8bit_storage(w, exponent_bits): + code = bnb.functional.create_dynamic_map(n=exponent_bits).to(w.device) + absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) + out = bnb.functional.dequantize_blockwise(absmax, C, code) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_quantile(w, args): + code = bnb.functional.estimate_quantiles(w.data, offset=args.offset) + # C = bnb.functional.quantize_no_absmax(code, w) + # out = bnb.functional.dequantize_no_absmax(code, C, out=w.data) + # print(out) + # out = out.half() + code /= torch.max(torch.abs(code)) + absmax, C = bnb.functional.quantize_blockwise(w.data, code=code) + out = bnb.functional.dequantize_blockwise(absmax, C, code) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_stoachstic(w): + rand = torch.rand(1024, device=w.device) + absmax, C = bnb.functional.quantize_blockwise(w.data, rand=rand) + out = bnb.functional.dequantize_blockwise(absmax, C) + out = out.half() + w.copy_(out) + return out + + @staticmethod + def fake_8bit_storage_with_max(w, topk=8): + blocked_w = einops.rearrange(w.flatten(), "(h b) -> h b", b=256) + max_val, idx = torch.sort(torch.abs(blocked_w), dim=1, descending=True) + idx = idx[:, :topk] + max_val = max_val[:, :topk] + + mask = torch.zeros_like(blocked_w) + mask.scatter_(dim=1, index=idx, src=torch.ones_like(max_val)) + mask = mask.bool() + + # 1. zero out max values + # 2. quantize + dequantize + # 3. write back max values + # 4. copy matrix back to weight + + values = blocked_w[mask] + blocked_w[mask] = 0 + + code = bnb.functional.create_dynamic_map() + code = code.to(w.device) + absmax, C = bnb.functional.quantize_blockwise(blocked_w.data) + bnb.functional.dequantize_blockwise(absmax, C, out=blocked_w) + + blocked_w[mask] = values + + unblocked_w = blocked_w.flatten().view(w.shape) + + w.copy_(unblocked_w) + return unblocked_w + + @staticmethod + def forward(ctx, x, weight, bias=None, args=None): + if args.use_8bit_training != "off": + weight8, S1 = LinearFunction.quant(weight, args.quant_type, dim=1) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=2) + outputq = bnb.functional.igemm(x8, weight8.t()) + output = LinearFunction.dequant( + outputq, S1, S2, x.dtype, args.quant_type + ) + # if torch.rand(1) < 0.01: + # output32 = torch.matmul(x, weight.t()) + # err = torch.abs(output-output32).float() + # relerr = err/(torch.abs(output32).float()+1e-8) + # print(f'{err.mean().item():.4f}, {relerr.mean().item():.4f}', args.quant_type, 'forward', proxy) + else: + # output = torch.matmul(x, weight.t()) + output = torch.einsum("bsi,oi->bso", x, weight) + + ctx.save_for_backward(x, weight, bias) + ctx.args = args + + if bias is not None: + output += bias.unsqueeze(0).expand_as(output) + return output + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + args = ctx.args + stochastic = False + grad_input = grad_weight = grad_bias = None + if bias is not None and ctx.needs_input_grad[2]: + grad_bias = grad_output.sum(0) + + # weight and x are already 8bit + # -> transform grad_output to 8-bit + if args.use_8bit_training == "forward+wgrad": + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=[0, 1] + ) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) + grad_weight8 = bnb.functional.igemm(grad_output8, x8) + grad_weight = LinearFunction.dequant( + grad_weight8, S1, S2, grad_output.dtype, args.quant_type + ) + + # grad_weight32 = torch.einsum('bso,bsi->oi', grad_output, x) + + grad_input = grad_output.matmul(weight) + elif args.use_8bit_training == "full": + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=[0, 1] + ) + x8, S2 = LinearFunction.quant(x, args.quant_type, dim=[0, 1]) + grad_weight8 = torch.zeros_like(weight, dtype=torch.int32) + bnb.functional.igemm(grad_output8, x8, out=grad_weight8) + grad_weight = LinearFunction.dequant( + grad_weight8, S1, S2, grad_output.dtype, args.quant_type + ) + + grad_output8, S1 = LinearFunction.quant( + grad_output, args.quant_type, dim=2 + ) + weight8, S3 = LinearFunction.quant(weight, args.quant_type, dim=0) + grad_input8 = bnb.functional.igemm(grad_output8, weight8) + grad_input = LinearFunction.dequant( + grad_input8, S1, S3, grad_output.dtype, args.quant_type + ) + + else: + grad_input = grad_output.matmul(weight) + grad_weight = torch.einsum("bsi,bso->oi", x, grad_output) + + return grad_input, grad_weight, grad_bias, None + + +class Linear8bit(nn.Module): + def __init__(self, input_features, output_features, bias=True, args=None): + super().__init__() + self.input_features = input_features + self.output_features = output_features + self.args = args + + self.weight = nn.Parameter(torch.empty(output_features, input_features)) + if bias: + self.bias = nn.Parameter(torch.empty(output_features)) + else: + self.register_parameter("bias", None) + + torch.nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + torch.nn.init.zeros_(self.bias) + + def forward(self, x): + self.args.training = self.training + + return LinearFunction.apply(x, self.weight, self.bias, self.args) + + +threshold = [0.0, 3.0] +values = threshold +names = [f"threshold_{vals}" for vals in values] + + +@pytest.mark.parametrize("threshold", values, ids=names) +def test_linear8bitlt_inference(threshold): + l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() + assert l1.weight.device.type == "cuda" + assert l1.weight.dtype == torch.float16 + + l1.eval() + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = l1(b1) + if i == 1: + assert l1.state.CxB is not None + + +def test_linear8bitlt_accumulated_gradient(): + l1 = torch.nn.Sequential(*[bnb.nn.Linear8bitLt(32, 32).cuda().half() for i in range(2)]) + l2 = torch.nn.Sequential(*[torch.nn.Linear(32, 32).cuda().half() for i in range(2)]) + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) + + opt1 = bnb.optim.Adam32bit(l1.parameters(), lr=0.001) + opt2 = bnb.optim.Adam32bit(l2.parameters(), lr=0.001) + + acc_steps = 10 + + for i in range(10): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = l1(b1) + o2 = l2(b1) + loss1 = o1.mean() + loss2 = o2.mean() + loss1.backward() + loss2.backward() + if i == 2: + assert l1[0].state.CxB is not None + assert l1[1].state.CxB is not None + + if i > 0 and i % acc_steps == 0: + opt1.step() + opt1.zero_grad(True) + opt2.step() + opt2.zero_grad(True) + assert_all_approx_close( + l1[0].weight, l2[0].weight, rtol=1.05, atol=0.01, count=2 + ) + assert_all_approx_close( + l1[1].weight, l2[1].weight, rtol=1.05, atol=0.01, count=2 + ) + # we do this copy because otherwise we have small divergences over time that add up + l1[0].weight.data.copy_(l2[0].weight.data) + l1[1].weight.data.copy_(l2[1].weight.data) + l1[0].bias.data.copy_(l2[0].bias.data) + l1[1].bias.data.copy_(l2[1].bias.data) + else: + torch.testing.assert_close(l1[0].weight.grad, l2[0].weight.grad, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(l1[1].weight.grad, l2[1].weight.grad, atol=1e-3, rtol=1e-3) + + +@pytest.mark.parametrize("threshold", [0.0, 2.0]) +@pytest.mark.parametrize("memory_efficient_backward", [False]) +def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): + l1 = (bnb.nn.Linear8bitLt( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).cuda().half()) + assert l1.weight.dtype == torch.int8 + + l1.eval() + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = l1(b1) + assert o1.dtype == torch.float16 + + mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).cuda() + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None + + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .cuda() + .half() + ) + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None + + mlp = ( + MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False) + .half() + .cuda() + ) + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + + mlp = ( MLP8bit( 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward).half().to("cuda")) + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + assert mlp.fc1.weight.device.type == "cuda" + assert mlp.fc2.weight.device.type == "cuda" + + mlp = MLP8bit( + 32, 64, threshold=threshold, has_fp16_weights=False, memory_efficient_backward=memory_efficient_backward + ) + w1, w2 = mlp.fc1.weight.clone().cuda(), mlp.fc2.weight.clone().cuda() # grab weights before quantization, + mlp = mlp.cuda().half() # and this line triggers quantization + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = mlp(b1) + assert o1.dtype == torch.float16 + if threshold > 0: + assert mlp.fc1.state.idx is not None + if threshold > 0: + assert mlp.fc2.state.idx is not None + + assert mlp.fc1.weight.dtype == torch.int8 + assert mlp.fc2.weight.dtype == torch.int8 + assert mlp.fc1.weight.device.type == "cuda" + assert mlp.fc2.weight.device.type == "cuda" + + if memory_efficient_backward: + b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) + o1 = mlp(b1) + assert o1.dtype == torch.float16 + assert o1.requires_grad + grad_proj = torch.randn_like(o1) + + mlp.zero_grad() + (o1 * grad_proj).sum().backward() + grad_ref = grad_proj.flatten(2) @ w2.half() @ w1.half() + scale = grad_ref.abs().mean() + + torch.testing.assert_close(b1.grad, grad_ref, rtol=0, atol=0.05 * scale) + idx = torch.isclose(b1.grad, grad_ref, atol=0.01 * scale, rtol=0.1) + assert (idx == 0).sum().item() <= b1.numel() * 0.005 + + +@pytest.mark.parametrize("module", [lambda nin, nout, bias=True: bnb.nn.Linear8bitLt(nin, nout, bias=bias, has_fp16_weights=False), bnb.nn.LinearFP4], ids=['Int8Lt', 'FP4']) +def test_linear_kbit_fp32_bias(module): + # casts model to fp16 -> int8 automatically + l1 = module(32, 64).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] + assert l1.bias.dtype == torch.float32 + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + # casts bias to fp32 + o1 = l1(b1) + assert l1.bias.dtype == torch.float16 + + # casts model to fp16 -> int8 automatically + l1 = module(32, 64, bias=False).cuda() + assert l1.weight.dtype in [torch.int8, torch.uint8] + assert l1.bias is None + + for i in range(100): + b1 = torch.randn(16, 8, 32, device="cuda").half() + o1 = l1(b1) + assert l1.bias is None + +modules = [] +modules.append(bnb.nn.Linear8bitLt) +modules.append(bnb.nn.Linear4bit) +modules.append(bnb.nn.LinearFP4) +modules.append(bnb.nn.LinearNF4) +modules.append(lambda d1, d2: bnb.nn.LinearFP4(d1, d2, compress_statistics=True)) +modules.append(lambda d1, d2: bnb.nn.LinearNF4(d1, d2, compress_statistics=True)) +names = ['Int8Lt', '4bit', 'FP4', 'NF4', 'FP4+C', 'NF4+C'] +@pytest.mark.skipif(not torch.cuda.is_available(), reason="this test requires a GPU") +@pytest.mark.parametrize("module", modules, ids=names) +def test_kbit_backprop(module): + # TODO: Remove after beta + if '4' in str(module): pytest.skip() + if 'lambda' in str(module): pytest.skip() + b = 17 + dim1 = 37 + dim2 = 83 + + ref = nn.Sequential(*[torch.nn.Linear(dim1, dim2), torch.nn.Linear(dim2, 10)]) + ref[1].weight.requires_grad = False + torch.nn.init.kaiming_normal_(ref[0].weight) + torch.nn.init.kaiming_normal_(ref[1].weight) + kbit = nn.Sequential(*[torch.nn.Linear(dim1, dim2), module(dim2, 10)]) + kbit[0].weight.detach().copy_(ref[0].weight) + kbit[1].weight.detach().copy_(ref[1].weight) + kbit[0].bias.detach().copy_(ref[0].bias) + kbit[1].bias.detach().copy_(ref[1].bias) + ref = ref.half().cuda() + kbit = kbit.half().cuda() + + errs1 = [] + errs2 = [] + relerrs1 = [] + relerrs2 = [] + for i in range(100): + batch = torch.randn(b, dim1).half().cuda() + out1 = ref(batch) + out2 = kbit(batch) + out1.mean().backward() + out2.mean().backward() + + grad1 = ref[0].weight.grad + grad2 = kbit[0].weight.grad + bgrad1 = ref[0].bias.grad + bgrad2 = kbit[0].bias.grad + + err1 = (out1-out2).abs().float() + err2 = (grad1-grad2).abs().float() + relerr1 = (err1/(out1.abs().float()+1e-9)) + relerr2 = (err2/(grad1.abs().float()+1e-9)) + errs1.append(err1.mean().item()) + errs2.append(err2.mean().item()) + relerrs1.append(relerr1.mean().item()) + relerrs2.append(relerr2.mean().item()) + + if isinstance(module, bnb.nn.Linear8bitLt): + torch.testing.assert_close(grad1, grad2, atol=0.008, rtol=0.05) + torch.testing.assert_close(bgrad1, bgrad2, atol=0.008, rtol=0.05) + else: + torch.testing.assert_close(grad1, grad2, atol=0.015, rtol=0.05) + torch.testing.assert_close(bgrad1, bgrad2, atol=0.02, rtol=0.05) + ref.zero_grad() + kbit.zero_grad() + + assert kbit[0].weight.grad is None or kbit[0].weight.grad.sum().item() == 0 + assert kbit[0].weight.grad is None or kbit[0].bias.grad.sum().item() == 0 + print('out', sum(errs1)/len(errs1)) + print('grad', sum(errs2)/len(errs2)) + print('rel out', sum(relerrs1)/len(relerrs1)) + print('rel grad', sum(relerrs2)/len(relerrs2)) + +def test_fp8linear(): + + b = 10 + h = 1024 + inp = torch.randn(b, h).cuda() + fp32 = torch.nn.Linear(h, h*2).cuda() + fp8 = bnb.research.nn.LinearFP8Mixed(h, h*2).cuda() + fp32b = torch.nn.Linear(h*2, h).cuda() + fp8b = bnb.research.nn.LinearFP8Mixed(h*2, h).cuda() + + fp8.weight.data.copy_(fp32.weight.data) + fp8.bias.data.copy_(fp32.bias.data) + fp8b.weight.data.copy_(fp32b.weight.data) + fp8b.bias.data.copy_(fp32b.bias.data) + + a = fp32b(torch.nn.functional.gelu(fp32(inp))) + b = fp8b(torch.nn.functional.gelu(fp8(inp))) + + err = (a-b).abs().mean() + + a.mean().backward() + b.mean().backward() + + graderr = (fp8.weight.grad-fp32.weight.grad).abs().mean() + bgraderr = (fp8.bias.grad-fp32.bias.grad).abs().mean() + + assert err < 0.05 + assert graderr < 0.00002 + assert bgraderr < 0.00002 + + + + + + + diff --git a/tests/test_optim.py b/tests/test_optim.py new file mode 100644 index 000000000..98e4289dd --- /dev/null +++ b/tests/test_optim.py @@ -0,0 +1,562 @@ +import ctypes +import os +import shutil +import time +import uuid +from itertools import product +from os.path import join + +import pytest +from lion_pytorch import Lion + +import torch + +import bitsandbytes as bnb +import bitsandbytes.functional as F + +# import apex + +k = 20 + +def assert_most_approx_close(a, b, rtol=1e-3, atol=1e-3, max_error_count=0): + idx = torch.isclose(a, b, rtol, atol) + error_count = (idx == 0).sum().item() + if error_count > max_error_count: + print(f"Too many values not close: assert {error_count} < {max_error_count}") + torch.testing.assert_close(a, b, rtol, atol) + + +def get_temp_dir(): + path = f"/tmp/autoswap/{str(uuid.uuid4())}" + os.makedirs(path, exist_ok=True) + return path + + +def rm_path(path): + shutil.rmtree(path) + +str2bf16support = {} +str2bf16support['adam8bit_blockwise'] = True + +str2optimizers = {} +str2optimizers["adam_pytorch"] = (None, torch.optim.Adam, bnb.optim.Adam) +# str2optimizers['adam_apex'] = (None, apex.optimizers.FusedAdam, bnb.optim.Adam) +# str2optimizers['momentum_apex'] = (None, lambda pxx: apex.optimizers.FusedSGD(pxx, 0.01, 0.9), bnb.optim.Adam) +str2optimizers["lion_pytorch"] = (None, Lion, bnb.optim.Lion) +str2optimizers["momentum_pytorch"] = ( + None, + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + bnb.optim.Adam, +) +str2optimizers["adam"] = (torch.optim.Adam, bnb.optim.Adam) +str2optimizers["paged_adamw"] = (torch.optim.AdamW, bnb.optim.PagedAdamW) +str2optimizers["paged_adam"] = (torch.optim.Adam, bnb.optim.PagedAdam) +# str2optimizers['fused_adam'] = (apex.optimizers.FusedAdam, bnb.optim.Adam) +str2optimizers["lion"] = (Lion, bnb.optim.Lion) +str2optimizers["momentum"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["rmsprop"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["adam8bit"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=False)) +str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False)) +str2optimizers["momentum8bit"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=False), +) +str2optimizers["rmsprop8bit"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=False), +) + +str2optimizers["adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.Adam8bit(pxx, block_wise=True)) +str2optimizers["paged_adamw8bit_blockwise"] = (torch.optim.AdamW, lambda pxx: bnb.optim.PagedAdamW8bit(pxx, block_wise=True)) +str2optimizers["paged_adam8bit_blockwise"] = (torch.optim.Adam, lambda pxx: bnb.optim.PagedAdam8bit(pxx, block_wise=True)) +str2optimizers["lion8bit_blockwise"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=True)) +str2optimizers["momentum8bit_blockwise"] = ( + lambda pxx: torch.optim.SGD(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.SGD8bit(pxx, 0.01, 0.9, block_wise=True), +) +str2optimizers["rmsprop8bit_blockwise"] = ( + lambda pxx: torch.optim.RMSprop(pxx, 0.01, 0.9), + lambda pxx: bnb.optim.RMSprop8bit(pxx, 0.01, 0.9, block_wise=True), +) + +str2statenames = {} +str2statenames["adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adamw"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["paged_adam"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["lion"] = [("exp_avg", "state1")] +str2statenames["momentum"] = [("momentum_buffer", "state1")] +str2statenames["lamb"] = [("exp_avg", "state1"), ("exp_avg_sq", "state2")] +str2statenames["rmsprop"] = [("square_avg", "state1")] +str2statenames["adam8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["lamb8bit"] = [("exp_avg", "state1", "qmap1", "max1"), ("exp_avg_sq", "state2", "qmap2", "max2")] +str2statenames["adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adam8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["paged_adamw8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1"), ("exp_avg_sq", "state2", "qmap2", "absmax2")] +str2statenames["momentum8bit"] = [("momentum_buffer", "state1", "qmap1", "max1")] +str2statenames["lion8bit"] = [("exp_avg", "state1", "qmap1", "max1")] +str2statenames["momentum8bit_blockwise"] = [("momentum_buffer", "state1", "qmap1", "absmax1")] +str2statenames["rmsprop8bit"] = [("square_avg", "state1", "qmap1", "max1")] +str2statenames["rmsprop8bit_blockwise"] = [("square_avg", "state1", "qmap1", "absmax1")] +str2statenames["lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")] + +dim1 = [1024] +dim2 = [32, 1024, 4097, 1] +gtype = [torch.float32, torch.float16] +optimizer_names = ["adam", "momentum", "rmsprop", 'paged_adamw', 'paged_adam', 'lion'] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = ["dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +def test_optimizer32bit(dim1, dim2, gtype, optim_name): + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p2 = p1.clone() + p1 = p1.float() + + torch_optimizer = str2optimizers[optim_name][0]([p1]) + bnb_optimizer = str2optimizers[optim_name][1]([p2]) + + if gtype == torch.float32: + atol, rtol = 1e-6, 1e-5 + elif gtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-2 + else: + atol, rtol = 1e-4, 1e-3 + + for i in range(k): + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + p1.grad = g.clone().float() + p2.grad = g.clone() + + bnb_optimizer.step() + torch_optimizer.step() + + + for name1, name2 in str2statenames[optim_name]: + torch.testing.assert_close( + torch_optimizer.state[p1][name1], + bnb_optimizer.state[p2][name2].cuda(), + atol=atol, + rtol=rtol, + ) + + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 10 errors for Lion + assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + + if i % (k // 5) == 0 and i > 0: + path = get_temp_dir() + torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt")) + del bnb_optimizer + bnb_optimizer = None + bnb_optimizer = str2optimizers[optim_name][1]([p2]) + bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) + rm_path(path) + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 10 errors for Lion + assert_most_approx_close(p1, p2.float(), atol, rtol, max_error_count=10) + for name1, name2 in str2statenames[optim_name]: + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 10 errors for Lion + assert_most_approx_close(torch_optimizer.state[p1][name1], bnb_optimizer.state[p2][name2], + atol=atol, rtol=rtol, + max_error_count=10) + + if gtype != torch.float32: + # the adam buffers should also be close because they are 32-bit + # but the paramters can diverge because they are 16-bit + # the difference grow larger and larger with each update + # --> copy the state to keep weights close + p1.data = p1.data.to(p2.dtype).float() + p2.copy_(p1.data) + torch.testing.assert_close(p1.to(p2.dtype), p2) + if optim_name in ["lars", "lamb"]: + assert bnb_optimizer.state[p2]["unorm_vec"] > 0.0 + + +dim1 = [1024] +dim2 = [32, 1024, 4097] +gtype = [torch.float32, torch.float16] +values = list(product(dim1, dim2, gtype)) +names = ["dim1_{}_dim2_{}_gtype_{}".format(*vals) for vals in values] + + +@pytest.mark.parametrize("dim1, dim2, gtype", values, ids=names) +def test_global_config(dim1, dim2, gtype): + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 + p2 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 + p3 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 + mask = torch.rand_like(p2) < 0.1 + beta1 = 0.9 + beta2 = 0.999 + lr = 0.001 + eps = 1e-8 + + bnb.optim.GlobalOptimManager.get_instance().initialize() + bnb.optim.GlobalOptimManager.get_instance().override_config( + p3, "optim_bits", 8 + ) + + bnb.optim.GlobalOptimManager.get_instance().register_parameters( + [p1, p2, p3] + ) + p1 = p1.cuda() + p2 = p2.cuda() + p3 = p3.cuda() + + adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps) + + if gtype == torch.float32: + atol, rtol = 1e-6, 1e-5 + else: + atol, rtol = 1e-4, 1e-3 + + for i in range(50): + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + p1.grad = g1 + p2.grad = g2 + p3.grad = g3 + + adam2.step() + + assert adam2.state[p3]["state1"].dtype == torch.uint8 + assert adam2.state[p3]["state2"].dtype == torch.uint8 + + +dim1 = [1024] +dim2 = [32, 1024, 4097] +gtype = [torch.float32, torch.float16, torch.bfloat16] +optimizer_names = [ + "adam8bit", + "lion8bit", + "momentum8bit", + "rmsprop8bit", + "adam8bit_blockwise", + "lion8bit_blockwise", + "momentum8bit_blockwise", + "rmsprop8bit_blockwise", +] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = [ + "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values +] + + +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +def test_optimizer8bit(dim1, dim2, gtype, optim_name): + if gtype == torch.bfloat16 and optim_name not in str2bf16support: return + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p2 = p1.clone() + p1 = p1.float() + blocksize = 2048 + + torch_optimizer = str2optimizers[optim_name][0]([p1]) + bnb_optimizer = str2optimizers[optim_name][1]([p2]) + + if gtype == torch.float32: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-5, 1e-3 + elif gtype == torch.bfloat16: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-4, 1e-2 + else: + atol, rtol = 3e-3, 1e-3 + patol, prtol = 1e-5, 1e-3 + + errors = [] + relerrors = [] + + for i in range(100): + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + p1.grad = g.clone().float() + p2.grad = g.clone() + + bnb_optimizer.step() + torch_optimizer.step() + + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 5 errors for Lion + assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5) + + dequant_states = [] + for name1, name2, qmap, max_val in str2statenames[optim_name]: + # print(bnb_optimizer.state[p2][max_val], name1) + if "blockwise" in optim_name: + s1 = F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + blocksize=blocksize, + ) + else: + s1 = F.dequantize( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + ) + num_not_close = ( + torch.isclose( + torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol + ) + == 0 + ) + #assert num_not_close.sum().item() < 20 + dequant_states.append(s1.clone()) + + err = torch.abs(p1 - p2) + relerr = err / (torch.abs(p1)+1e-9) + if g.dtype == torch.bfloat16: + assert err.mean() < 0.00015 + assert relerr.mean() < 0.0016 + else: + assert err.mean() < 0.00012 + assert relerr.mean() < 0.0012 + + errors.append(err.mean().item()) + relerrors.append(relerr.mean().item()) + + if i % 10 == 0 and i > 0: + for (name1, name2, qmap, max_val), s in zip( + str2statenames[optim_name], dequant_states + ): + s1cpy = s.clone() + raws1cpy = bnb_optimizer.state[p2][name2].clone() + qmap1 = bnb_optimizer.state[p2][qmap].clone() + + path = get_temp_dir() + torch.save(bnb_optimizer.state_dict(), join(path, "opt.pt")) + del bnb_optimizer + bnb_optimizer = None + bnb_optimizer = str2optimizers[optim_name][1]([p2]) + bnb_optimizer.load_state_dict(torch.load(join(path, "opt.pt"))) + rm_path(path) + torch.testing.assert_close(raws1cpy, bnb_optimizer.state[p2][name2]) + torch.testing.assert_close(qmap1, bnb_optimizer.state[p2][qmap]) + + if "blockwise" in optim_name: + s1 = F.dequantize_blockwise( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + blocksize=blocksize, + ) + else: + s1 = F.dequantize( + code=bnb_optimizer.state[p2][qmap], + absmax=bnb_optimizer.state[p2][max_val], + A=bnb_optimizer.state[p2][name2], + ) + torch.testing.assert_close(s1cpy, s1) + + num_not_close = (torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0) + assert num_not_close.sum().item() < 20 + # since Lion can have pretty noisy updates where things lie at the boundary + # allow up to 5 errors for Lion + assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=5) + + # the parameters diverge quickly. Here we keep them close + # together so we can test against the Adam error + p1.data = p1.data.to(gtype).float() + p2.copy_(p1.data) + torch.testing.assert_close(p1.to(gtype), p2) + for (name1, name2, qmap, max_val), s in zip(str2statenames[optim_name], dequant_states): + torch_optimizer.state[p1][name1].copy_(s.data) + + # print(sum(errors)/len(errors)) + # print(sum(relerrors)/len(relerrors)) + + +dim1 = [1024] +dim2 = [32, 1024, 4097] +gtype = [torch.float32] +optim_bits = [32, 8] +values = list(product(dim1, dim2, gtype, optim_bits)) +names = [ + "dim1_{}_dim2_{}_gtype_{}_optim_bits_{}".format(*vals) + for vals in values +] + + +@pytest.mark.parametrize("dim1, dim2, gtype, optim_bits", values, ids=names) +def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits): + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 + beta1 = 0.9 + beta2 = 0.999 + lr = 0.001 + eps = 1e-8 + p1 = p1.cuda() + p2 = p1.clone() + adam1 = bnb.optim.Adam([p1], lr, (beta1, beta2), eps, optim_bits=optim_bits) + adam2 = bnb.optim.Adam( + [p2], + lr, + (beta1, beta2), + eps, + optim_bits=optim_bits, + percentile_clipping=5, + ) + + gnorm_vec = torch.zeros(100).cuda() + step = 0 + + for i in range(50): + step += 1 + g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + ( + 0.01 * i + ) + g2 = g1.clone() + p2.grad = g2 + + current_gnorm, clip_val, gnorm_scale = F.percentile_clipping( + g1, gnorm_vec, step, 5 + ) + g1 = (g1.float() * gnorm_scale).to(gtype) + p1.grad = g1 + + adam1.step() + adam2.step() + + # gnorm_scale is not deterministic (warp reductions), as such there can be slight differences in state + if optim_bits == 32: + torch.testing.assert_close(p1, p2) + torch.testing.assert_close( + adam1.state[p1]["state1"], + adam2.state[p2]["state1"], + atol=5e-5, + rtol=1e-4, + ) + torch.testing.assert_close( + adam1.state[p1]["state2"], + adam2.state[p2]["state2"], + atol=5e-5, + rtol=1e-4, + ) + elif optim_bits == 8: + torch.testing.assert_close(p1, p2, atol=1e-4, rtol=1e-3) + torch.testing.assert_close( + adam1.state[p1]["state1"], + adam2.state[p2]["state1"], + atol=2, + rtol=1e-3, + ) + torch.testing.assert_close( + adam1.state[p1]["state2"], + adam2.state[p2]["state2"], + atol=2, + rtol=1e-3, + ) + adam1.state[p1]["state1"].copy_(adam2.state[p2]["state1"]) + adam1.state[p1]["state2"].copy_(adam2.state[p2]["state2"]) + if i % 10 == 0 and i > 0: + path = get_temp_dir() + torch.save(adam2.state_dict(), join(path, "opt.pt")) + del adam2 + adam2 = None + adam2 = bnb.optim.Adam( + [p2], + lr, + (beta1, beta2), + eps, + optim_bits=optim_bits, + percentile_clipping=5, + ) + adam2.load_state_dict(torch.load(join(path, "opt.pt"))) + + +dim1 = [4096] +dim2 = [4096] +gtype = [torch.float32, torch.float16] +# optimizer_names = ['adam8bit_blockwise', 'adam8bit', 'lamb8bit'] +# optimizer_names = ['adam8bit_blockwise', 'adam_apex', 'adam8bit', 'adam', 'adam_pytorch'] +# optimizer_names = ['momentum_apex', 'momentum8bit', 'momentum_pytorch'] +# optimizer_names = ['lamb_apex', 'lamb8bit'] +# optimizer_names = ['lars_apex', 'lars8bit'] +optimizer_names = ["adam8bit_blockwise", 'paged_adam8bit_blockwise', 'paged_adamw8bit_blockwise'] +values = list(product(dim1, dim2, gtype, optimizer_names)) +names = [ + "dim1_{}_dim2_{}_gtype_{}_optim_{}".format(*vals) for vals in values +] + + +@pytest.mark.parametrize("dim1, dim2, gtype, optim_name", values, ids=names) +def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): + if dim1 == 1 and dim2 == 1: + return + p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + + bnb_optimizer = str2optimizers[optim_name][1]([p1]) + + g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + p1.grad = g + for i in range(k): + if i == k // 5: + # 100 iterations for burn-in + torch.cuda.synchronize() + t0 = time.time() + + bnb_optimizer.step() + + torch.cuda.synchronize() + s = time.time() - t0 + print("") + params = (k - k // 5) * dim1 * dim2 + print(optim_name, gtype, s / params) + # assert s < 3.9 + +dim1 = [2*1024] +gtype = [torch.float16] +#mode = ['torch', 'bnb'] +mode = ['bnb'] +optimizer_names = ['paged_adamw'] +#optimizer_names = ['paged_adamw8bit_blockwise'] +values = list(product(dim1,gtype, optimizer_names, mode)) +names = ['dim1_{0}_gtype_{1}_optim_{2}_mode_{3}'.format(*vals) for vals in values] +@pytest.mark.parametrize("dim1, gtype, optim_name, mode", values, ids=names) +def test_stream_optimizer_bench(dim1, gtype, optim_name, mode): + layers1 = torch.nn.Sequential(*torch.nn.ModuleList([torch.nn.Linear(dim1, dim1) for i in range(10)])) + layers1 = layers1.to(gtype) + layers1 = layers1.cuda() + + large_tensor = None + if mode == 'torch': + optim = str2optimizers[optim_name][0](layers1.parameters()) + else: + optim = str2optimizers[optim_name][1](layers1.parameters()) + # 12 GB + large_tensor = torch.empty((int(4.5e9),), device='cuda') + + torch.cuda.synchronize() + time.sleep(5) + + num_batches = 5 + batches = torch.randn(num_batches, 128, dim1, device='cuda').to(gtype) + lbls = torch.randint(0, 10, size=(num_batches,128)).cuda() + + for i in range(num_batches): + print(i) + b = batches[i] + if i ==2: + torch.cuda.synchronize() + t0 = time.time() + + out1 = layers1(b) + + loss1 = torch.nn.functional.cross_entropy(out1, lbls[i]).mean() + loss1.backward() + optim.step() + torch.cuda.synchronize() + print(mode, time.time() - t0) diff --git a/tests/test_triton.py b/tests/test_triton.py new file mode 100644 index 000000000..e18c7a930 --- /dev/null +++ b/tests/test_triton.py @@ -0,0 +1,59 @@ +import pytest +import torch + +from bitsandbytes.triton.triton_utils import is_triton_available +from bitsandbytes.nn.triton_based_modules import SwitchBackLinear +from bitsandbytes.nn import Linear8bitLt + +@pytest.mark.skipif(not is_triton_available() or not torch.cuda.is_available() or not torch.cuda.get_device_capability()[0] >= 8, + reason="This test requires triton and a GPU with compute capability 8.0 or higher.") +@pytest.mark.parametrize("vector_wise_quantization", [False, True]) +def test_switchback(vector_wise_quantization): + for dim in [83]: + for batch in [13]: + + standard = torch.nn.Linear(dim, 4 * dim).cuda().half() + switchback = SwitchBackLinear(dim, 4 * dim, vector_wise_quantization=vector_wise_quantization).cuda().half() + baseline = Linear8bitLt(dim, 4 * dim).cuda().half() + switchback.weight.data.copy_(standard.weight) + switchback.bias.data.copy_(standard.bias) + baseline.weight.data.copy_(standard.weight) + baseline.bias.data.copy_(standard.bias) + + x1 = torch.randn(batch, dim).cuda().half().requires_grad_(True) + x2 = x1.clone().detach().requires_grad_(True) + x3 = x1.clone().detach().requires_grad_(True) + + out_standard = standard(x1) + (2**10 * out_standard.abs().mean()).backward() + + print(x2.dtype) + out_sb = switchback(x2) + (2**10 * out_sb.abs().mean()).backward() + + out_baseline = baseline(x3) + (2**10 * out_baseline.abs().mean()).backward() + + err_sb = (out_standard - out_sb).abs().mean() + err_baseline = (out_standard - out_baseline).abs().mean() + print('OUT', err_sb, err_baseline) + assert err_sb < 2 * err_baseline + + err_sb = (standard.bias.grad - switchback.bias.grad).abs().mean() + err_baseline = (standard.bias.grad - baseline.bias.grad).abs().mean() + + print('GW2', err_sb, err_baseline) + assert err_sb < 2 * err_baseline + + err_sb = (standard.weight.grad - switchback.weight.grad).abs().mean() + err_baseline = (standard.weight.grad - baseline.weight.grad).abs().mean() + + print('GW1', err_sb, err_baseline) + assert err_sb < 2 * err_baseline + + err_sb = (x1.grad - x2.grad).abs().mean() + err_baseline = (x1.grad - x3.grad).abs().mean() + + print('GX1', err_sb, err_baseline) + assert err_sb < 2 * err_baseline +