Skip to content

Add k-bit blockwise quantization (K=2-5) with warp-level CUDA kernels#1858

Open
TimDettmers wants to merge 9 commits intomainfrom
feature/kbit-quantization
Open

Add k-bit blockwise quantization (K=2-5) with warp-level CUDA kernels#1858
TimDettmers wants to merge 9 commits intomainfrom
feature/kbit-quantization

Conversation

@TimDettmers
Copy link
Collaborator

@TimDettmers TimDettmers commented Feb 14, 2026

Summary

  • Add k-bit blockwise quantization/dequantization CUDA kernels for K=2, 3, 4, 5 with blocksize=32, using warp-level intrinsics (__ballot_sync packing, __shfl_sync codebook lookup)
  • Add E4M4 uint8 absmax format (bias=11) that reduces per-block scale storage from 4 bytes to 1 byte, bringing K=4 to 4.25 bits/element
  • New Python API: quantize_kbit(), dequantize_kbit(), create_normal_float_codebook(), encode_absmax_e4m4(), decode_absmax_e4m4()
  • 222 tests passing, including correctness, error analysis, NF4 cross-validation, and performance benchmarks

Architecture

Each quantization block of 32 elements maps to exactly one CUDA warp. This enables:

  • Bit-plane packing via __ballot_sync: K calls produce K uint32 words with zero bit waste for any K value. No word-boundary issues for odd K (unlike sequential packing).
  • Codebook lookup via __shfl_sync: Each lane holds one codebook entry (up to 2^K=32 entries fit in warp width). Lookup is register-to-register (~5 cycles), no shared memory needed.
  • Absmax reduction via __shfl_down_sync: 5 reduction steps, no shared memory, no __syncthreads().
  • Multi-block dequant: Each warp processes 4 quantization blocks, amortizing codebook load.

Zero shared memory used. Zero warp divergence in the hot path. Templated on output type (fp16/bf16/fp32) and absmax format (E4M4 uint8/fp16).

E4M4 absmax format

uint8 micro-float with 4-bit exponent, 4-bit mantissa, bias=11, IEEE-style subnormals. Range [6.1e-5, 31.0]. Mean encode/decode relative error: 1.1%, 95th percentile: 2.4%. SQNR degradation vs fp32 absmax: <0.4 dB across all K values. Decode uses direct IEEE 754 bit construction (__uint_as_float) for zero overhead on the dequant hot path.

Benchmarks (RTX 4090, 67M elements, E4M4 absmax)

Dequant kernel throughput

K bits/elem fp16 (us) bf16 (us) fp32 (us) GB/s (fp16) % peak BW
2 2.25 205 208 394 781 78%
3 3.25 215 215 416 786 78%
4 4.25 244 246 420 729 72%
5 5.25 271 270 428 689 68%

Comparison with existing NF4 (fp16 output)

Method bits/elem Dequant (us) GB/s % peak BW vs NF4
NF4 bs=64 4.50 239 754 75% ref
kbit K=2 2.25 215 745 74% 1.11x
kbit K=3 3.25 229 737 73% 1.04x
kbit K=4 4.25 246 722 72% 0.97x
kbit K=5 5.25 271 689 68% 0.88x

K=4 is at parity with NF4 in kernel throughput (0.97x) while using 0.25 fewer bits/element. K=2 and K=3 are faster due to less data to read. Both kernels are bandwidth-bound at 68-78% of peak HBM.

Quality (SQNR, 1M elements, normal distribution)

Method SQNR (dB) MSE bits/elem
NF4 bs=64 20.74 0.0085 4.50
kbit K=2 7.43 0.181 2.25
kbit K=3 14.99 0.032 3.25
kbit K=4 21.09 0.0078 4.25
kbit K=5 25.95 0.0026 5.25

K=4 achieves comparable quality to NF4 at 0.25 fewer bits/element (4.25 vs 4.50). K=3 offers 3.25 bits/element (4.9x compression) with 15 dB SQNR. K=2 provides 7.1x compression for extreme quantization.

Storage comparison

Method bits/elem Compression vs fp16
NF4 bs=64 4.50 3.6x
kbit K=5 5.25 3.0x
kbit K=4 4.25 3.8x
kbit K=3 3.25 4.9x
kbit K=2 2.25 7.1x

Key design decisions

Decision Choice Rationale
Blocksize 32 (fixed) Maps exactly to CUDA warp width for __ballot_sync/__shfl_sync
K range 2-5 Codebook fits in warp lanes (2^K <= 32)
Packing Bit-plane via __ballot_sync Zero waste for any K; sequential packing has word-boundary issues for odd K
Codebook lookup __shfl_sync from lane registers Register-to-register, no shared memory
Absmax format E4M4 uint8 (default), fp16 (option) 1 byte/block vs 2-4 bytes; <0.4 dB SQNR loss
Codebook Symmetric normal-float default, arbitrary user-provided Maximum flexibility; NF codebooks precomputed in Python
Output types fp16, bf16, fp32 (templated) Feature parity with existing kernels
Target hardware CC 7.0+ (Volta) Required for _sync warp primitives

API

from bitsandbytes.functional import quantize_kbit, dequantize_kbit

# Quantize (default: symmetric normal-float codebook, E4M4 absmax)
packed, absmax, codebook = quantize_kbit(A, k=4)

# Dequantize
recovered = dequantize_kbit(packed, absmax, codebook, k=4, n=A.numel(), dtype=torch.float16)

# Custom codebook
my_cb = torch.tensor([...], dtype=torch.float32, device="cuda")
packed, absmax, cb = quantize_kbit(A, k=3, codebook=my_cb)

The default codebook is a symmetric normal-float codebook (quantiles of N(0,1), normalized to [-1, 1]). Unlike the existing NF4 codebook which is asymmetric (7 negative, 1 zero, 8 positive), this codebook has equal representation on both sides (8 negative, 8 positive, no explicit zero). Custom codebooks of any shape can be passed via the codebook parameter.

Test plan

222 tests covering:

  • Pure Python reference implementation
  • CUDA quantize/dequant correctness against reference
  • Error analysis (analytical bounds, MSE scaling, dtype consistency)
  • Cross-validation against existing NF4 at K=4
  • Performance benchmarks (bandwidth utilization, scaling, NF4 comparison)
  • E4M4 absmax (encode/decode round-trip, SQNR degradation, edge cases)
  • Output dtype correctness (bf16/fp32 match fp16 within precision)
  • Asymmetric codebooks (all-positive, all-negative, skewed, non-uniform, duplicate entries)

Files changed

  • csrc/ops.cu — CUDA kernels and launchers (+229 lines)
  • csrc/pythonInterface.cpp — C interface wrappers (+125 lines)
  • bitsandbytes/functional.py — Python API (+194 lines)
  • bitsandbytes/_ops.py — torch.library op definitions (+44 lines)
  • bitsandbytes/backends/cuda/ops.py — CUDA backend dispatch (+90 lines)
  • tests/test_kbit_quantization.py — Test suite (+1372 lines)

TimDettmers and others added 7 commits February 14, 2026 00:50
Implements Stages 0-5 of the k-bit quantization plan from cuda-spec.md:
- Pure Python reference (quantize_kbit_ref, dequantize_kbit_ref) with 57 passing tests
- CUDA kernels using __ballot_sync bit-plane packing and __shfl_sync codebook lookup
- Test kernels (pack/unpack, memory format, codebook lookup) and production kernels
- All C interface symbols exported and loadable via ctypes

CUDA kernels compile but are not yet executable due to an RDC device
linking issue where template instantiations in kernels.cu are not
pulled into the final fatbinary. See KBIT_PROGRESS.md for diagnosis
and recommended fix (move kernel bodies into ops.cu or a new self-contained file).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The "invalid device function" error was caused by mismatched kernel
declarations in kernels.cuh (without __restrict__) vs definitions in
ops.cu (with __restrict__). With CUDA separable compilation (-rdc=true),
this created conflicting host stubs in the function registration.

Fix: remove forward declarations from kernels.cuh, keep kernel
definitions and launch wrappers together in ops.cu. Also added
CUDA_RESOLVE_DEVICE_SYMBOLS ON to CMakeLists.txt.

All 157 tests now pass: Stage 0 (Python ref), Stages 1-3 (CUDA test
kernels), Stage 4 (quantize), Stage 5 (dequantize) -- covering K=2-5,
fp16/bf16/fp32, various tensor sizes, and analytical error bounds.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Stage 6: Error analysis on 1M+ elements (analytical bounds, MSE, SQNR)
- Stage 7: Cross-validation against existing NF4 dequant
- Stage 8: Performance benchmarks (bandwidth utilization, throughput scaling)
- Python API: quantize_kbit(), dequantize_kbit(), create_normal_float_codebook()
  in functional.py with torch.library registration in _ops.py and CUDA
  kernel dispatch in backends/cuda/ops.py
- Codebook caching per (k, device) pair

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Not needed in the final branch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Vectorized dequant kernel (half2 stores, 4 blocks/warp) gives 1.23-1.29x
speedup over scalar kernel, reaching 80-87% of peak HBM bandwidth.
Routes fp16 output through vectorized path; bf16/fp32 use scalar fallback.

E4M4 uint8 absmax (bias=11, IEEE-style subnormals) reduces absmax storage
from 4 bytes to 1 byte per block. K=4 drops from 5.0 to 4.25 bits/elem,
matching NF4 bs=64 storage. SQNR degradation is <0.4 dB across all K
values. Decode uses direct IEEE 754 bit construction for zero overhead
on the dequant hot path.

240 tests passing (22 new E4M4 tests).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Remove scalar dequant kernel (vectorized is strictly better)
- Remove fp32 absmax dequant path; E4M4 uint8 is now the default,
  fp16 absmax kept as an option
- Remove Stage 1-3 test scaffolding kernels (pack/unpack, memory
  format, codebook lookup) and their C wrappers
- Dequant always produces fp16 at the CUDA level; bf16/fp32 output
  via cast in Python
- Net removal of 334 lines; 188 tests passing

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace half2-specific vectorized kernel with a generic version
templated on T (output type) and ABSMAX_T (absmax format). Scalar
stores via (T)val; hardware coalesces warp writes. No fp16 regression
(within benchmark noise). Native bf16 and fp32 output at the kernel
level — no Python-side cast needed.

Add output dtype correctness tests (bf16/fp32 match fp16) and
asymmetric codebook tests (all-positive, all-negative, skewed,
non-uniform spacing, duplicate entries). 222 tests passing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@github-actions
Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

TimDettmers and others added 2 commits February 14, 2026 01:04
Apply ruff lint fix (unused variable), ruff format, and clang-format
to pass CI pre-commit hooks.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
The error bound was using a flat 1.25x multiplier on the quantization
error, but E4M4 absmax quantization adds up to 1/16 (6.25%) absolute
scale error. For K=5 where the codebook gap is ~0.0625, this E4M4
error is 2x the quantization error itself, exceeding the 1.25x margin.

Fix by computing the bound correctly as (max_gap/2 + 1/16) * absmax,
which adds both error sources instead of scaling one by a fixed factor.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant