diff --git a/Cargo.lock b/Cargo.lock index 8ce19ebc21f..0e41213df3a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10067,7 +10067,9 @@ dependencies = [ "fastlanes", "mimalloc", "parquet 58.0.0", + "paste", "rand 0.10.0", + "rand_distr 0.6.0", "serde_json", "tokio", "tracing", @@ -10097,6 +10099,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", + "vortex-turboquant", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10257,6 +10260,7 @@ dependencies = [ "vortex-runend", "vortex-sequence", "vortex-sparse", + "vortex-turboquant", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10570,6 +10574,7 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", + "vortex-turboquant", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10989,6 +10994,23 @@ dependencies = [ "web-sys", ] +[[package]] +name = "vortex-turboquant" +version = "0.1.0" +dependencies = [ + "half", + "prost 0.14.3", + "rand 0.10.0", + "rand_distr 0.6.0", + "rstest", + "vortex-array", + "vortex-buffer", + "vortex-error", + "vortex-fastlanes", + "vortex-session", + "vortex-utils", +] + [[package]] name = "vortex-utils" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 75353ca0b3a..5098c0a7db9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ members = [ "encodings/zstd", "encodings/bytebool", "encodings/parquet-variant", + "encodings/turboquant", # Benchmarks "benchmarks/lance-bench", "benchmarks/compress-bench", @@ -285,6 +286,7 @@ vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-feat vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false } vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false } vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false } +vortex-turboquant = { version = "0.1.0", path = "./encodings/turboquant", default-features = false } vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false } vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false } vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false } diff --git a/encodings/turboquant/Cargo.toml b/encodings/turboquant/Cargo.toml new file mode 100644 index 00000000000..71504a71f82 --- /dev/null +++ b/encodings/turboquant/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "vortex-turboquant" +authors = { workspace = true } +categories = { workspace = true } +description = "Vortex TurboQuant vector quantization encoding" +edition = { workspace = true } +homepage = { workspace = true } +include = { workspace = true } +keywords = { workspace = true } +license = { workspace = true } +readme = { workspace = true } +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +half = { workspace = true } +prost = { workspace = true } +rand = { workspace = true } +vortex-array = { workspace = true } +vortex-buffer = { workspace = true } +vortex-error = { workspace = true } +vortex-fastlanes = { workspace = true } +vortex-session = { workspace = true } +vortex-utils = { workspace = true } + +[dev-dependencies] +rand_distr = { workspace = true } +rstest = { workspace = true } +vortex-array = { workspace = true, features = ["_test-harness"] } diff --git a/encodings/turboquant/public-api.lock b/encodings/turboquant/public-api.lock new file mode 100644 index 00000000000..f48b7834d6a --- /dev/null +++ b/encodings/turboquant/public-api.lock @@ -0,0 +1,185 @@ +pub mod vortex_turboquant + +pub struct vortex_turboquant::QjlCorrection + +impl vortex_turboquant::QjlCorrection + +pub fn vortex_turboquant::QjlCorrection::residual_norms(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::QjlCorrection::rotation_signs(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::QjlCorrection::signs(&self) -> &vortex_array::array::ArrayRef + +impl core::clone::Clone for vortex_turboquant::QjlCorrection + +pub fn vortex_turboquant::QjlCorrection::clone(&self) -> vortex_turboquant::QjlCorrection + +impl core::fmt::Debug for vortex_turboquant::QjlCorrection + +pub fn vortex_turboquant::QjlCorrection::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub struct vortex_turboquant::TurboQuant + +impl vortex_turboquant::TurboQuant + +pub const vortex_turboquant::TurboQuant::ID: vortex_array::vtable::dyn_::ArrayId + +impl core::clone::Clone for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::clone(&self) -> vortex_turboquant::TurboQuant + +impl core::fmt::Debug for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::arrays::dict::take::TakeExecute for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::take(array: &vortex_turboquant::TurboQuantArray, indices: &vortex_array::array::ArrayRef, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + +impl vortex_array::arrays::slice::SliceReduce for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::slice(array: &vortex_turboquant::TurboQuantArray, range: core::ops::range::Range) -> vortex_error::VortexResult> + +impl vortex_array::vtable::VTable for vortex_turboquant::TurboQuant + +pub type vortex_turboquant::TurboQuant::Array = vortex_turboquant::TurboQuantArray + +pub type vortex_turboquant::TurboQuant::Metadata = vortex_array::metadata::ProstMetadata + +pub type vortex_turboquant::TurboQuant::OperationsVTable = vortex_turboquant::TurboQuant + +pub type vortex_turboquant::TurboQuant::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromChild + +pub fn vortex_turboquant::TurboQuant::array_eq(array: &vortex_turboquant::TurboQuantArray, other: &vortex_turboquant::TurboQuantArray, precision: vortex_array::hash::Precision) -> bool + +pub fn vortex_turboquant::TurboQuant::array_hash(array: &vortex_turboquant::TurboQuantArray, state: &mut H, precision: vortex_array::hash::Precision) + +pub fn vortex_turboquant::TurboQuant::buffer(_array: &vortex_turboquant::TurboQuantArray, idx: usize) -> vortex_array::buffer::BufferHandle + +pub fn vortex_turboquant::TurboQuant::buffer_name(_array: &vortex_turboquant::TurboQuantArray, _idx: usize) -> core::option::Option + +pub fn vortex_turboquant::TurboQuant::build(dtype: &vortex_array::dtype::DType, len: usize, metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuant::child(array: &vortex_turboquant::TurboQuantArray, idx: usize) -> vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuant::child_name(_array: &vortex_turboquant::TurboQuantArray, idx: usize) -> alloc::string::String + +pub fn vortex_turboquant::TurboQuant::deserialize(bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuant::dtype(array: &vortex_turboquant::TurboQuantArray) -> &vortex_array::dtype::DType + +pub fn vortex_turboquant::TurboQuant::execute(array: alloc::sync::Arc>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuant::execute_parent(array: &vortex_array::vtable::typed::Array, parent: &vortex_array::array::ArrayRef, child_idx: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult> + +pub fn vortex_turboquant::TurboQuant::id(&self) -> vortex_array::vtable::dyn_::ArrayId + +pub fn vortex_turboquant::TurboQuant::len(array: &vortex_turboquant::TurboQuantArray) -> usize + +pub fn vortex_turboquant::TurboQuant::metadata(array: &vortex_turboquant::TurboQuantArray) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuant::nbuffers(_array: &vortex_turboquant::TurboQuantArray) -> usize + +pub fn vortex_turboquant::TurboQuant::nchildren(array: &vortex_turboquant::TurboQuantArray) -> usize + +pub fn vortex_turboquant::TurboQuant::reduce_parent(array: &vortex_array::vtable::typed::Array, parent: &vortex_array::array::ArrayRef, child_idx: usize) -> vortex_error::VortexResult> + +pub fn vortex_turboquant::TurboQuant::serialize(metadata: Self::Metadata) -> vortex_error::VortexResult>> + +pub fn vortex_turboquant::TurboQuant::stats(array: &vortex_turboquant::TurboQuantArray) -> vortex_array::stats::array::StatsSetRef<'_> + +pub fn vortex_turboquant::TurboQuant::vtable(_array: &Self::Array) -> &Self + +pub fn vortex_turboquant::TurboQuant::with_children(array: &mut Self::Array, children: alloc::vec::Vec) -> vortex_error::VortexResult<()> + +impl vortex_array::vtable::operations::OperationsVTable for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::scalar_at(array: &vortex_turboquant::TurboQuantArray, index: usize, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +impl vortex_array::vtable::validity::ValidityChild for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::validity_child(array: &vortex_turboquant::TurboQuantArray) -> &vortex_array::array::ArrayRef + +pub struct vortex_turboquant::TurboQuantArray + +impl vortex_turboquant::TurboQuantArray + +pub fn vortex_turboquant::TurboQuantArray::bit_width(&self) -> u8 + +pub fn vortex_turboquant::TurboQuantArray::centroids(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantArray::codes(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantArray::dimension(&self) -> u32 + +pub fn vortex_turboquant::TurboQuantArray::has_qjl(&self) -> bool + +pub fn vortex_turboquant::TurboQuantArray::norms(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantArray::padded_dim(&self) -> u32 + +pub fn vortex_turboquant::TurboQuantArray::qjl(&self) -> core::option::Option<&vortex_turboquant::QjlCorrection> + +pub fn vortex_turboquant::TurboQuantArray::rotation_signs(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_turboquant::TurboQuantArray::try_new_mse(dtype: vortex_array::dtype::DType, codes: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef, centroids: vortex_array::array::ArrayRef, rotation_signs: vortex_array::array::ArrayRef, dimension: u32, bit_width: u8) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuantArray::try_new_qjl(dtype: vortex_array::dtype::DType, codes: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef, centroids: vortex_array::array::ArrayRef, rotation_signs: vortex_array::array::ArrayRef, qjl: vortex_turboquant::QjlCorrection, dimension: u32, bit_width: u8) -> vortex_error::VortexResult + +impl vortex_turboquant::TurboQuantArray + +pub fn vortex_turboquant::TurboQuantArray::to_array(&self) -> vortex_array::array::ArrayRef + +impl core::clone::Clone for vortex_turboquant::TurboQuantArray + +pub fn vortex_turboquant::TurboQuantArray::clone(&self) -> vortex_turboquant::TurboQuantArray + +impl core::convert::AsRef for vortex_turboquant::TurboQuantArray + +pub fn vortex_turboquant::TurboQuantArray::as_ref(&self) -> &dyn vortex_array::array::DynArray + +impl core::convert::From for vortex_array::array::ArrayRef + +pub fn vortex_array::array::ArrayRef::from(value: vortex_turboquant::TurboQuantArray) -> vortex_array::array::ArrayRef + +impl core::fmt::Debug for vortex_turboquant::TurboQuantArray + +pub fn vortex_turboquant::TurboQuantArray::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::ops::deref::Deref for vortex_turboquant::TurboQuantArray + +pub type vortex_turboquant::TurboQuantArray::Target = dyn vortex_array::array::DynArray + +pub fn vortex_turboquant::TurboQuantArray::deref(&self) -> &Self::Target + +impl vortex_array::array::IntoArray for vortex_turboquant::TurboQuantArray + +pub fn vortex_turboquant::TurboQuantArray::into_array(self) -> vortex_array::array::ArrayRef + +pub struct vortex_turboquant::TurboQuantConfig + +pub vortex_turboquant::TurboQuantConfig::bit_width: u8 + +pub vortex_turboquant::TurboQuantConfig::seed: core::option::Option + +impl core::clone::Clone for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::clone(&self) -> vortex_turboquant::TurboQuantConfig + +impl core::default::Default for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::default() -> Self + +impl core::fmt::Debug for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub const vortex_turboquant::FIXED_SHAPE_TENSOR_EXT_ID: &str + +pub const vortex_turboquant::VECTOR_EXT_ID: &str + +pub fn vortex_turboquant::initialize(session: &mut vortex_session::VortexSession) + +pub fn vortex_turboquant::turboquant_encode_mse(fsl: &vortex_array::arrays::fixed_size_list::array::FixedSizeListArray, config: &vortex_turboquant::TurboQuantConfig) -> vortex_error::VortexResult + +pub fn vortex_turboquant::turboquant_encode_qjl(fsl: &vortex_array::arrays::fixed_size_list::array::FixedSizeListArray, config: &vortex_turboquant::TurboQuantConfig) -> vortex_error::VortexResult diff --git a/encodings/turboquant/src/array.rs b/encodings/turboquant/src/array.rs new file mode 100644 index 00000000000..682ced5a0bd --- /dev/null +++ b/encodings/turboquant/src/array.rs @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant array definition: stores quantized coordinate codes, norms, +//! centroids (codebook), rotation signs, and optional QJL correction fields. + +use vortex_array::ArrayRef; +use vortex_array::dtype::DType; +use vortex_array::stats::ArrayStats; +use vortex_array::vtable; +use vortex_array::vtable::ArrayId; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// Encoding marker type for TurboQuant. +#[derive(Clone, Debug)] +pub struct TurboQuant; + +impl TurboQuant { + pub const ID: ArrayId = ArrayId::new_ref("vortex.turboquant"); +} + +vtable!(TurboQuant); + +/// Protobuf metadata for TurboQuant encoding. +#[derive(Clone, prost::Message)] +pub struct TurboQuantMetadata { + /// Vector dimension d. + #[prost(uint32, tag = "1")] + pub dimension: u32, + /// MSE bits per coordinate (1-8). + #[prost(uint32, tag = "2")] + pub bit_width: u32, + /// Whether QJL correction children are present. + #[prost(bool, tag = "3")] + pub has_qjl: bool, +} + +/// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased +/// inner product estimation. When present, adds 3 additional children. +#[derive(Clone, Debug)] +pub struct QjlCorrection { + /// Sign bits: `BoolArray`, length `num_rows * padded_dim`. + pub(crate) signs: ArrayRef, + /// Residual norms: `PrimitiveArray`, length `num_rows`. + pub(crate) residual_norms: ArrayRef, + /// QJL rotation signs: `BoolArray`, length `3 * padded_dim` (inverse order). + pub(crate) rotation_signs: ArrayRef, +} + +impl QjlCorrection { + /// The QJL sign bits. + pub fn signs(&self) -> &ArrayRef { + &self.signs + } + + /// The residual norms. + pub fn residual_norms(&self) -> &ArrayRef { + &self.residual_norms + } + + /// The QJL rotation signs (BoolArray, inverse application order). + pub fn rotation_signs(&self) -> &ArrayRef { + &self.rotation_signs + } +} + +/// TurboQuant array. +/// +/// Core children (always present): +/// - 0: `codes` — `BitPackedArray` or `PrimitiveArray` (quantized indices) +/// - 1: `norms` — `PrimitiveArray` (one per vector row) +/// - 2: `centroids` — `PrimitiveArray` (codebook, length 2^bit_width) +/// - 3: `rotation_signs` — `BoolArray` (3 * padded_dim bits, inverse application order) +/// +/// Optional QJL children (when `has_qjl` is true): +/// - 4: `qjl_signs` — `BoolArray` (num_rows * padded_dim bits) +/// - 5: `qjl_residual_norms` — `PrimitiveArray` (one per row) +/// - 6: `qjl_rotation_signs` — `BoolArray` (3 * padded_dim bits, QJL rotation, inverse order) +#[derive(Clone, Debug)] +pub struct TurboQuantArray { + pub(crate) dtype: DType, + pub(crate) codes: ArrayRef, + pub(crate) norms: ArrayRef, + pub(crate) centroids: ArrayRef, + pub(crate) rotation_signs: ArrayRef, + pub(crate) qjl: Option, + pub(crate) dimension: u32, + pub(crate) bit_width: u8, + pub(crate) stats_set: ArrayStats, +} + +impl TurboQuantArray { + /// Build a TurboQuant array with MSE-only encoding (no QJL correction). + #[allow(clippy::too_many_arguments)] + pub fn try_new_mse( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + dimension: u32, + bit_width: u8, + ) -> VortexResult { + vortex_ensure!( + (1..=8).contains(&bit_width), + "MSE bit_width must be 1-8, got {bit_width}" + ); + Ok(Self { + dtype, + codes, + norms, + centroids, + rotation_signs, + qjl: None, + dimension, + bit_width, + stats_set: Default::default(), + }) + } + + /// Build a TurboQuant array with QJL correction (MSE + QJL). + #[allow(clippy::too_many_arguments)] + pub fn try_new_qjl( + dtype: DType, + codes: ArrayRef, + norms: ArrayRef, + centroids: ArrayRef, + rotation_signs: ArrayRef, + qjl: QjlCorrection, + dimension: u32, + bit_width: u8, + ) -> VortexResult { + vortex_ensure!( + (1..=8).contains(&bit_width), + "MSE bit_width must be 1-8, got {bit_width}" + ); + Ok(Self { + dtype, + codes, + norms, + centroids, + rotation_signs, + qjl: Some(qjl), + dimension, + bit_width, + stats_set: Default::default(), + }) + } + + /// The vector dimension d. + pub fn dimension(&self) -> u32 { + self.dimension + } + + /// MSE bits per coordinate. + pub fn bit_width(&self) -> u8 { + self.bit_width + } + + /// Padded dimension (next power of 2 >= dimension). + pub fn padded_dim(&self) -> u32 { + self.dimension.next_power_of_two() + } + + /// Whether QJL correction is present. + pub fn has_qjl(&self) -> bool { + self.qjl.is_some() + } + + /// The quantized codes child. + pub fn codes(&self) -> &ArrayRef { + &self.codes + } + + /// The norms child. + pub fn norms(&self) -> &ArrayRef { + &self.norms + } + + /// The centroids (codebook) child. + pub fn centroids(&self) -> &ArrayRef { + &self.centroids + } + + /// The MSE rotation signs child (BoolArray, length 3 * padded_dim). + pub fn rotation_signs(&self) -> &ArrayRef { + &self.rotation_signs + } + + /// The optional QJL correction. + pub fn qjl(&self) -> Option<&QjlCorrection> { + self.qjl.as_ref() + } +} diff --git a/encodings/turboquant/src/centroids.rs b/encodings/turboquant/src/centroids.rs new file mode 100644 index 00000000000..4742cbab3a4 --- /dev/null +++ b/encodings/turboquant/src/centroids.rs @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. +//! +//! Pre-computes optimal scalar quantizer centroids for the marginal distribution of coordinates +//! after random rotation of a unit-norm vector. In high dimensions, each coordinate of a randomly +//! rotated unit vector follows a distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, +//! which converges to `N(0, 1/d)`. The Max-Lloyd algorithm finds optimal quantization centroids +//! that minimize MSE for this distribution. + +use std::sync::LazyLock; + +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_utils::aliases::dash_map::DashMap; + +/// Number of numerical integration points for computing conditional expectations. +const INTEGRATION_POINTS: usize = 1000; + +/// Max-Lloyd convergence threshold. +const CONVERGENCE_EPSILON: f64 = 1e-12; + +/// Maximum iterations for Max-Lloyd algorithm. +const MAX_ITERATIONS: usize = 200; + +/// Global centroid cache keyed by (dimension, bit_width). +static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); + +/// Get or compute cached centroids for the given dimension and bit width. +/// +/// Returns `2^bit_width` centroids sorted in ascending order, representing +/// optimal scalar quantization levels for the coordinate distribution after +/// random rotation in `dimension`-dimensional space. +pub fn get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { + if !(1..=8).contains(&bit_width) { + vortex_bail!("TurboQuant bit_width must be 1-8, got {bit_width}"); + } + if dimension < 2 { + vortex_bail!("TurboQuant dimension must be >= 2, got {dimension}"); + } + + if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { + return Ok(centroids.clone()); + } + + let centroids = max_lloyd_centroids(dimension, bit_width); + CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); + Ok(centroids) +} + +/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. +/// +/// Operates on the marginal distribution of a single coordinate of a randomly +/// rotated unit vector in d dimensions. The PDF is: +/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` +/// where `C_d` is the normalizing constant. +fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Vec { + let num_centroids = 1usize << bit_width; + let dim = dimension as f64; + + // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. + let exponent = (dim - 3.0) / 2.0; + + // Initialize centroids uniformly on [-1, 1]. + let mut centroids: Vec = (0..num_centroids) + .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .collect(); + + let mut boundaries: Vec = vec![0.0; num_centroids + 1]; + for _ in 0..MAX_ITERATIONS { + // Compute decision boundaries (midpoints between adjacent centroids). + boundaries[0] = -1.0; + for idx in 0..num_centroids - 1 { + boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0; + } + boundaries[num_centroids] = 1.0; + + // Update each centroid to the conditional mean within its Voronoi cell. + let mut max_change = 0.0f64; + for idx in 0..num_centroids { + let lo = boundaries[idx]; + let hi = boundaries[idx + 1]; + let new_centroid = conditional_mean(lo, hi, exponent); + max_change = max_change.max((new_centroid - centroids[idx]).abs()); + centroids[idx] = new_centroid; + } + + if max_change < CONVERGENCE_EPSILON { + break; + } + } + + centroids.into_iter().map(|val| val as f32).collect() +} + +/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. +/// +/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` +/// on [-1, 1]. +fn conditional_mean(lo: f64, hi: f64, exponent: f64) -> f64 { + if (hi - lo).abs() < 1e-15 { + return (lo + hi) / 2.0; + } + + let dx = (hi - lo) / INTEGRATION_POINTS as f64; + + let mut numerator = 0.0; + let mut denominator = 0.0; + + for step in 0..=INTEGRATION_POINTS { + let x_val = lo + (step as f64) * dx; + let weight = pdf_unnormalized(x_val, exponent); + + let trap_weight = if step == 0 || step == INTEGRATION_POINTS { + 0.5 + } else { + 1.0 + }; + + numerator += trap_weight * x_val * weight; + denominator += trap_weight * weight; + } + + if denominator.abs() < 1e-30 { + (lo + hi) / 2.0 + } else { + numerator / denominator + } +} + +/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. +/// +/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents +/// that arise from `(d-3)/2`. This is significantly faster than the general +/// `powf` which goes through `exp(exponent * ln(base))`. +#[inline] +fn pdf_unnormalized(x_val: f64, exponent: f64) -> f64 { + let base = (1.0 - x_val * x_val).max(0.0); + + let int_part = exponent as i32; + let frac = exponent - int_part as f64; + if frac.abs() < 1e-10 { + // Integer exponent: use powi. + base.powi(int_part) + } else { + // Half-integer exponent: powi(floor) * sqrt(base). + base.powi(int_part) * base.sqrt() + } +} + +/// Precompute decision boundaries (midpoints between adjacent centroids). +/// +/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps +/// to centroid 0, a value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, +/// and a value >= `boundaries[k-2]` maps to centroid `k-1`. +pub fn compute_boundaries(centroids: &[f32]) -> Vec { + centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() +} + +/// Find the index of the nearest centroid using precomputed decision boundaries. +/// +/// `boundaries` must be the output of [`compute_boundaries`] for the corresponding +/// centroids. Uses binary search on the midpoints, avoiding distance comparisons +/// in the inner loop. +#[inline] +pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { + debug_assert!( + boundaries.windows(2).all(|w| w[0] <= w[1]), + "boundaries must be sorted" + ); + + boundaries.partition_point(|&b| b < value) as u8 +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[rstest] + #[case(128, 1, 2)] + #[case(128, 2, 4)] + #[case(128, 3, 8)] + #[case(128, 4, 16)] + #[case(768, 2, 4)] + #[case(1536, 3, 8)] + fn centroids_have_correct_count( + #[case] dim: u32, + #[case] bits: u8, + #[case] expected: usize, + ) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + assert_eq!(centroids.len(), expected); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(768, 2)] + fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for window in centroids.windows(2) { + assert!( + window[0] < window[1], + "centroids not sorted: {:?}", + centroids + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(256, 2)] + #[case(768, 2)] + fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + let count = centroids.len(); + for idx in 0..count / 2 { + let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); + assert!( + diff < 1e-5, + "centroids not symmetric: c[{idx}]={}, c[{}]={}", + centroids[idx], + count - 1 - idx, + centroids[count - 1 - idx] + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 4)] + fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = get_centroids(dim, bits)?; + for &val in ¢roids { + assert!( + (-1.0..=1.0).contains(&val), + "centroid out of [-1, 1]: {val}", + ); + } + Ok(()) + } + + #[test] + fn centroids_cached() -> VortexResult<()> { + let c1 = get_centroids(128, 2)?; + let c2 = get_centroids(128, 2)?; + assert_eq!(c1, c2); + Ok(()) + } + + #[test] + fn find_nearest_basic() -> VortexResult<()> { + let centroids = get_centroids(128, 2)?; + let boundaries = compute_boundaries(¢roids); + assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); + + let last_idx = (centroids.len() - 1) as u8; + assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); + for (idx, &cv) in centroids.iter().enumerate() { + let expected = idx as u8; + assert_eq!(find_nearest_centroid(cv, &boundaries), expected); + } + Ok(()) + } + + #[test] + fn rejects_invalid_params() { + assert!(get_centroids(128, 0).is_err()); + assert!(get_centroids(128, 9).is_err()); + assert!(get_centroids(1, 2).is_err()); + } +} diff --git a/encodings/turboquant/src/compress.rs b/encodings/turboquant/src/compress.rs new file mode 100644 index 00000000000..30c5c5fbd45 --- /dev/null +++ b/encodings/turboquant/src/compress.rs @@ -0,0 +1,341 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant encoding (quantization) logic. + +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_fastlanes::bitpack_compress::bitpack_encode; + +use crate::array::QjlCorrection; +use crate::array::TurboQuantArray; +use crate::centroids::compute_boundaries; +use crate::centroids::find_nearest_centroid; +use crate::centroids::get_centroids; +use crate::rotation::RotationMatrix; + +/// Configuration for TurboQuant encoding. +#[derive(Clone, Debug)] +pub struct TurboQuantConfig { + /// Bits per coordinate. + /// + /// For MSE encoding: 1-8. + /// For QJL encoding: 2-9 (the MSE component uses `bit_width - 1`). + pub bit_width: u8, + /// Optional seed for the rotation matrix. If None, the default seed is used. + pub seed: Option, +} + +impl Default for TurboQuantConfig { + fn default() -> Self { + Self { + bit_width: 5, + seed: Some(42), + } + } +} + +/// Extract elements from a FixedSizeListArray as a flat f32 PrimitiveArray. +#[allow(clippy::cast_possible_truncation)] +fn extract_f32_elements(fsl: &FixedSizeListArray) -> VortexResult { + let elements = fsl.elements(); + let primitive = elements.to_canonical()?.into_primitive(); + let ptype = primitive.ptype(); + + match ptype { + PType::F16 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| f32::from(v)) + .collect()), + PType::F32 => Ok(primitive), + PType::F64 => Ok(primitive + .as_slice::() + .iter() + .map(|&v| v as f32) + .collect()), + _ => vortex_bail!("TurboQuant requires float elements, got {ptype:?}"), + } +} + +/// Compute the L2 norm of a vector. +#[inline] +fn l2_norm(x: &[f32]) -> f32 { + x.iter().map(|&v| v * v).sum::().sqrt() +} + +/// Shared intermediate results from the MSE quantization loop. +struct MseQuantizationResult { + rotation: RotationMatrix, + f32_elements: PrimitiveArray, + centroids: Vec, + all_indices: BufferMut, + norms: BufferMut, + padded_dim: usize, +} + +/// Core quantization: extract f32 elements, build rotation, normalize/rotate/quantize all rows. +fn turboquant_quantize_core( + fsl: &FixedSizeListArray, + seed: u64, + bit_width: u8, +) -> VortexResult { + let dimension = fsl.list_size() as usize; + let num_rows = fsl.len(); + + let rotation = RotationMatrix::try_new(seed, dimension)?; + let padded_dim = rotation.padded_dim(); + + let f32_elements = extract_f32_elements(fsl)?; + + let centroids = get_centroids(padded_dim as u32, bit_width)?; + let boundaries = compute_boundaries(¢roids); + + let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); + let mut norms = BufferMut::::with_capacity(num_rows); + let mut padded = vec![0.0f32; padded_dim]; + let mut rotated = vec![0.0f32; padded_dim]; + + let f32_slice = f32_elements.as_slice::(); + for row in 0..num_rows { + let x = &f32_slice[row * dimension..(row + 1) * dimension]; + let norm = l2_norm(x); + norms.push(norm); + + if norm > 0.0 { + let inv_norm = 1.0 / norm; + for (dst, &src) in padded[..dimension].iter_mut().zip(x.iter()) { + *dst = src * inv_norm; + } + } else { + padded[..dimension].fill(0.0); + } + rotation.rotate(&padded, &mut rotated); + + for j in 0..padded_dim { + all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); + } + } + + Ok(MseQuantizationResult { + rotation, + f32_elements, + centroids, + all_indices, + norms, + padded_dim, + }) +} + +/// Build a `TurboQuantArray` (MSE-only) from quantization results. +fn build_turboquant_mse( + fsl: &FixedSizeListArray, + core: MseQuantizationResult, + bit_width: u8, +) -> VortexResult { + let dimension = fsl.list_size(); + + let num_rows = fsl.len(); + let padded_dim = core.padded_dim; + let codes_elements = + PrimitiveArray::new::(core.all_indices.freeze(), Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes_elements.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )? + .into_array(); + let norms_array = + PrimitiveArray::new::(core.norms.freeze(), Validity::NonNullable).into_array(); + + // TODO(perf): `get_centroids` returns Vec; could avoid the copy by + // supporting Buffer::from(Vec) or caching as Buffer directly. + let mut centroids_buf = BufferMut::::with_capacity(core.centroids.len()); + centroids_buf.extend_from_slice(&core.centroids); + let centroids_array = + PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable).into_array(); + + let rotation_signs = bitpack_rotation_signs(&core.rotation)?; + + TurboQuantArray::try_new_mse( + fsl.dtype().clone(), + codes, + norms_array, + centroids_array, + rotation_signs, + dimension, + bit_width, + ) +} + +/// Encode a FixedSizeListArray into a MSE-only `TurboQuantArray`. +/// +/// The input must be non-nullable. TurboQuant is a lossy encoding that does not +/// preserve null positions; callers must handle validity externally. +pub fn turboquant_encode_mse( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult { + vortex_ensure!( + fsl.dtype().nullability() == Nullability::NonNullable, + "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" + ); + vortex_ensure!( + config.bit_width >= 1 && config.bit_width <= 8, + "MSE bit_width must be 1-8, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 2, + "TurboQuant requires dimension >= 2, got {dimension}" + ); + + if fsl.is_empty() { + return Ok(fsl.clone().into_array()); + } + + let seed = config.seed.unwrap_or(42); + let core = turboquant_quantize_core(fsl, seed, config.bit_width)?; + + Ok(build_turboquant_mse(fsl, core, config.bit_width)?.into_array()) +} + +/// Encode a FixedSizeListArray into a `TurboQuantArray` with QJL correction. +/// +/// The QJL variant uses `bit_width - 1` MSE bits plus 1 bit of QJL residual +/// correction, giving unbiased inner product estimation. The input must be +/// non-nullable. +pub fn turboquant_encode_qjl( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, +) -> VortexResult { + vortex_ensure!( + fsl.dtype().nullability() == Nullability::NonNullable, + "TurboQuant requires non-nullable input, got nullable FixedSizeListArray" + ); + vortex_ensure!( + config.bit_width >= 2 && config.bit_width <= 9, + "QJL bit_width must be 2-9, got {}", + config.bit_width + ); + let dimension = fsl.list_size(); + vortex_ensure!( + dimension >= 2, + "TurboQuant requires dimension >= 2, got {dimension}" + ); + + if fsl.is_empty() { + return Ok(fsl.clone().into_array()); + } + + let seed = config.seed.unwrap_or(42); + let dim = dimension as usize; + let mse_bit_width = config.bit_width - 1; + + let core = turboquant_quantize_core(fsl, seed, mse_bit_width)?; + let padded_dim = core.padded_dim; + + // QJL uses a different rotation than the MSE stage to ensure statistical + // independence between the quantization noise and the sign projection. + let qjl_rotation = RotationMatrix::try_new(seed.wrapping_add(25), dim)?; + + let num_rows = fsl.len(); + let mut residual_norms_buf = BufferMut::::with_capacity(num_rows); + let mut qjl_sign_u8 = BufferMut::::with_capacity(num_rows * padded_dim); + + let mut dequantized_rotated = vec![0.0f32; padded_dim]; + let mut dequantized = vec![0.0f32; padded_dim]; + let mut residual = vec![0.0f32; padded_dim]; + let mut projected = vec![0.0f32; padded_dim]; + + // Compute QJL residuals using precomputed indices and norms from the core. + { + let f32_slice = core.f32_elements.as_slice::(); + let indices_slice: &[u8] = &core.all_indices; + let norms_slice: &[f32] = &core.norms; + + for row in 0..num_rows { + let x = &f32_slice[row * dim..(row + 1) * dim]; + let norm = norms_slice[row]; + + // Dequantize from precomputed indices. + let row_indices = &indices_slice[row * padded_dim..(row + 1) * padded_dim]; + for j in 0..padded_dim { + dequantized_rotated[j] = core.centroids[row_indices[j] as usize]; + } + + core.rotation + .inverse_rotate(&dequantized_rotated, &mut dequantized); + if norm > 0.0 { + for val in dequantized[..dim].iter_mut() { + *val *= norm; + } + } + + // Compute residual: r = x - x̂. + for j in 0..dim { + residual[j] = x[j] - dequantized[j]; + } + let residual_norm = l2_norm(&residual[..dim]); + residual_norms_buf.push(residual_norm); + + // QJL: sign(S · r). + if residual_norm > 0.0 { + qjl_rotation.rotate(&residual, &mut projected); + } else { + projected.fill(0.0); + } + + for j in 0..padded_dim { + qjl_sign_u8.push(if projected[j] >= 0.0 { 1u8 } else { 0u8 }); + } + } + } + + // Build the MSE part. + let mut array = build_turboquant_mse(fsl, core, mse_bit_width)?; + + // Attach QJL correction. + let residual_norms_array = + PrimitiveArray::new::(residual_norms_buf.freeze(), Validity::NonNullable); + let qjl_signs_elements = PrimitiveArray::new::(qjl_sign_u8.freeze(), Validity::NonNullable); + let qjl_signs = FixedSizeListArray::try_new( + qjl_signs_elements.into_array(), + padded_dim as u32, + Validity::NonNullable, + num_rows, + )?; + let qjl_rotation_signs = bitpack_rotation_signs(&qjl_rotation)?; + + array.qjl = Some(QjlCorrection { + signs: qjl_signs.into_array(), + residual_norms: residual_norms_array.into_array(), + rotation_signs: qjl_rotation_signs, + }); + + Ok(array.into_array()) +} + +/// Export rotation signs as a 1-bit `BitPackedArray` for efficient storage. +/// +/// The rotation matrix's 3 × padded_dim sign values are exported as 0/1 u8 +/// values in inverse application order, then bitpacked to 1 bit per sign. +/// On decode, FastLanes SIMD-unpacks back to `&[u8]` of 0/1 values. +fn bitpack_rotation_signs(rotation: &RotationMatrix) -> VortexResult { + let signs_u8 = rotation.export_inverse_signs_u8(); + let mut buf = BufferMut::::with_capacity(signs_u8.len()); + buf.extend_from_slice(&signs_u8); + let prim = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + Ok(bitpack_encode(&prim, 1, None)?.into_array()) +} diff --git a/encodings/turboquant/src/compute/cosine_similarity.rs b/encodings/turboquant/src/compute/cosine_similarity.rs new file mode 100644 index 00000000000..63ebac99d4e --- /dev/null +++ b/encodings/turboquant/src/compute/cosine_similarity.rs @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Approximate cosine similarity in the quantized domain. +//! +//! Since the SRHT is orthogonal, inner products are preserved in the rotated +//! domain. For two vectors from the same TurboQuant column (same rotation and +//! centroids), we can compute the dot product of their quantized representations +//! without full decompression: +//! +//! ```text +//! cos(a, b) = dot(a, b) / (||a|| × ||b||) +//! = ||a|| × ||b|| × dot(â_rot, b̂_rot) / (||a|| × ||b||) +//! = sum(centroids[code_a[j]] × centroids[code_b[j]]) +//! ``` +//! +//! where `â_rot` and `b̂_rot` are the quantized unit-norm rotated vectors. + +use vortex_array::ExecutionCtx; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_error::VortexResult; + +use crate::array::TurboQuantArray; + +/// Compute approximate cosine similarity between two rows of a TurboQuant array +/// without full decompression. +/// +/// Both rows must come from the same array (same rotation matrix and codebook). +/// The result has bounded error proportional to the quantization distortion. +/// +/// TODO: Wire into `vortex-tensor` cosine_similarity scalar function dispatch +/// so that `cosine_similarity(Extension(TurboQuant), Extension(TurboQuant))` +/// short-circuits to this when both arguments share the same encoding. +#[allow(dead_code)] // TODO: wire into vortex-tensor cosine_similarity dispatch +pub fn cosine_similarity_quantized( + array: &TurboQuantArray, + row_a: usize, + row_b: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let pd = array.padded_dim() as usize; + + // Read norms — execute to handle cascade-compressed children. + let norms_prim = array.norms().clone().execute::(ctx)?; + let norms = norms_prim.as_slice::(); + let norm_a = norms[row_a]; + let norm_b = norms[row_b]; + + if norm_a == 0.0 || norm_b == 0.0 { + return Ok(0.0); + } + + // Read codes from the FixedSizeListArray → flat u8. + let codes_fsl = array.codes().clone().execute::(ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let all_codes = codes_prim.as_slice::(); + + // Read centroids. + let centroids_prim = array.centroids().clone().execute::(ctx)?; + let c = centroids_prim.as_slice::(); + + let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; + let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; + + // Dot product of unit-norm quantized vectors in rotated domain. + // Since SRHT preserves inner products, this equals the dot product + // of the dequantized (but still unit-norm) vectors. + let dot: f32 = codes_a + .iter() + .zip(codes_b.iter()) + .map(|(&ca, &cb)| c[ca as usize] * c[cb as usize]) + .sum(); + + Ok(dot) +} diff --git a/encodings/turboquant/src/compute/l2_norm.rs b/encodings/turboquant/src/compute/l2_norm.rs new file mode 100644 index 00000000000..60aece9f98e --- /dev/null +++ b/encodings/turboquant/src/compute/l2_norm.rs @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! L2 norm direct readthrough for TurboQuant. +//! +//! TurboQuant stores the exact original L2 norm of each vector in the `norms` +//! child. This enables O(1) per-vector norm lookup without any decompression. + +use vortex_array::ArrayRef; + +use crate::array::TurboQuantArray; + +/// Return the stored norms directly — no decompression needed. +#[allow(dead_code)] // TODO: wire into vortex-tensor L2Norm dispatch +/// +/// The norms are computed before quantization, so they are exact (not affected +/// by the lossy encoding). The returned `ArrayRef` is a `PrimitiveArray` +/// with one element per vector row. +/// +/// TODO: Wire into `vortex-tensor` L2Norm scalar function dispatch so that +/// `l2_norm(Extension(TurboQuant(...)))` short-circuits to this. +pub fn l2_norm_direct(array: &TurboQuantArray) -> &ArrayRef { + array.norms() +} diff --git a/encodings/turboquant/src/compute/mod.rs b/encodings/turboquant/src/compute/mod.rs new file mode 100644 index 00000000000..1c249352d5e --- /dev/null +++ b/encodings/turboquant/src/compute/mod.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Compute pushdown implementations for TurboQuant. + +pub(crate) mod cosine_similarity; +pub(crate) mod l2_norm; +mod ops; +pub(crate) mod rules; +mod slice; +mod take; diff --git a/encodings/turboquant/src/compute/ops.rs b/encodings/turboquant/src/compute/ops.rs new file mode 100644 index 00000000000..9038371559b --- /dev/null +++ b/encodings/turboquant/src/compute/ops.rs @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ExecutionCtx; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::slice::SliceReduce; +use vortex_array::scalar::Scalar; +use vortex_array::vtable::OperationsVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; + +use crate::array::TurboQuant; +use crate::array::TurboQuantArray; + +impl OperationsVTable for TurboQuant { + fn scalar_at( + array: &TurboQuantArray, + index: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + // Slice to single row, decompress that one row. + let Some(sliced) = ::slice(array, index..index + 1)? else { + vortex_bail!("slice returned None for index {index}") + }; + let decoded = sliced.execute::(ctx)?; + decoded.scalar_at(0) + } +} diff --git a/encodings/turboquant/src/compute/rules.rs b/encodings/turboquant/src/compute/rules.rs new file mode 100644 index 00000000000..13cf20b1e19 --- /dev/null +++ b/encodings/turboquant/src/compute/rules.rs @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::dict::TakeExecuteAdaptor; +use vortex_array::arrays::slice::SliceReduceAdaptor; +use vortex_array::kernel::ParentKernelSet; +use vortex_array::optimizer::rules::ParentRuleSet; + +use crate::array::TurboQuant; + +pub(crate) static RULES: ParentRuleSet = + ParentRuleSet::new(&[ParentRuleSet::lift(&SliceReduceAdaptor(TurboQuant))]); + +pub(crate) static PARENT_KERNELS: ParentKernelSet = + ParentKernelSet::new(&[ParentKernelSet::lift(&TakeExecuteAdaptor(TurboQuant))]); diff --git a/encodings/turboquant/src/compute/slice.rs b/encodings/turboquant/src/compute/slice.rs new file mode 100644 index 00000000000..b3702254ed6 --- /dev/null +++ b/encodings/turboquant/src/compute/slice.rs @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::ops::Range; + +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::slice::SliceReduce; +use vortex_error::VortexResult; + +use crate::array::QjlCorrection; +use crate::array::TurboQuant; +use crate::array::TurboQuantArray; + +impl SliceReduce for TurboQuant { + fn slice(array: &TurboQuantArray, range: Range) -> VortexResult> { + let sliced_codes = array.codes.slice(range.clone())?; + let sliced_norms = array.norms.slice(range.clone())?; + + let sliced_qjl = array + .qjl + .as_ref() + .map(|qjl| -> VortexResult { + Ok(QjlCorrection { + signs: qjl.signs.slice(range.clone())?, + residual_norms: qjl.residual_norms.slice(range.clone())?, + rotation_signs: qjl.rotation_signs.clone(), + }) + }) + .transpose()?; + + let mut result = TurboQuantArray::try_new_mse( + array.dtype.clone(), + sliced_codes, + sliced_norms, + array.centroids.clone(), + array.rotation_signs.clone(), + array.dimension, + array.bit_width, + )?; + result.qjl = sliced_qjl; + + Ok(Some(result.into_array())) + } +} diff --git a/encodings/turboquant/src/compute/take.rs b/encodings/turboquant/src/compute/take.rs new file mode 100644 index 00000000000..ddbc28d8cd9 --- /dev/null +++ b/encodings/turboquant/src/compute/take.rs @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayRef; +use vortex_array::DynArray; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::dict::TakeExecute; +use vortex_error::VortexResult; + +use crate::array::QjlCorrection; +use crate::array::TurboQuant; +use crate::array::TurboQuantArray; + +impl TakeExecute for TurboQuant { + fn take( + array: &TurboQuantArray, + indices: &ArrayRef, + _ctx: &mut ExecutionCtx, + ) -> VortexResult> { + // FSL children handle per-row take natively. + let taken_codes = array.codes.take(indices.clone())?; + let taken_norms = array.norms.take(indices.clone())?; + + let taken_qjl = array + .qjl + .as_ref() + .map(|qjl| -> VortexResult { + Ok(QjlCorrection { + signs: qjl.signs.take(indices.clone())?, + residual_norms: qjl.residual_norms.take(indices.clone())?, + rotation_signs: qjl.rotation_signs.clone(), + }) + }) + .transpose()?; + + let mut result = TurboQuantArray::try_new_mse( + array.dtype.clone(), + taken_codes, + taken_norms, + array.centroids.clone(), + array.rotation_signs.clone(), + array.dimension, + array.bit_width, + )?; + result.qjl = taken_qjl; + + Ok(Some(result.into_array())) + } +} diff --git a/encodings/turboquant/src/decompress.rs b/encodings/turboquant/src/decompress.rs new file mode 100644 index 00000000000..c26d905df8c --- /dev/null +++ b/encodings/turboquant/src/decompress.rs @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant decoding (dequantization) logic. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; + +use crate::array::TurboQuantArray; +use crate::rotation::RotationMatrix; + +/// QJL correction scale factor: `sqrt(π/2) / padded_dim`. +/// +/// Accounts for the SRHT normalization (`1/padded_dim^{3/2}` per transform) +/// combined with `E[|z|] = sqrt(2/π)` for half-normal sign expectations. +/// Verified empirically via the `qjl_inner_product_bias` test suite. +#[inline] +fn qjl_correction_scale(padded_dim: usize) -> f32 { + (std::f32::consts::FRAC_PI_2).sqrt() / (padded_dim as f32) +} + +/// Decompress a `TurboQuantArray` into a `FixedSizeListArray` of floats. +/// +/// Reads stored centroids and rotation signs from the array's children, +/// avoiding any recomputation. If QJL correction is present, applies +/// the residual correction after MSE decoding. +pub fn execute_decompress( + array: TurboQuantArray, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dim = array.dimension() as usize; + let padded_dim = array.padded_dim() as usize; + let num_rows = array.norms.len(); + + if num_rows == 0 { + let elements = PrimitiveArray::empty::(array.dtype.nullability()); + return Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + 0, + )? + .into_array()); + } + + // Read stored centroids — no recomputation. + let centroids_prim = array.centroids.clone().execute::(ctx)?; + let centroids = centroids_prim.as_slice::(); + + // FastLanes SIMD-unpacks the 1-bit bitpacked rotation signs into u8 0/1 values, + // then we expand to u32 XOR masks once (amortized over all rows). This enables + // branchless XOR-based sign application in the per-row SRHT hot loop. + let signs_prim = array + .rotation_signs + .clone() + .execute::(ctx)?; + let rotation = RotationMatrix::from_u8_slice(signs_prim.as_slice::(), dim)?; + + // Unpack codes from FixedSizeListArray → flat u8 elements. + let codes_fsl = array.codes.clone().execute::(ctx)?; + let codes_prim = codes_fsl.elements().to_canonical()?.into_primitive(); + let indices = codes_prim.as_slice::(); + + let norms_prim = array.norms.clone().execute::(ctx)?; + let norms = norms_prim.as_slice::(); + + // MSE decode: dequantize → inverse rotate → scale by norm. + let mut mse_output = BufferMut::::with_capacity(num_rows * dim); + let mut dequantized = vec![0.0f32; padded_dim]; + let mut unrotated = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let row_indices = &indices[row * padded_dim..(row + 1) * padded_dim]; + let norm = norms[row]; + + for idx in 0..padded_dim { + dequantized[idx] = centroids[row_indices[idx] as usize]; + } + + rotation.inverse_rotate(&dequantized, &mut unrotated); + + for idx in 0..dim { + unrotated[idx] *= norm; + } + + mse_output.extend_from_slice(&unrotated[..dim]); + } + + // If no QJL correction, we're done. + let Some(qjl) = &array.qjl else { + let elements = PrimitiveArray::new::(mse_output.freeze(), Validity::NonNullable); + return Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + num_rows, + )? + .into_array()); + }; + + // Apply QJL residual correction. + // Unpack QJL signs from FixedSizeListArray → flat u8 0/1 values. + let qjl_signs_fsl = qjl.signs.clone().execute::(ctx)?; + let qjl_signs_prim = qjl_signs_fsl.elements().to_canonical()?.into_primitive(); + let qjl_signs_u8 = qjl_signs_prim.as_slice::(); + + let residual_norms_prim = qjl.residual_norms.clone().execute::(ctx)?; + let residual_norms = residual_norms_prim.as_slice::(); + + let qjl_rot_signs_prim = qjl.rotation_signs.clone().execute::(ctx)?; + let qjl_rot = RotationMatrix::from_u8_slice(qjl_rot_signs_prim.as_slice::(), dim)?; + + let qjl_scale = qjl_correction_scale(padded_dim); + let mse_elements = mse_output.as_ref(); + + let mut output = BufferMut::::with_capacity(num_rows * dim); + let mut qjl_signs_vec = vec![0.0f32; padded_dim]; + let mut qjl_projected = vec![0.0f32; padded_dim]; + + for row in 0..num_rows { + let mse_row = &mse_elements[row * dim..(row + 1) * dim]; + let residual_norm = residual_norms[row]; + + // Convert u8 0/1 → f32 ±1.0 for this row's signs. + let row_signs = &qjl_signs_u8[row * padded_dim..(row + 1) * padded_dim]; + for idx in 0..padded_dim { + qjl_signs_vec[idx] = if row_signs[idx] != 0 { 1.0 } else { -1.0 }; + } + + qjl_rot.inverse_rotate(&qjl_signs_vec, &mut qjl_projected); + let scale = qjl_scale * residual_norm; + + for idx in 0..dim { + output.push(mse_row[idx] + scale * qjl_projected[idx]); + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + Ok(FixedSizeListArray::try_new( + elements.into_array(), + array.dimension(), + Validity::NonNullable, + num_rows, + )? + .into_array()) +} diff --git a/encodings/turboquant/src/lib.rs b/encodings/turboquant/src/lib.rs new file mode 100644 index 00000000000..6d71596d0aa --- /dev/null +++ b/encodings/turboquant/src/lib.rs @@ -0,0 +1,995 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +// Numerical truncations are intentional throughout this crate (dimension u32↔usize, +// f64→f32 centroids, partition_point→u8 indices, etc.). +#![allow(clippy::cast_possible_truncation)] + +//! TurboQuant vector quantization encoding for Vortex. +//! +//! Implements the TurboQuant algorithm ([arXiv:2504.19874]) for lossy compression of +//! high-dimensional vector data. The encoding operates on `FixedSizeList` arrays of floats +//! (the storage format of `Vector` and `FixedShapeTensor` extension types). +//! +//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 +//! +//! # Variants +//! +//! - **MSE** (`TurboQuantVariant::Mse`): Minimizes mean-squared reconstruction error +//! (1-8 bits per coordinate). +//! - **Prod** (`TurboQuantVariant::Prod`): Preserves inner products with an unbiased +//! estimator (uses `b-1` bits for MSE + 1-bit QJL residual correction, 2-9 bits). +//! At `b=9`, the MSE codes are raw int8 values suitable for direct use with +//! tensor core int8 GEMM kernels. +//! +//! # Theoretical error bounds +//! +//! For unit-norm vectors quantized at `b` bits per coordinate, the paper's Theorem 1 +//! guarantees normalized MSE distortion: +//! +//! > `E[||x - x̂||² / ||x||²] ≤ (√3 · π / 2) / 4^b` +//! +//! | Bits | MSE bound | Quality | +//! |------|------------|-------------------| +//! | 1 | 6.80e-01 | Poor | +//! | 2 | 1.70e-01 | Usable for ANN | +//! | 3 | 4.25e-02 | Good | +//! | 4 | 1.06e-02 | Very good | +//! | 5 | 2.66e-03 | Excellent | +//! | 6 | 6.64e-04 | Near-lossless | +//! | 7 | 1.66e-04 | Near-lossless | +//! | 8 | 4.15e-05 | Near-lossless | +//! +//! # Compression ratios +//! +//! Each vector is stored as `padded_dim × bit_width / 8` bytes of quantized codes plus a +//! 4-byte f32 norm. Non-power-of-2 dimensions are padded to the next power of 2 for the +//! Walsh-Hadamard transform, which reduces the effective ratio for those dimensions. +//! +//! | dim | padded | bits | f32 bytes | TQ bytes | ratio | +//! |------|--------|------|-----------|----------|--------| +//! | 768 | 1024 | 2 | 3072 | 260 | 11.8x | +//! | 1024 | 1024 | 2 | 4096 | 260 | 15.8x | +//! | 768 | 1024 | 4 | 3072 | 516 | 6.0x | +//! | 1024 | 1024 | 4 | 4096 | 516 | 7.9x | +//! | 768 | 1024 | 8 | 3072 | 1028 | 3.0x | +//! | 1024 | 1024 | 8 | 4096 | 1028 | 4.0x | +//! +//! # Example +//! +//! ``` +//! use vortex_array::IntoArray; +//! use vortex_array::arrays::FixedSizeListArray; +//! use vortex_array::arrays::PrimitiveArray; +//! use vortex_array::validity::Validity; +//! use vortex_buffer::BufferMut; +//! use vortex_turboquant::{TurboQuantConfig, turboquant_encode_mse}; +//! +//! // Create a FixedSizeListArray of 100 random 128-d vectors. +//! let num_rows = 100; +//! let dim = 128; +//! let mut buf = BufferMut::::with_capacity(num_rows * dim); +//! for i in 0..(num_rows * dim) { +//! buf.push((i as f32 * 0.001).sin()); +//! } +//! let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); +//! let fsl = FixedSizeListArray::try_new( +//! elements.into_array(), dim as u32, Validity::NonNullable, num_rows, +//! ).unwrap(); +//! +//! // Quantize at 2 bits per coordinate using MSE-optimal encoding. +//! let config = TurboQuantConfig { bit_width: 2, seed: Some(42) }; +//! let encoded = turboquant_encode_mse(&fsl, &config).unwrap(); +//! +//! // Verify compression: 100 vectors × 128 dims × 4 bytes = 51200 bytes input. +//! assert!(encoded.nbytes() < 51200); +//! ``` + +pub use array::QjlCorrection; +pub use array::TurboQuant; +pub use array::TurboQuantArray; +pub use compress::TurboQuantConfig; +pub use compress::turboquant_encode_mse; +pub use compress::turboquant_encode_qjl; + +mod array; +pub(crate) mod centroids; +mod compress; +mod compute; +pub(crate) mod decompress; +pub(crate) mod rotation; +mod vtable; + +/// Extension ID for the `Vector` type from `vortex-tensor`. +pub const VECTOR_EXT_ID: &str = "vortex.tensor.vector"; + +/// Extension ID for the `FixedShapeTensor` type from `vortex-tensor`. +pub const FIXED_SHAPE_TENSOR_EXT_ID: &str = "vortex.tensor.fixed_shape_tensor"; + +use vortex_array::session::ArraySessionExt; +use vortex_session::VortexSession; + +/// Initialize the TurboQuant encoding in the given session. +pub fn initialize(session: &mut VortexSession) { + session.arrays().register(TurboQuant); +} + +#[cfg(test)] +mod tests { + use std::sync::LazyLock; + + use rand::RngExt; + use rand::SeedableRng; + use rand::rngs::StdRng; + use rand_distr::Distribution; + use rand_distr::Normal; + use rstest::rstest; + use vortex_array::ArrayRef; + use vortex_array::IntoArray; + use vortex_array::VortexSessionExecute; + use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; + use vortex_array::matcher::Matcher; + use vortex_array::session::ArraySession; + use vortex_array::validity::Validity; + use vortex_buffer::BufferMut; + use vortex_error::VortexResult; + use vortex_session::VortexSession; + + use crate::TurboQuant; + use crate::TurboQuantConfig; + use crate::rotation::RotationMatrix; + use crate::turboquant_encode_mse; + use crate::turboquant_encode_qjl; + + static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + + /// Create a FixedSizeListArray of random f32 vectors (i.i.d. standard normal). + fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(seed); + let normal = Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + ) + .unwrap() + } + + fn theoretical_mse_bound(bit_width: u8) -> f32 { + let sqrt3_pi_over_2 = (3.0f32).sqrt() * std::f32::consts::PI / 2.0; + sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) + } + + fn per_vector_normalized_mse( + original: &[f32], + reconstructed: &[f32], + dim: usize, + num_rows: usize, + ) -> f32 { + let mut total = 0.0f32; + for row in 0..num_rows { + let orig = &original[row * dim..(row + 1) * dim]; + let recon = &reconstructed[row * dim..(row + 1) * dim]; + let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); + if norm_sq < 1e-10 { + continue; + } + let err_sq: f32 = orig + .iter() + .zip(recon.iter()) + .map(|(&a, &b)| (a - b) * (a - b)) + .sum(); + total += err_sq / norm_sq; + } + total / num_rows as f32 + } + + /// Encode and decode, returning (original, decoded) flat f32 slices. + fn encode_decode( + fsl: &FixedSizeListArray, + encode_fn: impl FnOnce(&FixedSizeListArray) -> VortexResult, + ) -> VortexResult<(Vec, Vec)> { + let original: Vec = { + let prim = fsl.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + let encoded = encode_fn(fsl)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + let decoded_elements: Vec = { + let prim = decoded.elements().to_canonical().unwrap().into_primitive(); + prim.as_slice::().to_vec() + }; + Ok((original, decoded_elements)) + } + + fn encode_decode_mse( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ) -> VortexResult<(Vec, Vec)> { + let config = config.clone(); + encode_decode(fsl, |fsl| turboquant_encode_mse(fsl, &config)) + } + + fn encode_decode_qjl( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ) -> VortexResult<(Vec, Vec)> { + let config = config.clone(); + encode_decode(fsl, |fsl| turboquant_encode_qjl(fsl, &config)) + } + + // ----------------------------------------------------------------------- + // MSE encoding tests + // ----------------------------------------------------------------------- + + #[rstest] + #[case(32, 1)] + #[case(32, 2)] + #[case(32, 3)] + #[case(32, 4)] + #[case(128, 2)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(256, 2)] + fn roundtrip_mse(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(256, 2)] + #[case(256, 4)] + fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + let bound = theoretical_mse_bound(bit_width); + + assert!( + normalized_mse < bound, + "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} for dim={dim}, bits={bit_width}", + ); + Ok(()) + } + + #[rstest] + #[case(128, 6)] + #[case(128, 8)] + #[case(256, 6)] + #[case(256, 8)] + fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let (original_4, decoded_4) = encode_decode_mse(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" + ); + assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); + Ok(()) + } + + #[test] + fn mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 1..=8u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) + } + + // ----------------------------------------------------------------------- + // QJL encoding tests + // ----------------------------------------------------------------------- + + #[rstest] + #[case(32, 2)] + #[case(32, 3)] + #[case(128, 2)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(128, 9)] + #[case(768, 3)] + fn roundtrip_qjl(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let fsl = make_fsl(10, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(456), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + assert_eq!(decoded.len(), original.len()); + Ok(()) + } + + #[rstest] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(128, 6)] + #[case(128, 8)] + #[case(128, 9)] + #[case(768, 3)] + #[case(768, 4)] + fn qjl_inner_product_bias(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 100; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + + let num_pairs = 500; + let mut rng = StdRng::seed_from_u64(0); + let mut signed_errors = Vec::with_capacity(num_pairs); + + for _ in 0..num_pairs { + let qi = rng.random_range(0..num_rows); + let xi = rng.random_range(0..num_rows); + if qi == xi { + continue; + } + + let query = &original[qi * dim..(qi + 1) * dim]; + let orig_vec = &original[xi * dim..(xi + 1) * dim]; + let quant_vec = &decoded[xi * dim..(xi + 1) * dim]; + + let true_ip: f32 = query.iter().zip(orig_vec).map(|(&a, &b)| a * b).sum(); + let quant_ip: f32 = query.iter().zip(quant_vec).map(|(&a, &b)| a * b).sum(); + + if true_ip.abs() > 1e-6 { + signed_errors.push((quant_ip - true_ip) / true_ip.abs()); + } + } + + if signed_errors.is_empty() { + return Ok(()); + } + + let mean_rel_error: f32 = signed_errors.iter().sum::() / signed_errors.len() as f32; + assert!( + mean_rel_error.abs() < 0.3, + "QJL inner product bias too high: {mean_rel_error:.4} for dim={dim}, bits={bit_width}" + ); + Ok(()) + } + + #[test] + fn qjl_mse_decreases_with_bits() -> VortexResult<()> { + let dim = 128; + let num_rows = 50; + let fsl = make_fsl(num_rows, dim, 99); + + let mut prev_mse = f32::MAX; + for bit_width in 2..=9u8 { + let config = TurboQuantConfig { + bit_width, + seed: Some(123), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + assert!( + mse <= prev_mse * 1.01, + "QJL MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" + ); + prev_mse = mse; + } + Ok(()) + } + + // ----------------------------------------------------------------------- + // Edge cases + // ----------------------------------------------------------------------- + + #[rstest] + #[case(0)] + #[case(1)] + fn roundtrip_mse_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) + } + + #[rstest] + #[case(0)] + #[case(1)] + fn roundtrip_qjl_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { + let fsl = make_fsl(num_rows, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + let mut ctx = SESSION.create_execution_ctx(); + let decoded = encoded.execute::(&mut ctx)?; + assert_eq!(decoded.len(), num_rows); + Ok(()) + } + + #[test] + fn mse_rejects_dimension_below_2() { + let fsl = make_fsl_dim1(); + let config = TurboQuantConfig { + bit_width: 2, + seed: Some(0), + }; + assert!(turboquant_encode_mse(&fsl, &config).is_err()); + } + + #[test] + fn qjl_rejects_dimension_below_2() { + let fsl = make_fsl_dim1(); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(0), + }; + assert!(turboquant_encode_qjl(&fsl, &config).is_err()); + } + + fn make_fsl_dim1() -> FixedSizeListArray { + let mut buf = BufferMut::::with_capacity(1); + buf.push(1.0); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new(elements.into_array(), 1, Validity::NonNullable, 1).unwrap() + } + + // ----------------------------------------------------------------------- + // Verification tests for stored metadata + // ----------------------------------------------------------------------- + + /// Verify that the centroids stored in the MSE array match what get_centroids() computes. + #[test] + fn stored_centroids_match_computed() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = TurboQuant::try_match(&*encoded).unwrap(); + + let mut ctx = SESSION.create_execution_ctx(); + let stored_centroids_prim = encoded + .centroids() + .clone() + .execute::(&mut ctx)?; + let stored = stored_centroids_prim.as_slice::(); + + let padded_dim = encoded.padded_dim(); + let computed = crate::centroids::get_centroids(padded_dim, 3)?; + + assert_eq!(stored.len(), computed.len()); + for i in 0..stored.len() { + assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); + } + Ok(()) + } + + /// Verify that stored rotation signs produce identical decode to seed-based decode. + /// + /// Encodes the same data twice: once with the new path (stored signs), and + /// once by manually recomputing the rotation from the seed. Both should + /// produce identical output. + #[test] + fn stored_rotation_signs_produce_correct_decode() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = TurboQuant::try_match(&*encoded).unwrap(); + + // Decode via the stored-signs path (normal decode). + let mut ctx = SESSION.create_execution_ctx(); + let decoded_fsl = encoded + .clone() + .into_array() + .execute::(&mut ctx)?; + let decoded = decoded_fsl.elements().to_canonical()?.into_primitive(); + let decoded_slice = decoded.as_slice::(); + + // Verify stored signs match seed-derived signs. + let rot_from_seed = RotationMatrix::try_new(123, 128)?; + let expected_u8 = rot_from_seed.export_inverse_signs_u8(); + let stored_signs = encoded + .rotation_signs() + .clone() + .execute::(&mut ctx)?; + let stored_u8 = stored_signs.as_slice::(); + + assert_eq!(expected_u8.len(), stored_u8.len()); + for i in 0..expected_u8.len() { + assert_eq!(expected_u8[i], stored_u8[i], "Sign mismatch at index {i}"); + } + + // Also verify decode output is non-empty and has expected size. + assert_eq!(decoded_slice.len(), 20 * 128); + Ok(()) + } + + // ----------------------------------------------------------------------- + // QJL-specific quality tests + // ----------------------------------------------------------------------- + + /// Verify that QJL's MSE component (at bit_width-1) satisfies the theoretical bound. + #[rstest] + #[case(128, 3)] + #[case(128, 4)] + #[case(256, 3)] + fn qjl_mse_within_theoretical_bound( + #[case] dim: usize, + #[case] bit_width: u8, + ) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + + let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + // QJL at b bits uses (b-1)-bit MSE plus a correction term. + // The MSE should be at most the (b-1)-bit theoretical bound, though + // in practice the QJL correction often improves it further. + let mse_bound = theoretical_mse_bound(bit_width - 1); + assert!( + normalized_mse < mse_bound, + "QJL MSE {normalized_mse:.6} exceeds (b-1)-bit bound {mse_bound:.6} \ + for dim={dim}, bits={bit_width}", + ); + Ok(()) + } + + /// Verify that high-bitwidth QJL (8-9 bits) achieves very low distortion. + #[rstest] + #[case(128, 8)] + #[case(128, 9)] + fn high_bitwidth_qjl_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { + let num_rows = 200; + let fsl = make_fsl(num_rows, dim, 42); + + // Compare against 4-bit QJL as reference ceiling. + let config_4bit = TurboQuantConfig { + bit_width: 4, + seed: Some(789), + }; + let (original_4, decoded_4) = encode_decode_qjl(&fsl, &config_4bit)?; + let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); + + let config = TurboQuantConfig { + bit_width, + seed: Some(789), + }; + let (original, decoded) = encode_decode_qjl(&fsl, &config)?; + let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); + + assert!( + mse < mse_4bit, + "{bit_width}-bit QJL MSE ({mse:.6}) should be < 4-bit ({mse_4bit:.6})" + ); + assert!( + mse < 0.01, + "{bit_width}-bit QJL MSE ({mse:.6}) should be < 1%" + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Edge case and input format tests + // ----------------------------------------------------------------------- + + /// Verify that all-zero vectors roundtrip correctly (norm == 0 branch). + #[test] + fn all_zero_vectors_roundtrip() -> VortexResult<()> { + let num_rows = 10; + let dim = 128; + let buf = BufferMut::::full(0.0f32, num_rows * dim); + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + let (original, decoded) = encode_decode_mse(&fsl, &config)?; + // All-zero vectors should decode to all-zero (norm=0 → 0 * anything = 0). + for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { + assert_eq!(o, 0.0, "original[{i}] not zero"); + assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); + } + Ok(()) + } + + /// Verify that f64 input is accepted and encoded (converted to f32 internally). + #[test] + fn f64_input_encodes_successfully() -> VortexResult<()> { + let num_rows = 10; + let dim = 64; + let mut rng = StdRng::seed_from_u64(99); + let normal = Normal::new(0.0f64, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(num_rows * dim); + for _ in 0..(num_rows * dim) { + buf.push(normal.sample(&mut rng)); + } + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + num_rows, + )?; + + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(42), + }; + // Verify encoding succeeds with f64 input (f64→f32 conversion). + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = TurboQuant::try_match(&*encoded).unwrap(); + assert_eq!(encoded.norms().len(), num_rows); + assert_eq!(encoded.dimension(), dim as u32); + Ok(()) + } + + /// Verify serde roundtrip: serialize MSE array metadata + children, then rebuild. + #[test] + fn mse_serde_roundtrip() -> VortexResult<()> { + use vortex_array::DynArray; + use vortex_array::vtable::VTable; + + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let encoded = TurboQuant::try_match(&*encoded).unwrap(); + + // Serialize metadata. + let metadata = ::metadata(encoded)?; + let serialized = + ::serialize(metadata)?.expect("metadata should serialize"); + + // Collect children. + let nchildren = ::nchildren(encoded); + assert_eq!(nchildren, 4); + let children: Vec = (0..nchildren) + .map(|i| ::child(encoded, i)) + .collect(); + + // Deserialize and rebuild. + let deserialized = ::deserialize( + &serialized, + encoded.dtype(), + encoded.len(), + &[], + &SESSION, + )?; + + // Verify metadata fields survived roundtrip. + assert_eq!(deserialized.dimension, encoded.dimension()); + assert_eq!(deserialized.bit_width, encoded.bit_width() as u32); + assert_eq!(deserialized.has_qjl, encoded.has_qjl()); + + // Verify the rebuilt array decodes identically. + let mut ctx = SESSION.create_execution_ctx(); + let decoded_original = encoded + .clone() + .into_array() + .execute::(&mut ctx)?; + let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); + + // Rebuild from children (simulating deserialization). + let rebuilt = crate::array::TurboQuantArray::try_new_mse( + encoded.dtype().clone(), + children[0].clone(), + children[1].clone(), + children[2].clone(), + children[3].clone(), + deserialized.dimension, + deserialized.bit_width as u8, + )?; + let decoded_rebuilt = rebuilt + .into_array() + .execute::(&mut ctx)?; + let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive(); + + assert_eq!( + original_elements.as_slice::(), + rebuilt_elements.as_slice::() + ); + Ok(()) + } + + /// Verify serde roundtrip for QJL: serialize metadata + children, then rebuild. + #[test] + fn qjl_serde_roundtrip() -> VortexResult<()> { + use vortex_array::DynArray; + use vortex_array::vtable::VTable; + + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + let encoded = TurboQuant::try_match(&*encoded).unwrap(); + + // Serialize metadata. + let metadata = ::metadata(encoded)?; + let serialized = + ::serialize(metadata)?.expect("metadata should serialize"); + + // Collect children — QJL has 7 (4 MSE + 3 QJL). + let nchildren = ::nchildren(encoded); + assert_eq!(nchildren, 7); + let children: Vec = (0..nchildren) + .map(|i| ::child(encoded, i)) + .collect(); + + // Deserialize metadata. + let deserialized = ::deserialize( + &serialized, + encoded.dtype(), + encoded.len(), + &[], + &SESSION, + )?; + + assert!(deserialized.has_qjl); + assert_eq!(deserialized.dimension, encoded.dimension()); + + // Verify decode: original vs rebuilt from children. + let mut ctx = SESSION.create_execution_ctx(); + let decoded_original = encoded + .clone() + .into_array() + .execute::(&mut ctx)?; + let original_elements = decoded_original.elements().to_canonical()?.into_primitive(); + + // Rebuild with QJL children. + let rebuilt = crate::array::TurboQuantArray::try_new_qjl( + encoded.dtype().clone(), + children[0].clone(), + children[1].clone(), + children[2].clone(), + children[3].clone(), + crate::array::QjlCorrection { + signs: children[4].clone(), + residual_norms: children[5].clone(), + rotation_signs: children[6].clone(), + }, + deserialized.dimension, + deserialized.bit_width as u8, + )?; + let decoded_rebuilt = rebuilt + .into_array() + .execute::(&mut ctx)?; + let rebuilt_elements = decoded_rebuilt.elements().to_canonical()?.into_primitive(); + + assert_eq!( + original_elements.as_slice::(), + rebuilt_elements.as_slice::() + ); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Compute pushdown tests + // ----------------------------------------------------------------------- + + #[test] + fn slice_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + + // Full decompress then slice. + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let expected = full_decoded.slice(5..10)?; + let expected_prim = expected.to_canonical()?.into_fixed_size_list(); + let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); + + // Slice then decompress. + let sliced = encoded.slice(5..10)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let actual_elements = sliced_decoded.elements().to_canonical()?.into_primitive(); + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) + } + + #[test] + fn slice_qjl_preserves_data() -> VortexResult<()> { + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(456), + }; + let encoded = turboquant_encode_qjl(&fsl, &config)?; + + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + let expected = full_decoded.slice(3..8)?; + let expected_prim = expected.to_canonical()?.into_fixed_size_list(); + let expected_elements = expected_prim.elements().to_canonical()?.into_primitive(); + + let sliced = encoded.slice(3..8)?; + let sliced_decoded = sliced.execute::(&mut ctx)?; + let actual_elements = sliced_decoded.elements().to_canonical()?.into_primitive(); + + assert_eq!( + expected_elements.as_slice::(), + actual_elements.as_slice::() + ); + Ok(()) + } + + #[test] + fn scalar_at_matches_decompress() -> VortexResult<()> { + let fsl = make_fsl(10, 64, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + + let mut ctx = SESSION.create_execution_ctx(); + let full_decoded = encoded.clone().execute::(&mut ctx)?; + + for i in [0, 1, 5, 9] { + let expected = full_decoded.scalar_at(i)?; + let actual = encoded.scalar_at(i)?; + assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); + } + Ok(()) + } + + #[test] + fn l2_norm_readthrough() -> VortexResult<()> { + let fsl = make_fsl(10, 128, 42); + let config = TurboQuantConfig { + bit_width: 3, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let tq = TurboQuant::try_match(&*encoded).unwrap(); + + // Stored norms should match the actual L2 norms of the input. + let norms_prim = tq.norms().to_canonical()?.into_primitive(); + let stored_norms = norms_prim.as_slice::(); + + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + for row in 0..10 { + let vec = &input_f32[row * 128..(row + 1) * 128]; + let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); + assert!( + (stored_norms[row] - actual_norm).abs() < 1e-5, + "norm mismatch at row {row}: stored={}, actual={}", + stored_norms[row], + actual_norm + ); + } + Ok(()) + } + + #[test] + fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { + use crate::compute::cosine_similarity::cosine_similarity_quantized; + + let fsl = make_fsl(20, 128, 42); + let config = TurboQuantConfig { + bit_width: 4, + seed: Some(123), + }; + let encoded = turboquant_encode_mse(&fsl, &config)?; + let tq = TurboQuant::try_match(&*encoded).unwrap(); + + // Compute exact cosine similarity from original data. + let input_prim = fsl.elements().to_canonical()?.into_primitive(); + let input_f32 = input_prim.as_slice::(); + + for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { + let a = &input_f32[row_a * 128..(row_a + 1) * 128]; + let b = &input_f32[row_b * 128..(row_b + 1) * 128]; + + let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum(); + let norm_a: f32 = a.iter().map(|&v| v * v).sum::().sqrt(); + let norm_b: f32 = b.iter().map(|&v| v * v).sum::().sqrt(); + let exact_cos = dot / (norm_a * norm_b); + + let mut ctx = SESSION.create_execution_ctx(); + let approx_cos = cosine_similarity_quantized(tq, row_a, row_b, &mut ctx)?; + + // 4-bit quantization: expect reasonable accuracy. + let error = (exact_cos - approx_cos).abs(); + assert!( + error < 0.15, + "cosine similarity error too large for ({row_a}, {row_b}): \ + exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" + ); + } + Ok(()) + } +} diff --git a/encodings/turboquant/src/rotation.rs b/encodings/turboquant/src/rotation.rs new file mode 100644 index 00000000000..466b843e04a --- /dev/null +++ b/encodings/turboquant/src/rotation.rs @@ -0,0 +1,379 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Deterministic random rotation for TurboQuant. +//! +//! Uses a Structured Random Hadamard Transform (SRHT) for O(d log d) rotation +//! instead of a full d×d matrix multiply. The SRHT applies the sequence +//! D₃ · H · D₂ · H · D₁ where H is the Walsh-Hadamard transform and Dₖ are +//! random diagonal ±1 sign matrices. Three rounds of HD provide sufficient +//! randomness for near-uniform distribution on the sphere. +//! +//! For dimensions that are not powers of 2, the input is zero-padded to the +//! next power of 2 before the transform and truncated afterward. +//! +//! # Sign representation +//! +//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) +//! and `0x80000000` for -1 (flip IEEE 754 sign bit). The sign application +//! function uses integer XOR instead of floating-point multiply, which avoids +//! FP dependency chains and auto-vectorizes into `vpxor`/`veor`. + +use rand::RngExt; +use rand::SeedableRng; +use rand::rngs::StdRng; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// IEEE 754 sign bit mask for f32. +const F32_SIGN_BIT: u32 = 0x8000_0000; + +/// A structured random Hadamard transform for O(d log d) pseudo-random rotation. +pub struct RotationMatrix { + /// XOR masks for each of the 3 diagonal matrices, each of length `padded_dim`. + /// `0x00000000` = multiply by +1 (no-op), `0x80000000` = multiply by -1 (flip sign bit). + sign_masks: [Vec; 3], + /// The padded dimension (next power of 2 >= dimension). + padded_dim: usize, + /// Normalization factor: 1/(padded_dim * sqrt(padded_dim)), applied once at the end. + norm_factor: f32, +} + +impl RotationMatrix { + /// Create a new SRHT rotation from a deterministic seed. + pub fn try_new(seed: u64, dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + let mut rng = StdRng::seed_from_u64(seed); + + let sign_masks = std::array::from_fn(|_| gen_random_sign_masks(&mut rng, padded_dim)); + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + sign_masks, + padded_dim, + norm_factor, + }) + } + + /// Apply forward rotation: `output = SRHT(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. The caller + /// is responsible for zero-padding input beyond `dim` positions. + pub fn rotate(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_srht(output); + } + + /// Apply inverse rotation: `output = SRHT⁻¹(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. + pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_inverse_srht(output); + } + + /// Returns the padded dimension (next power of 2 >= dim). + /// + /// All rotate/inverse_rotate buffers must be this length. + pub fn padded_dim(&self) -> usize { + self.padded_dim + } + + /// Apply the SRHT: D₃ · H · D₂ · H · D₁ · x, with normalization. + fn apply_srht(&self, buf: &mut [f32]) { + apply_signs_xor(buf, &self.sign_masks[0]); + walsh_hadamard_transform(buf); + + apply_signs_xor(buf, &self.sign_masks[1]); + walsh_hadamard_transform(buf); + + apply_signs_xor(buf, &self.sign_masks[2]); + walsh_hadamard_transform(buf); + + let norm = self.norm_factor; + buf.iter_mut().for_each(|val| *val *= norm); + } + + /// Apply the inverse SRHT. + /// + /// Forward is: norm · H · D₃ · H · D₂ · H · D₁ + /// Inverse is: norm · D₁ · H · D₂ · H · D₃ · H + fn apply_inverse_srht(&self, buf: &mut [f32]) { + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[2]); + + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[1]); + + walsh_hadamard_transform(buf); + apply_signs_xor(buf, &self.sign_masks[0]); + + let norm = self.norm_factor; + buf.iter_mut().for_each(|val| *val *= norm); + } + + /// Export the 3 sign vectors as a flat `Vec` of 0/1 values in inverse + /// application order `[D₃ | D₂ | D₁]`. + /// + /// Convention: `1` = positive (+1), `0` = negative (-1). + /// The output has length `3 * padded_dim` and is suitable for bitpacking + /// via FastLanes `bitpack_encode(..., 1, None)`. + pub fn export_inverse_signs_u8(&self) -> Vec { + let total = 3 * self.padded_dim; + let mut out = Vec::with_capacity(total); + + // Store in inverse order: sign_masks[2] (D₃), sign_masks[1] (D₂), sign_masks[0] (D₁) + for sign_idx in [2, 1, 0] { + for &mask in &self.sign_masks[sign_idx] { + out.push(if mask == 0 { 1u8 } else { 0u8 }); + } + } + out + } + + /// Reconstruct a `RotationMatrix` from unpacked `u8` 0/1 values. + /// + /// The input must have length `3 * padded_dim` with signs in inverse + /// application order `[D₃ | D₂ | D₁]` (as produced by [`export_inverse_signs_u8`]). + /// Convention: `1` = positive, `0` = negative. + /// + /// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the + /// stored `BitPackedArray` into `&[u8]`, which is passed here. + pub fn from_u8_slice(signs_u8: &[u8], dimension: usize) -> VortexResult { + let padded_dim = dimension.next_power_of_two(); + vortex_ensure!( + signs_u8.len() == 3 * padded_dim, + "Expected {} sign bytes, got {}", + 3 * padded_dim, + signs_u8.len() + ); + + // Reconstruct in storage order (inverse): [D₃, D₂, D₁] → sign_masks[2], [1], [0] + let mut sign_masks: [Vec; 3] = std::array::from_fn(|_| Vec::with_capacity(padded_dim)); + + for (round, sign_idx) in [2, 1, 0].into_iter().enumerate() { + let offset = round * padded_dim; + sign_masks[sign_idx] = signs_u8[offset..offset + padded_dim] + .iter() + .map(|&v| if v != 0 { 0u32 } else { F32_SIGN_BIT }) + .collect(); + } + + let norm_factor = 1.0 / (padded_dim as f32 * (padded_dim as f32).sqrt()); + + Ok(Self { + sign_masks, + padded_dim, + norm_factor, + }) + } +} + +/// Generate a vector of random XOR sign masks. +fn gen_random_sign_masks(rng: &mut StdRng, len: usize) -> Vec { + (0..len) + .map(|_| { + if rng.random_bool(0.5) { + 0u32 // +1: no-op + } else { + F32_SIGN_BIT // -1: flip sign bit + } + }) + .collect() +} + +/// Apply sign masks via XOR on the IEEE 754 sign bit. +/// +/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). +/// Equivalent to multiplying each element by ±1.0, but avoids FP dependency chains. +#[inline] +fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) { + for (val, &mask) in buf.iter_mut().zip(masks.iter()) { + *val = f32::from_bits(val.to_bits() ^ mask); + } +} + +/// In-place Walsh-Hadamard Transform (unnormalized, iterative). +/// +/// Input length must be a power of 2. Runs in O(n log n). +/// +/// Uses a fixed-size chunk strategy: for each stage, the buffer is processed +/// in `CHUNK`-element blocks with a compile-time-known butterfly function. +/// This lets LLVM unroll and auto-vectorize the butterfly into NEON/AVX SIMD. +fn walsh_hadamard_transform(buf: &mut [f32]) { + let len = buf.len(); + debug_assert!(len.is_power_of_two()); + + let mut half = 1; + while half < len { + let stride = half * 2; + // Process in chunks of `stride` elements. Within each chunk, + // split into non-overlapping (lo, hi) halves for the butterfly. + for chunk in buf.chunks_exact_mut(stride) { + let (lo, hi) = chunk.split_at_mut(half); + butterfly(lo, hi); + } + half *= 2; + } +} + +/// Butterfly: `lo[i], hi[i] = lo[i] + hi[i], lo[i] - hi[i]`. +/// +/// Separate function so LLVM can see the slice lengths match and auto-vectorize. +#[inline(always)] +fn butterfly(lo: &mut [f32], hi: &mut [f32]) { + debug_assert_eq!(lo.len(), hi.len()); + for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { + let sum = *a + *b; + let diff = *a - *b; + *a = sum; + *b = diff; + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[test] + fn deterministic_from_seed() -> VortexResult<()> { + let r1 = RotationMatrix::try_new(42, 64)?; + let r2 = RotationMatrix::try_new(42, 64)?; + let pd = r1.padded_dim(); + + let mut input = vec![0.0f32; pd]; + for i in 0..64 { + input[i] = i as f32; + } + let mut out1 = vec![0.0f32; pd]; + let mut out2 = vec![0.0f32; pd]; + + r1.rotate(&input, &mut out1); + r2.rotate(&input, &mut out2); + + assert_eq!(out1, out2); + Ok(()) + } + + /// Verify roundtrip is exact to f32 precision across many dimensions, + /// including non-power-of-two dimensions that require padding. + #[rstest] + #[case(32)] + #[case(64)] + #[case(100)] + #[case(128)] + #[case(256)] + #[case(512)] + #[case(768)] + #[case(1024)] + fn roundtrip_exact(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + let mut rotated = vec![0.0f32; padded_dim]; + let mut recovered = vec![0.0f32; padded_dim]; + + rot.rotate(&input, &mut rotated); + rot.inverse_rotate(&rotated, &mut recovered); + + let max_err: f32 = input + .iter() + .zip(recovered.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); + let rel_err = max_err / max_val; + + // SRHT roundtrip should be exact up to f32 precision (~1e-6). + assert!( + rel_err < 1e-5, + "roundtrip relative error too large for dim={dim}: {rel_err:.2e}" + ); + Ok(()) + } + + /// Verify norm preservation across dimensions. + #[rstest] + #[case(128)] + #[case(768)] + fn preserves_norm(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(7, dim)?; + let padded_dim = rot.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32) * 0.01; + } + let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); + + let mut rotated = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut rotated); + let rotated_norm: f32 = rotated.iter().map(|x| x * x).sum::().sqrt(); + + assert!( + (input_norm - rotated_norm).abs() / input_norm < 1e-5, + "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", + input_norm, + rotated_norm, + (input_norm - rotated_norm).abs() / input_norm + ); + Ok(()) + } + + /// Verify that export → from_u8_slice produces identical rotation output. + #[rstest] + #[case(64)] + #[case(128)] + #[case(768)] + fn sign_export_import_roundtrip(#[case] dim: usize) -> VortexResult<()> { + let rot = RotationMatrix::try_new(42, dim)?; + let padded_dim = rot.padded_dim(); + + let signs_u8 = rot.export_inverse_signs_u8(); + let rot2 = RotationMatrix::from_u8_slice(&signs_u8, dim)?; + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + + let mut out1 = vec![0.0f32; padded_dim]; + let mut out2 = vec![0.0f32; padded_dim]; + rot.rotate(&input, &mut out1); + rot2.rotate(&input, &mut out2); + assert_eq!(out1, out2, "Forward rotation mismatch after export/import"); + + rot.inverse_rotate(&out1, &mut out2); + let mut out3 = vec![0.0f32; padded_dim]; + rot2.inverse_rotate(&out1, &mut out3); + assert_eq!(out2, out3, "Inverse rotation mismatch after export/import"); + + Ok(()) + } + + #[test] + fn wht_basic() { + // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] + let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; + walsh_hadamard_transform(&mut buf); + assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); + + // WHT is self-inverse (up to scaling by n) + walsh_hadamard_transform(&mut buf); + // After two WHTs: each element multiplied by n=4 + assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); + } +} diff --git a/encodings/turboquant/src/vtable.rs b/encodings/turboquant/src/vtable.rs new file mode 100644 index 00000000000..07751509cb8 --- /dev/null +++ b/encodings/turboquant/src/vtable.rs @@ -0,0 +1,306 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! VTable implementation for TurboQuant encoding. + +use std::hash::Hash; +use std::ops::Deref; +use std::sync::Arc; + +use vortex_array::ArrayEq; +use vortex_array::ArrayHash; +use vortex_array::ArrayRef; +use vortex_array::DeserializeMetadata; +use vortex_array::DynArray; +use vortex_array::ExecutionCtx; +use vortex_array::ExecutionResult; +use vortex_array::Precision; +use vortex_array::ProstMetadata; +use vortex_array::SerializeMetadata; +use vortex_array::buffer::BufferHandle; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::serde::ArrayChildren; +use vortex_array::stats::StatsSetRef; +use vortex_array::vtable::Array; +use vortex_array::vtable::ArrayId; +use vortex_array::vtable::VTable; +use vortex_array::vtable::ValidityChild; +use vortex_array::vtable::ValidityVTableFromChild; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_panic; +use vortex_session::VortexSession; + +use crate::array::QjlCorrection; +use crate::array::TurboQuant; +use crate::array::TurboQuantArray; +use crate::array::TurboQuantMetadata; +use crate::decompress::execute_decompress; + +const MSE_CHILDREN: usize = 4; +const QJL_CHILDREN: usize = 3; + +impl VTable for TurboQuant { + type Array = TurboQuantArray; + type Metadata = ProstMetadata; + type OperationsVTable = TurboQuant; + type ValidityVTable = ValidityVTableFromChild; + + fn vtable(_array: &Self::Array) -> &Self { + &TurboQuant + } + + fn id(&self) -> ArrayId { + Self::ID + } + + fn len(array: &TurboQuantArray) -> usize { + array.norms.len() + } + + fn dtype(array: &TurboQuantArray) -> &DType { + &array.dtype + } + + fn stats(array: &TurboQuantArray) -> StatsSetRef<'_> { + array.stats_set.to_ref(array.as_ref()) + } + + fn array_hash( + array: &TurboQuantArray, + state: &mut H, + precision: Precision, + ) { + array.dtype.hash(state); + array.dimension.hash(state); + array.bit_width.hash(state); + array.has_qjl().hash(state); + array.codes.array_hash(state, precision); + array.norms.array_hash(state, precision); + array.centroids.array_hash(state, precision); + array.rotation_signs.array_hash(state, precision); + if let Some(qjl) = &array.qjl { + qjl.signs.array_hash(state, precision); + qjl.residual_norms.array_hash(state, precision); + qjl.rotation_signs.array_hash(state, precision); + } + } + + fn array_eq(array: &TurboQuantArray, other: &TurboQuantArray, precision: Precision) -> bool { + array.dtype == other.dtype + && array.dimension == other.dimension + && array.bit_width == other.bit_width + && array.has_qjl() == other.has_qjl() + && array.codes.array_eq(&other.codes, precision) + && array.norms.array_eq(&other.norms, precision) + && array.centroids.array_eq(&other.centroids, precision) + && array + .rotation_signs + .array_eq(&other.rotation_signs, precision) + && match (&array.qjl, &other.qjl) { + (Some(a), Some(b)) => { + a.signs.array_eq(&b.signs, precision) + && a.residual_norms.array_eq(&b.residual_norms, precision) + && a.rotation_signs.array_eq(&b.rotation_signs, precision) + } + (None, None) => true, + _ => false, + } + } + + fn nbuffers(_array: &TurboQuantArray) -> usize { + 0 + } + + fn buffer(_array: &TurboQuantArray, idx: usize) -> BufferHandle { + vortex_panic!("TurboQuantArray buffer index {idx} out of bounds") + } + + fn buffer_name(_array: &TurboQuantArray, _idx: usize) -> Option { + None + } + + fn nchildren(array: &TurboQuantArray) -> usize { + if array.has_qjl() { + MSE_CHILDREN + QJL_CHILDREN + } else { + MSE_CHILDREN + } + } + + fn child(array: &TurboQuantArray, idx: usize) -> ArrayRef { + match idx { + 0 => array.codes.clone(), + 1 => array.norms.clone(), + 2 => array.centroids.clone(), + 3 => array.rotation_signs.clone(), + 4 => array + .qjl + .as_ref() + .vortex_expect("QJL child requested but has_qjl is false") + .signs + .clone(), + 5 => array + .qjl + .as_ref() + .vortex_expect("QJL child requested but has_qjl is false") + .residual_norms + .clone(), + 6 => array + .qjl + .as_ref() + .vortex_expect("QJL child requested but has_qjl is false") + .rotation_signs + .clone(), + _ => vortex_panic!("TurboQuantArray child index {idx} out of bounds"), + } + } + + fn child_name(_array: &TurboQuantArray, idx: usize) -> String { + match idx { + 0 => "codes".to_string(), + 1 => "norms".to_string(), + 2 => "centroids".to_string(), + 3 => "rotation_signs".to_string(), + 4 => "qjl_signs".to_string(), + 5 => "qjl_residual_norms".to_string(), + 6 => "qjl_rotation_signs".to_string(), + _ => vortex_panic!("TurboQuantArray child_name index {idx} out of bounds"), + } + } + + fn metadata(array: &TurboQuantArray) -> VortexResult { + Ok(ProstMetadata(TurboQuantMetadata { + dimension: array.dimension, + bit_width: array.bit_width as u32, + has_qjl: array.has_qjl(), + })) + } + + fn serialize(metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(metadata.serialize())) + } + + fn deserialize( + bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + Ok(ProstMetadata( + as DeserializeMetadata>::deserialize(bytes)?, + )) + } + + fn build( + dtype: &DType, + len: usize, + metadata: &Self::Metadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + let bit_width = u8::try_from(metadata.bit_width)?; + let padded_dim = metadata.dimension.next_power_of_two() as usize; + let num_centroids = 1usize << bit_width; + + let u8_nn = DType::Primitive(PType::U8, Nullability::NonNullable); + let f32_nn = DType::Primitive(PType::F32, Nullability::NonNullable); + let codes_dtype = DType::FixedSizeList( + Arc::new(u8_nn.clone()), + padded_dim as u32, + Nullability::NonNullable, + ); + let codes = children.get(0, &codes_dtype, len)?; + + let norms = children.get(1, &f32_nn, len)?; + let centroids = children.get(2, &f32_nn, num_centroids)?; + + let signs_dtype = DType::Primitive(PType::U8, Nullability::NonNullable); + let rotation_signs = children.get(3, &signs_dtype, 3 * padded_dim)?; + + let qjl = if metadata.has_qjl { + let qjl_signs_dtype = + DType::FixedSizeList(Arc::new(u8_nn), padded_dim as u32, Nullability::NonNullable); + let qjl_signs = children.get(4, &qjl_signs_dtype, len)?; + let qjl_residual_norms = children.get(5, &f32_nn, len)?; + let qjl_rotation_signs = children.get(6, &signs_dtype, 3 * padded_dim)?; + Some(QjlCorrection { + signs: qjl_signs, + residual_norms: qjl_residual_norms, + rotation_signs: qjl_rotation_signs, + }) + } else { + None + }; + + Ok(TurboQuantArray { + dtype: dtype.clone(), + codes, + norms, + centroids, + rotation_signs, + qjl, + dimension: metadata.dimension, + bit_width, + stats_set: Default::default(), + }) + } + + fn with_children(array: &mut Self::Array, children: Vec) -> VortexResult<()> { + let expected = if array.has_qjl() { + MSE_CHILDREN + QJL_CHILDREN + } else { + MSE_CHILDREN + }; + vortex_ensure!( + children.len() == expected, + "TurboQuantArray expects {expected} children, got {}", + children.len() + ); + let mut iter = children.into_iter(); + array.codes = iter.next().vortex_expect("codes child"); + array.norms = iter.next().vortex_expect("norms child"); + array.centroids = iter.next().vortex_expect("centroids child"); + array.rotation_signs = iter.next().vortex_expect("rotation_signs child"); + if let Some(qjl) = &mut array.qjl { + qjl.signs = iter.next().vortex_expect("qjl_signs child"); + qjl.residual_norms = iter.next().vortex_expect("qjl_residual_norms child"); + qjl.rotation_signs = iter.next().vortex_expect("qjl_rotation_signs child"); + } + Ok(()) + } + + fn reduce_parent( + array: &Array, + parent: &ArrayRef, + child_idx: usize, + ) -> VortexResult> { + crate::compute::rules::RULES.evaluate(array, parent, child_idx) + } + + fn execute_parent( + array: &Array, + parent: &ArrayRef, + child_idx: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult> { + crate::compute::rules::PARENT_KERNELS.execute(array, parent, child_idx, ctx) + } + + fn execute(array: Arc>, ctx: &mut ExecutionCtx) -> VortexResult { + let inner = Arc::try_unwrap(array) + .map(|a| a.into_inner()) + .unwrap_or_else(|arc| arc.as_ref().deref().clone()); + Ok(ExecutionResult::done(execute_decompress(inner, ctx)?)) + } +} + +impl ValidityChild for TurboQuant { + fn validity_child(array: &TurboQuantArray) -> &ArrayRef { + array.codes() + } +} diff --git a/vortex-btrblocks/Cargo.toml b/vortex-btrblocks/Cargo.toml index 1c745306c4a..4e51ee33014 100644 --- a/vortex-btrblocks/Cargo.toml +++ b/vortex-btrblocks/Cargo.toml @@ -35,6 +35,7 @@ vortex-pco = { workspace = true, optional = true } vortex-runend = { workspace = true } vortex-sequence = { workspace = true } vortex-sparse = { workspace = true } +vortex-turboquant = { workspace = true } vortex-utils = { workspace = true } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } diff --git a/vortex-btrblocks/public-api.lock b/vortex-btrblocks/public-api.lock index 55d23a96a26..17564a21025 100644 --- a/vortex-btrblocks/public-api.lock +++ b/vortex-btrblocks/public-api.lock @@ -194,6 +194,8 @@ pub vortex_btrblocks::BtrBlocksCompressor::int_schemes: alloc::vec::Vec<&'static pub vortex_btrblocks::BtrBlocksCompressor::string_schemes: alloc::vec::Vec<&'static dyn vortex_btrblocks::compressor::string::StringScheme> +pub vortex_btrblocks::BtrBlocksCompressor::turboquant_config: core::option::Option + impl vortex_btrblocks::BtrBlocksCompressor pub fn vortex_btrblocks::BtrBlocksCompressor::compress(&self, array: &vortex_array::array::ArrayRef) -> vortex_error::VortexResult @@ -236,6 +238,8 @@ pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::include_int(self, codes: im pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::include_string(self, codes: impl core::iter::traits::collect::IntoIterator) -> Self +pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::with_turboquant(self, config: vortex_turboquant::compress::TurboQuantConfig) -> Self + impl core::clone::Clone for vortex_btrblocks::BtrBlocksCompressorBuilder pub fn vortex_btrblocks::BtrBlocksCompressorBuilder::clone(&self) -> vortex_btrblocks::BtrBlocksCompressorBuilder diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index d329ec8c139..851c4e6d986 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -46,6 +46,7 @@ pub struct BtrBlocksCompressorBuilder { int_schemes: HashSet<&'static dyn IntegerScheme>, float_schemes: HashSet<&'static dyn FloatScheme>, string_schemes: HashSet<&'static dyn StringScheme>, + turboquant_config: Option, } impl Default for BtrBlocksCompressorBuilder { @@ -66,6 +67,7 @@ impl Default for BtrBlocksCompressorBuilder { .copied() .filter(|s| s.code() != StringCode::Zstd && s.code() != StringCode::ZstdBuffers) .collect(), + turboquant_config: None, } } } @@ -77,6 +79,7 @@ impl BtrBlocksCompressorBuilder { int_schemes: Default::default(), float_schemes: Default::default(), string_schemes: Default::default(), + turboquant_config: None, } } @@ -134,6 +137,16 @@ impl BtrBlocksCompressorBuilder { self } + /// Enables TurboQuant lossy vector quantization for tensor extension types. + /// + /// When enabled, `Vector` and `FixedShapeTensor` extension columns will be + /// quantized at the configured bit-width instead of using the default + /// recursive storage compression. + pub fn with_turboquant(mut self, config: vortex_turboquant::TurboQuantConfig) -> Self { + self.turboquant_config = Some(config); + self + } + /// Builds the configured `BtrBlocksCompressor`. pub fn build(self) -> BtrBlocksCompressor { // Note we should apply the schemes in the same order, in case try conflict. @@ -153,6 +166,7 @@ impl BtrBlocksCompressorBuilder { .into_iter() .sorted_by_key(|s| s.code()) .collect_vec(), + turboquant_config: self.turboquant_config, } } } diff --git a/vortex-btrblocks/src/canonical_compressor.rs b/vortex-btrblocks/src/canonical_compressor.rs index 46252060a1f..203144912fe 100644 --- a/vortex-btrblocks/src/canonical_compressor.rs +++ b/vortex-btrblocks/src/canonical_compressor.rs @@ -40,6 +40,8 @@ use crate::compressor::float::FloatScheme; use crate::compressor::integer::IntegerScheme; use crate::compressor::string::StringScheme; use crate::compressor::temporal::compress_temporal; +use crate::compressor::turboquant::compress_turboquant; +use crate::compressor::turboquant::is_tensor_extension; /// Trait for compressors that can compress canonical arrays. /// @@ -101,6 +103,9 @@ pub struct BtrBlocksCompressor { /// String compressor with configured schemes. pub string_schemes: Vec<&'static dyn StringScheme>, + + /// Optional TurboQuant configuration for tensor extension types. + pub turboquant_config: Option, } impl Default for BtrBlocksCompressor { @@ -290,6 +295,15 @@ impl CanonicalCompressor for BtrBlocksCompressor { return compress_temporal(self, temporal_array); } + // Compress tensor extension types with TurboQuant if configured. + // Falls through to default compression for nullable storage. + if let Some(tq_config) = &self.turboquant_config + && is_tensor_extension(&ext_array) + && let Some(compressed) = compress_turboquant(&ext_array, tq_config)? + { + return Ok(compressed); + } + // Compress the underlying storage array. let compressed_storage = self.compress(ext_array.storage_array())?; diff --git a/vortex-btrblocks/src/compressor/mod.rs b/vortex-btrblocks/src/compressor/mod.rs index 5c3a31271cd..e97c1d9b87b 100644 --- a/vortex-btrblocks/src/compressor/mod.rs +++ b/vortex-btrblocks/src/compressor/mod.rs @@ -34,6 +34,7 @@ mod patches; mod rle; pub(crate) mod string; pub(crate) mod temporal; +pub(crate) mod turboquant; /// Maximum cascade depth for compression. pub(crate) const MAX_CASCADE: usize = 3; diff --git a/vortex-btrblocks/src/compressor/turboquant.rs b/vortex-btrblocks/src/compressor/turboquant.rs new file mode 100644 index 00000000000..974518acb14 --- /dev/null +++ b/vortex-btrblocks/src/compressor/turboquant.rs @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Specialized compressor for TurboQuant vector quantization of tensor extension types. + +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_error::VortexResult; +use vortex_turboquant::FIXED_SHAPE_TENSOR_EXT_ID; +use vortex_turboquant::TurboQuantConfig; +use vortex_turboquant::VECTOR_EXT_ID; +use vortex_turboquant::turboquant_encode_qjl; + +/// Check if an extension array has a tensor extension type. +pub(crate) fn is_tensor_extension(ext_array: &ExtensionArray) -> bool { + let ext_id = ext_array.ext_dtype().id(); + ext_id.as_ref() == VECTOR_EXT_ID || ext_id.as_ref() == FIXED_SHAPE_TENSOR_EXT_ID +} + +/// Try to compress a tensor extension array using TurboQuant. +/// +/// Returns `Ok(Some(...))` on success, or `Ok(None)` if the storage is nullable +/// (TurboQuant requires non-nullable input). The caller should fall through to +/// default compression when `None` is returned. +/// +/// Produces a `TurboQuantArray` with QJL correction, stored inside the Extension +/// wrapper. The per-row children (codes, QJL signs) are `FixedSizeListArray`s +/// whose inner elements will be cascading-compressed by the layout writer. +pub(crate) fn compress_turboquant( + ext_array: &ExtensionArray, + config: &TurboQuantConfig, +) -> VortexResult> { + let storage = ext_array.storage_array(); + let fsl = storage.to_canonical()?.into_fixed_size_list(); + + if fsl.dtype().is_nullable() { + return Ok(None); + } + if fsl.is_empty() { + return Ok(None); + } + + let encoded = turboquant_encode_qjl(&fsl, config)?; + + Ok(Some( + ExtensionArray::new(ext_array.ext_dtype().clone(), encoded).into_array(), + )) +} diff --git a/vortex-file/Cargo.toml b/vortex-file/Cargo.toml index d568328bb52..0752553c1e4 100644 --- a/vortex-file/Cargo.toml +++ b/vortex-file/Cargo.toml @@ -54,6 +54,7 @@ vortex-scan = { workspace = true } vortex-sequence = { workspace = true } vortex-session = { workspace = true } vortex-sparse = { workspace = true } +vortex-turboquant = { workspace = true } vortex-utils = { workspace = true, features = ["dashmap"] } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } diff --git a/vortex-file/public-api.lock b/vortex-file/public-api.lock index 84cca867cba..ffb19c25fb5 100644 --- a/vortex-file/public-api.lock +++ b/vortex-file/public-api.lock @@ -358,6 +358,8 @@ pub fn vortex_file::WriteStrategyBuilder::with_flat_strategy(self, flat: alloc:: pub fn vortex_file::WriteStrategyBuilder::with_row_block_size(self, row_block_size: usize) -> Self +pub fn vortex_file::WriteStrategyBuilder::with_vector_quantization(self, config: vortex_turboquant::compress::TurboQuantConfig) -> Self + impl core::default::Default for vortex_file::WriteStrategyBuilder pub fn vortex_file::WriteStrategyBuilder::default() -> Self diff --git a/vortex-file/src/lib.rs b/vortex-file/src/lib.rs index d888eb88def..b99ba26d9e9 100644 --- a/vortex-file/src/lib.rs +++ b/vortex-file/src/lib.rs @@ -178,4 +178,5 @@ pub fn register_default_encodings(session: &mut VortexSession) { vortex_fastlanes::initialize(session); vortex_runend::initialize(session); vortex_sequence::initialize(session); + vortex_turboquant::initialize(session); } diff --git a/vortex-file/src/strategy.rs b/vortex-file/src/strategy.rs index 4d6031a220c..61385791b84 100644 --- a/vortex-file/src/strategy.rs +++ b/vortex-file/src/strategy.rs @@ -28,7 +28,6 @@ use vortex_array::arrays::VarBinView; use vortex_array::dtype::FieldPath; use vortex_array::session::ArrayRegistry; use vortex_array::session::ArraySession; -#[cfg(feature = "zstd")] use vortex_btrblocks::BtrBlocksCompressorBuilder; #[cfg(feature = "zstd")] use vortex_btrblocks::FloatCode; @@ -61,6 +60,7 @@ use vortex_pco::Pco; use vortex_runend::RunEnd; use vortex_sequence::Sequence; use vortex_sparse::Sparse; +use vortex_turboquant::TurboQuant; use vortex_utils::aliases::hash_map::HashMap; use vortex_zigzag::ZigZag; #[cfg(feature = "zstd")] @@ -110,6 +110,7 @@ pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { session.register(Sequence); session.register(Sparse); session.register(ZigZag); + session.register(TurboQuant); #[cfg(feature = "zstd")] session.register(Zstd); @@ -126,6 +127,7 @@ pub static ALLOWED_ENCODINGS: LazyLock = LazyLock::new(|| { /// bulk decoding performance, and IOPS required to perform an indexed read. pub struct WriteStrategyBuilder { compressor: Option>, + turboquant_config: Option, row_block_size: usize, field_writers: HashMap>, allow_encodings: Option, @@ -138,6 +140,7 @@ impl Default for WriteStrategyBuilder { fn default() -> Self { Self { compressor: None, + turboquant_config: None, row_block_size: 8192, field_writers: HashMap::new(), allow_encodings: Some(ALLOWED_ENCODINGS.clone()), @@ -231,6 +234,29 @@ impl WriteStrategyBuilder { self } + /// Configure lossy vector quantization for tensor columns using TurboQuant. + /// + /// Columns with `Vector` or `FixedShapeTensor` extension types will be quantized at the + /// specified bit-width. All other columns use the default BtrBlocks compression strategy. + /// The TurboQuant array's children (norms, codes) are recursively compressed by the + /// BtrBlocks compressor. + /// + /// This can be combined with other builder methods. If a custom compressor is also set + /// via [`with_compressor`](Self::with_compressor), the custom compressor takes precedence + /// and the TurboQuant config is ignored. + /// + /// # Examples + /// + /// ```ignore + /// WriteStrategyBuilder::default() + /// .with_vector_quantization(TurboQuantConfig { bit_width: 3, seed: None }) + /// .build() + /// ``` + pub fn with_vector_quantization(mut self, config: vortex_turboquant::TurboQuantConfig) -> Self { + self.turboquant_config = Some(config); + self + } + /// Builds the canonical [`LayoutStrategy`] implementation, with the configured overrides /// applied. pub fn build(self) -> Arc { @@ -249,6 +275,14 @@ impl WriteStrategyBuilder { // 5. compress each chunk let compressing = if let Some(ref compressor) = self.compressor { CompressingStrategy::new_opaque(buffered, compressor.clone()) + } else if let Some(tq_config) = self.turboquant_config { + let btrblocks = BtrBlocksCompressorBuilder::default() + .with_turboquant(tq_config) + .build(); + CompressingStrategy::new_opaque( + buffered, + Arc::new(btrblocks) as Arc, + ) } else { CompressingStrategy::new_btrblocks(buffered, true) }; diff --git a/vortex/Cargo.toml b/vortex/Cargo.toml index d8dc89882b0..d41db760117 100644 --- a/vortex/Cargo.toml +++ b/vortex/Cargo.toml @@ -44,6 +44,7 @@ vortex-scan = { workspace = true } vortex-sequence = { workspace = true } vortex-session = { workspace = true } vortex-sparse = { workspace = true } +vortex-turboquant = { workspace = true } vortex-utils = { workspace = true } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } @@ -54,8 +55,10 @@ arrow-array = { workspace = true } divan = { workspace = true } fastlanes = { workspace = true } mimalloc = { workspace = true } +paste = { workspace = true } parquet = { workspace = true } rand = { workspace = true } +rand_distr = { workspace = true } serde_json = { workspace = true } tokio = { workspace = true, features = ["full"] } tracing = { workspace = true } diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index 4776afa4a52..7e46b22322f 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -11,16 +11,19 @@ use divan::Bencher; #[cfg(not(codspeed))] use divan::counter::BytesCount; use mimalloc::MiMalloc; +use paste::paste; use rand::RngExt; use rand::SeedableRng; use rand::prelude::IndexedRandom; use rand::rngs::StdRng; use vortex::array::IntoArray; use vortex::array::ToCanonical; +use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; use vortex::array::arrays::VarBinViewArray; use vortex::array::builders::dict::dict_encode; use vortex::array::builtins::ArrayBuiltins; +use vortex::array::validity::Validity; use vortex::dtype::PType; use vortex::encodings::alp::RDEncoder; use vortex::encodings::alp::alp_encode; @@ -32,11 +35,15 @@ use vortex::encodings::fsst::fsst_train_compressor; use vortex::encodings::pco::PcoArray; use vortex::encodings::runend::RunEndArray; use vortex::encodings::sequence::sequence_encode; +use vortex::encodings::turboquant::TurboQuantConfig; +use vortex::encodings::turboquant::turboquant_encode_mse; +use vortex::encodings::turboquant::turboquant_encode_qjl; use vortex::encodings::zigzag::zigzag_encode; use vortex::encodings::zstd::ZstdArray; use vortex_array::VortexSessionExecute; use vortex_array::dtype::Nullability; use vortex_array::session::ArraySession; +use vortex_buffer::BufferMut; use vortex_sequence::SequenceArray; use vortex_session::VortexSession; @@ -405,3 +412,108 @@ fn bench_zstd_decompress_string(bencher: Bencher) { .with_inputs(|| &compressed) .bench_refs(|a| a.to_canonical()); } + +// TurboQuant vector quantization benchmarks + +const NUM_VECTORS: usize = 1_000; + +/// Generate `num_vectors` random f32 vectors of the given dimension using i.i.d. +/// standard normal components. This is a conservative test distribution: real +/// neural network embeddings typically have structure (clustered, anisotropic) +/// that the SRHT exploits for better quantization, so Gaussian i.i.d. is a +/// worst-case baseline for TurboQuant. +fn setup_vector_fsl(dim: usize) -> FixedSizeListArray { + let mut rng = StdRng::seed_from_u64(42); + let normal = rand_distr::Normal::new(0.0f32, 1.0).unwrap(); + + let mut buf = BufferMut::::with_capacity(NUM_VECTORS * dim); + for _ in 0..(NUM_VECTORS * dim) { + buf.push(rand_distr::Distribution::sample(&normal, &mut rng)); + } + + let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); + FixedSizeListArray::try_new( + elements.into_array(), + dim as u32, + Validity::NonNullable, + NUM_VECTORS, + ) + .unwrap() +} + +fn turboquant_config(bit_width: u8) -> TurboQuantConfig { + TurboQuantConfig { + bit_width, + seed: Some(123), + } +} + +macro_rules! turboquant_bench { + (compress, $dim:literal, $bits:literal, $name:ident) => { + paste! { + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit_mse"))] + fn [<$name _mse>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &fsl) + .bench_refs(|a| turboquant_encode_mse(a, &config).unwrap()); + } + + #[divan::bench(name = concat!("turboquant_compress_dim", stringify!($dim), "_", stringify!($bits), "bit_qjl"))] + fn [<$name _qjl>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &fsl) + .bench_refs(|a| turboquant_encode_qjl(a, &config).unwrap()); + } + } + }; + (decompress, $dim:literal, $bits:literal, $name:ident) => { + paste! { + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit_mse"))] + fn [<$name _mse>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + let compressed = turboquant_encode_mse(&fsl, &config).unwrap(); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &compressed) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + a.clone() + .into_array() + .execute::(&mut ctx) + .unwrap() + }); + } + + #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit_qjl"))] + fn [<$name _qjl>](bencher: Bencher) { + let fsl = setup_vector_fsl($dim); + let config = turboquant_config($bits); + let compressed = turboquant_encode_qjl(&fsl, &config).unwrap(); + with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) + .with_inputs(|| &compressed) + .bench_refs(|a| { + let mut ctx = SESSION.create_execution_ctx(); + a.clone() + .into_array() + .execute::(&mut ctx) + .unwrap() + }); + } + } + }; +} + +turboquant_bench!(compress, 128, 4, bench_tq_compress_128_4); +turboquant_bench!(decompress, 128, 4, bench_tq_decompress_128_4); +turboquant_bench!(compress, 768, 4, bench_tq_compress_768_4); +turboquant_bench!(decompress, 768, 4, bench_tq_decompress_768_4); +turboquant_bench!(compress, 1024, 2, bench_tq_compress_1024_2); +turboquant_bench!(decompress, 1024, 2, bench_tq_decompress_1024_2); +turboquant_bench!(compress, 1024, 4, bench_tq_compress_1024_4); +turboquant_bench!(decompress, 1024, 4, bench_tq_decompress_1024_4); +turboquant_bench!(compress, 1024, 8, bench_tq_compress_1024_8); +turboquant_bench!(decompress, 1024, 8, bench_tq_decompress_1024_8); diff --git a/vortex/public-api.lock b/vortex/public-api.lock index 0c8ce9d0cd9..325812fafc4 100644 --- a/vortex/public-api.lock +++ b/vortex/public-api.lock @@ -74,6 +74,10 @@ pub mod vortex::encodings::sparse pub use vortex::encodings::sparse::<> +pub mod vortex::encodings::turboquant + +pub use vortex::encodings::turboquant::<> + pub mod vortex::encodings::zigzag pub use vortex::encodings::zigzag::<> diff --git a/vortex/src/lib.rs b/vortex/src/lib.rs index a532fc1adad..454886077c3 100644 --- a/vortex/src/lib.rs +++ b/vortex/src/lib.rs @@ -143,6 +143,10 @@ pub mod encodings { pub use vortex_sparse::*; } + pub mod turboquant { + pub use vortex_turboquant::*; + } + pub mod zigzag { pub use vortex_zigzag::*; }