Add k-bit blockwise quantization (K=2-5) with warp-level CUDA kernels#1858
Open
TimDettmers wants to merge 9 commits intomainfrom
Open
Add k-bit blockwise quantization (K=2-5) with warp-level CUDA kernels#1858TimDettmers wants to merge 9 commits intomainfrom
TimDettmers wants to merge 9 commits intomainfrom
Conversation
Implements Stages 0-5 of the k-bit quantization plan from cuda-spec.md: - Pure Python reference (quantize_kbit_ref, dequantize_kbit_ref) with 57 passing tests - CUDA kernels using __ballot_sync bit-plane packing and __shfl_sync codebook lookup - Test kernels (pack/unpack, memory format, codebook lookup) and production kernels - All C interface symbols exported and loadable via ctypes CUDA kernels compile but are not yet executable due to an RDC device linking issue where template instantiations in kernels.cu are not pulled into the final fatbinary. See KBIT_PROGRESS.md for diagnosis and recommended fix (move kernel bodies into ops.cu or a new self-contained file). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The "invalid device function" error was caused by mismatched kernel declarations in kernels.cuh (without __restrict__) vs definitions in ops.cu (with __restrict__). With CUDA separable compilation (-rdc=true), this created conflicting host stubs in the function registration. Fix: remove forward declarations from kernels.cuh, keep kernel definitions and launch wrappers together in ops.cu. Also added CUDA_RESOLVE_DEVICE_SYMBOLS ON to CMakeLists.txt. All 157 tests now pass: Stage 0 (Python ref), Stages 1-3 (CUDA test kernels), Stage 4 (quantize), Stage 5 (dequantize) -- covering K=2-5, fp16/bf16/fp32, various tensor sizes, and analytical error bounds. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Stage 6: Error analysis on 1M+ elements (analytical bounds, MSE, SQNR) - Stage 7: Cross-validation against existing NF4 dequant - Stage 8: Performance benchmarks (bandwidth utilization, throughput scaling) - Python API: quantize_kbit(), dequantize_kbit(), create_normal_float_codebook() in functional.py with torch.library registration in _ops.py and CUDA kernel dispatch in backends/cuda/ops.py - Codebook caching per (k, device) pair Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Not needed in the final branch. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Vectorized dequant kernel (half2 stores, 4 blocks/warp) gives 1.23-1.29x speedup over scalar kernel, reaching 80-87% of peak HBM bandwidth. Routes fp16 output through vectorized path; bf16/fp32 use scalar fallback. E4M4 uint8 absmax (bias=11, IEEE-style subnormals) reduces absmax storage from 4 bytes to 1 byte per block. K=4 drops from 5.0 to 4.25 bits/elem, matching NF4 bs=64 storage. SQNR degradation is <0.4 dB across all K values. Decode uses direct IEEE 754 bit construction for zero overhead on the dequant hot path. 240 tests passing (22 new E4M4 tests). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Remove scalar dequant kernel (vectorized is strictly better) - Remove fp32 absmax dequant path; E4M4 uint8 is now the default, fp16 absmax kept as an option - Remove Stage 1-3 test scaffolding kernels (pack/unpack, memory format, codebook lookup) and their C wrappers - Dequant always produces fp16 at the CUDA level; bf16/fp32 output via cast in Python - Net removal of 334 lines; 188 tests passing Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace half2-specific vectorized kernel with a generic version templated on T (output type) and ABSMAX_T (absmax format). Scalar stores via (T)val; hardware coalesces warp writes. No fp16 regression (within benchmark noise). Native bf16 and fp32 output at the kernel level — no Python-side cast needed. Add output dtype correctness tests (bf16/fp32 match fp16) and asymmetric codebook tests (all-positive, all-negative, skewed, non-uniform spacing, duplicate entries). 222 tests passing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Apply ruff lint fix (unused variable), ruff format, and clang-format to pass CI pre-commit hooks. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The error bound was using a flat 1.25x multiplier on the quantization error, but E4M4 absmax quantization adds up to 1/16 (6.25%) absolute scale error. For K=5 where the codebook gap is ~0.0625, this E4M4 error is 2x the quantization error itself, exceeding the 1.25x margin. Fix by computing the bound correctly as (max_gap/2 + 1/16) * absmax, which adds both error sources instead of scaling one by a fixed factor. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
__ballot_syncpacking,__shfl_synccodebook lookup)quantize_kbit(),dequantize_kbit(),create_normal_float_codebook(),encode_absmax_e4m4(),decode_absmax_e4m4()Architecture
Each quantization block of 32 elements maps to exactly one CUDA warp. This enables:
__ballot_sync: K calls produce K uint32 words with zero bit waste for any K value. No word-boundary issues for odd K (unlike sequential packing).__shfl_sync: Each lane holds one codebook entry (up to 2^K=32 entries fit in warp width). Lookup is register-to-register (~5 cycles), no shared memory needed.__shfl_down_sync: 5 reduction steps, no shared memory, no__syncthreads().Zero shared memory used. Zero warp divergence in the hot path. Templated on output type (fp16/bf16/fp32) and absmax format (E4M4 uint8/fp16).
E4M4 absmax format
uint8 micro-float with 4-bit exponent, 4-bit mantissa, bias=11, IEEE-style subnormals. Range [6.1e-5, 31.0]. Mean encode/decode relative error: 1.1%, 95th percentile: 2.4%. SQNR degradation vs fp32 absmax: <0.4 dB across all K values. Decode uses direct IEEE 754 bit construction (
__uint_as_float) for zero overhead on the dequant hot path.Benchmarks (RTX 4090, 67M elements, E4M4 absmax)
Dequant kernel throughput
Comparison with existing NF4 (fp16 output)
K=4 is at parity with NF4 in kernel throughput (0.97x) while using 0.25 fewer bits/element. K=2 and K=3 are faster due to less data to read. Both kernels are bandwidth-bound at 68-78% of peak HBM.
Quality (SQNR, 1M elements, normal distribution)
K=4 achieves comparable quality to NF4 at 0.25 fewer bits/element (4.25 vs 4.50). K=3 offers 3.25 bits/element (4.9x compression) with 15 dB SQNR. K=2 provides 7.1x compression for extreme quantization.
Storage comparison
Key design decisions
__ballot_sync/__shfl_sync__ballot_sync__shfl_syncfrom lane registers_syncwarp primitivesAPI
The default codebook is a symmetric normal-float codebook (quantiles of N(0,1), normalized to [-1, 1]). Unlike the existing NF4 codebook which is asymmetric (7 negative, 1 zero, 8 positive), this codebook has equal representation on both sides (8 negative, 8 positive, no explicit zero). Custom codebooks of any shape can be passed via the
codebookparameter.Test plan
222 tests covering:
Files changed
csrc/ops.cu— CUDA kernels and launchers (+229 lines)csrc/pythonInterface.cpp— C interface wrappers (+125 lines)bitsandbytes/functional.py— Python API (+194 lines)bitsandbytes/_ops.py— torch.library op definitions (+44 lines)bitsandbytes/backends/cuda/ops.py— CUDA backend dispatch (+90 lines)tests/test_kbit_quantization.py— Test suite (+1372 lines)