diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index e7baf491fef..5485e8a08d3 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -2,6 +2,122 @@ pub mod vortex_tensor pub mod vortex_tensor::encodings +pub mod vortex_tensor::encodings::norm + +pub struct vortex_tensor::encodings::norm::NormVector + +impl core::clone::Clone for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::clone(&self) -> vortex_tensor::encodings::norm::NormVector + +impl core::fmt::Debug for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl vortex_array::vtable::VTable for vortex_tensor::encodings::norm::NormVector + +pub type vortex_tensor::encodings::norm::NormVector::Array = vortex_tensor::encodings::norm::NormVectorArray + +pub type vortex_tensor::encodings::norm::NormVector::Metadata = vortex_array::metadata::EmptyMetadata + +pub type vortex_tensor::encodings::norm::NormVector::OperationsVTable = vortex_tensor::encodings::norm::NormVector + +pub type vortex_tensor::encodings::norm::NormVector::ValidityVTable = vortex_array::vtable::validity::ValidityVTableFromChild + +pub fn vortex_tensor::encodings::norm::NormVector::array_eq(array: &vortex_tensor::encodings::norm::NormVectorArray, other: &vortex_tensor::encodings::norm::NormVectorArray, precision: vortex_array::hash::Precision) -> bool + +pub fn vortex_tensor::encodings::norm::NormVector::array_hash(array: &vortex_tensor::encodings::norm::NormVectorArray, state: &mut H, precision: vortex_array::hash::Precision) + +pub fn vortex_tensor::encodings::norm::NormVector::buffer(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> vortex_array::buffer::BufferHandle + +pub fn vortex_tensor::encodings::norm::NormVector::buffer_name(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> core::option::Option + +pub fn vortex_tensor::encodings::norm::NormVector::build(dtype: &vortex_array::dtype::DType, len: usize, _metadata: &Self::Metadata, _buffers: &[vortex_array::buffer::BufferHandle], children: &dyn vortex_array::serde::ArrayChildren) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::child(array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> vortex_array::array::ArrayRef + +pub fn vortex_tensor::encodings::norm::NormVector::child_name(_array: &vortex_tensor::encodings::norm::NormVectorArray, idx: usize) -> alloc::string::String + +pub fn vortex_tensor::encodings::norm::NormVector::deserialize(_bytes: &[u8], _dtype: &vortex_array::dtype::DType, _len: usize, _buffers: &[vortex_array::buffer::BufferHandle], _session: &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::dtype(array: &vortex_tensor::encodings::norm::NormVectorArray) -> &vortex_array::dtype::DType + +pub fn vortex_tensor::encodings::norm::NormVector::execute(array: alloc::sync::Arc>, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::id(&self) -> vortex_array::vtable::dyn_::ArrayId + +pub fn vortex_tensor::encodings::norm::NormVector::len(array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::metadata(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVector::nbuffers(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::nchildren(_array: &vortex_tensor::encodings::norm::NormVectorArray) -> usize + +pub fn vortex_tensor::encodings::norm::NormVector::serialize(_metadata: Self::Metadata) -> vortex_error::VortexResult>> + +pub fn vortex_tensor::encodings::norm::NormVector::slot_name(_array: &Self::Array, idx: usize) -> alloc::string::String + +pub fn vortex_tensor::encodings::norm::NormVector::slots(array: &Self::Array) -> &[core::option::Option] + +pub fn vortex_tensor::encodings::norm::NormVector::stats(array: &vortex_tensor::encodings::norm::NormVectorArray) -> vortex_array::stats::array::StatsSetRef<'_> + +pub fn vortex_tensor::encodings::norm::NormVector::vtable(_array: &Self::Array) -> &Self + +pub fn vortex_tensor::encodings::norm::NormVector::with_slots(array: &mut Self::Array, slots: alloc::vec::Vec>) -> vortex_error::VortexResult<()> + +impl vortex_array::vtable::operations::OperationsVTable for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::scalar_at(array: &vortex_tensor::encodings::norm::NormVectorArray, index: usize, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +impl vortex_array::vtable::validity::ValidityChild for vortex_tensor::encodings::norm::NormVector + +pub fn vortex_tensor::encodings::norm::NormVector::validity_child(array: &vortex_tensor::encodings::norm::NormVectorArray) -> &vortex_array::array::ArrayRef + +pub struct vortex_tensor::encodings::norm::NormVectorArray + +impl vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::compress(vector_array: vortex_array::array::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::decompress(&self, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::norms(&self) -> &vortex_array::array::ArrayRef + +pub fn vortex_tensor::encodings::norm::NormVectorArray::try_new(vector_array: vortex_array::array::ArrayRef, norms: vortex_array::array::ArrayRef) -> vortex_error::VortexResult + +pub fn vortex_tensor::encodings::norm::NormVectorArray::vector_array(&self) -> &vortex_array::array::ArrayRef + +impl vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::to_array(&self) -> vortex_array::array::ArrayRef + +impl core::clone::Clone for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::clone(&self) -> vortex_tensor::encodings::norm::NormVectorArray + +impl core::convert::AsRef for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::as_ref(&self) -> &dyn vortex_array::array::DynArray + +impl core::convert::From for vortex_array::array::ArrayRef + +pub fn vortex_array::array::ArrayRef::from(value: vortex_tensor::encodings::norm::NormVectorArray) -> vortex_array::array::ArrayRef + +impl core::fmt::Debug for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::ops::deref::Deref for vortex_tensor::encodings::norm::NormVectorArray + +pub type vortex_tensor::encodings::norm::NormVectorArray::Target = dyn vortex_array::array::DynArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::deref(&self) -> &Self::Target + +impl vortex_array::array::IntoArray for vortex_tensor::encodings::norm::NormVectorArray + +pub fn vortex_tensor::encodings::norm::NormVectorArray::into_array(self) -> vortex_array::array::ArrayRef + pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::FixedShapeTensor diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs index 090151e9226..616c48fb458 100644 --- a/vortex-tensor/src/encodings/mod.rs +++ b/vortex-tensor/src/encodings/mod.rs @@ -4,7 +4,7 @@ //! Encodings for the different tensor types. // TODO(connor): -// pub mod norm; // Unit-normalized vectors. +pub mod norm; // Unit-normalized vectors. // pub mod spherical; // Spherical transform on unit-normalized vectors. // TODO(will): diff --git a/vortex-tensor/src/encodings/norm/array.rs b/vortex-tensor/src/encodings/norm/array.rs new file mode 100644 index 00000000000..81176fd06a7 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/array.rs @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::Float; +use num_traits::Zero; +use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::LEGACY_SESSION; +use vortex::array::VortexSessionExecute; +use vortex::array::arrays::ExtensionArray; +use vortex::array::arrays::FixedSizeListArray; +use vortex::array::arrays::PrimitiveArray; +use vortex::array::builtins::ArrayBuiltins; +use vortex::array::match_each_float_ptype; +use vortex::array::stats::ArrayStats; +use vortex::array::validity::Validity; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::dtype::extension::ExtDType; +use vortex::dtype::extension::ExtDTypeRef; +use vortex::encodings::runend::RunEndArray; +use vortex::encodings::sequence::SequenceArray; +use vortex::error::VortexExpect; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; +use vortex::expr::Expression; +use vortex::expr::root; +use vortex::extension::EmptyMetadata; +use vortex::scalar::PValue; +use vortex::scalar_fn::ScalarFn; +use vortex::scalar_fn::fns::operators::Operator; + +use crate::encodings::norm::vtable::NORMS_SLOT; +use crate::encodings::norm::vtable::VECTORS_SLOT; +use crate::scalar_fns::ApproxOptions; +use crate::scalar_fns::l2_norm::L2Norm; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; +use crate::vector::Vector; + +/// A normalized array that stores unit-normalized vectors alongside their original L2 norms. +/// +/// Each vector in the array is divided by its L2 norm, producing a unit-normalized vector. The +/// original norms are stored separately so that the original vectors can be reconstructed. +/// +/// The `vector_array` child carries its own validity and nullability, so a nullable input vector +/// array produces a nullable `NormVectorArray`. +#[derive(Debug, Clone)] +pub struct NormVectorArray { + /// The slots of the child arrays, which are vectors and norms. + /// + /// The vector array is the backing vector array that has been unit normalized, and the norm + /// array is the L2 norms of each vector. + /// + /// The underlying elements of the vector array must be floating-point. This child may be + /// nullable, and its validity determines the validity of the `NormVectorArray`. + /// + /// This must have the same validity as the vector array, and the same dtype as the elements of + /// the vector array. + pub(super) slots: Vec>, + + /// Stats set owned by this array. + pub(super) stats_set: ArrayStats, +} + +impl NormVectorArray { + /// Creates a new [`NormVectorArray`] from a unit-normalized vector array and associated L2 + /// norms for each vector. + /// + /// The `vector_array` must be a [`Vector`] extension array with floating-point elements, and + /// `norms` must be a primitive array of the same float type with the same length. The + /// `vector_array` may be nullable. + pub fn try_new(vector_array: ArrayRef, norms: ArrayRef) -> VortexResult { + let ext = Self::validate(&vector_array, &norms)?; + + let element_ptype = extension_element_ptype(&ext)?; + + let nullability = Nullability::from(vector_array.dtype().is_nullable()); + let expected_norms_dtype = DType::Primitive(element_ptype, nullability); + vortex_ensure_eq!( + *norms.dtype(), + expected_norms_dtype, + "norms dtype must match vector element type" + ); + + vortex_ensure_eq!( + vector_array.len(), + norms.len(), + "vector_array and norms must have the same length" + ); + + let slots = vec![Some(vector_array), Some(norms)]; + + Ok(Self { + slots, + stats_set: ArrayStats::default(), + }) + } + + /// Validates that the given array has the [`Vector`] extension type and returns the + /// [`ExtDTypeRef`] of the vector array on success. + fn validate_vector_array(vector_array: &ArrayRef) -> VortexResult { + let ext = vector_array.dtype().as_extension_opt().ok_or_else(|| { + vortex_err!( + "vector_array dtype must be an extension type, got {}", + vector_array.dtype() + ) + })?; + + vortex_ensure!( + ext.is::(), + "vector_array must have the Vector extension type, got {}", + vector_array.dtype() + ); + + Ok(ext.clone()) + } + + /// Validates that the given `vector_array` and `norms` are compatible. + /// + /// Checks that: + /// - The `vector_array` has the [`Vector`] extension type. + /// - Both arrays have the same length. + /// - The element primitive type of the vectors matches the primitive type of the norms. + /// - Both arrays share the same validity mask. + /// + /// Returns the [`ExtDTypeRef`] of the vector array on success. + fn validate(vector_array: &ArrayRef, norms: &ArrayRef) -> VortexResult { + let ext = Self::validate_vector_array(vector_array)?; + + vortex_ensure_eq!( + vector_array.len(), + norms.len(), + "vector_array and norms must have the same length" + ); + + let element_ptype = extension_element_ptype(&ext)?; + vortex_ensure_eq!( + element_ptype, + norms.dtype().as_ptype(), + "vector elements ptype must be the same as the norms ptype" + ); + + // TODO(connor): Is there a better way to do this? + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let mask_eq = vector_array + .validity()? + .mask_eq(&norms.validity()?, &mut ctx)?; + vortex_ensure!( + mask_eq, + "vector_array and norms must have the same validity" + ); + + Ok(ext) + } + + /// Returns a reference to the backing vector array that has been unit normalized. + pub fn vector_array(&self) -> &ArrayRef { + self.slots[VECTORS_SLOT] + .as_ref() + .vortex_expect("vector_array slot must be present") + } + + /// Returns a reference to the L2 norms of each vector. + pub fn norms(&self) -> &ArrayRef { + self.slots[NORMS_SLOT] + .as_ref() + .vortex_expect("norms slot must be present") + } + + /// Encodes a [`Vector`] extension array into a [`NormVectorArray`] by computing L2 norms and + /// dividing each vector by its norm. + /// + /// The input must be a [`Vector`] extension array with floating-point elements. Nullable inputs + /// are supported; the validity mask is preserved and the normalized data for null rows is + /// unspecified. + /// + /// Note that compression is lossy per floating-point operations. + pub fn compress(vector_array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let ext = Self::validate_vector_array(&vector_array)?; + + let list_size = extension_list_size(&ext)? as usize; + let row_count = vector_array.len(); + let nullability = Nullability::from(vector_array.dtype().is_nullable()); + let validity = vector_array.validity()?; + + // Compute L2 norms using the scalar function. If the input is nullable, the norms will + // also be nullable (null vectors produce null norms). + let storage = extension_storage(&vector_array)?; + let l2_norm_expr = Expression::try_new( + ScalarFn::new(L2Norm, ApproxOptions::Exact).erased(), + [root()], + )?; + let norms_prim: PrimitiveArray = vector_array.apply(&l2_norm_expr)?.execute(ctx)?; + let norms_array = norms_prim.clone().into_array(); + + // Extract flat elements from the (always non-nullable) storage for normalization. + let flat = extract_flat_elements(&storage, list_size, ctx)?; + + match_each_float_ptype!(flat.ptype(), |T| { + let norms_slice = norms_prim.as_slice::(); + + let normalized_elems: PrimitiveArray = (0..row_count) + .map(|i| -> VortexResult> { + if !validity.is_valid(i)? { + return Ok(vec![T::zero(); list_size]); + } + + let inv_norm = safe_inv_norm(norms_slice[i]); + Ok(flat.row::(i).iter().map(|&v| v * inv_norm).collect()) + }) + .collect::>>>()? + .into_iter() + .flatten() + .collect(); + + // Reconstruct the vector array with the same nullability as the input. + let validity = Validity::from(nullability); + let fsl = FixedSizeListArray::new( + normalized_elems.into_array(), + u32::try_from(list_size)?, + validity, + row_count, + ); + + let ext_dtype = + ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); + let normalized_vector = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); + + Self::try_new(normalized_vector, norms_array) + }) + } + + /// Reconstructs the original vectors by multiplying each unit-normalized vector by its L2 norm. + /// + /// The returned array has the same dtype (including nullability) as the original + /// `vector_array` child. + pub fn decompress(&self, ctx: &mut ExecutionCtx) -> VortexResult { + let ext = self + .dtype() + .as_extension_opt() + .vortex_expect("somehow had a non-extension dtype"); + + let storage = extension_storage(self.vector_array())?; + let fsl: FixedSizeListArray = storage.execute(ctx)?; + + let denormalized_fsl = + broadcast_binary_to_elements(fsl, self.norms().clone(), Operator::Mul, ctx)?; + + Ok(ExtensionArray::new(ext.clone(), denormalized_fsl.into_array()).into_array()) + } +} + +/// We do not have any kind of "broadcast" expression where we evaluate a binary expression between +/// every `FixedSizeList` element and another value. We can mimic this by creating a +/// `RunEnd(Sequence)` array that we evaluate with the elements of the [`FixedSizeListArray`]. +fn broadcast_binary_to_elements( + fsl: FixedSizeListArray, + values: ArrayRef, + op: Operator, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let num_lists = fsl.len(); + let list_size = fsl.list_size(); + let validity = fsl.validity()?; + let elements = fsl.elements(); + debug_assert!(elements.dtype().is_primitive()); + + // Create the broadcasting array via a runend array with a sequence of ends. + let base: PValue = list_size.into(); + let multiplier: PValue = base; + let ends_ptype = base.ptype(); + let ends_nullability = Nullability::NonNullable; + + let ends = SequenceArray::try_new(base, multiplier, ends_ptype, ends_nullability, num_lists)?; + let runend = RunEndArray::try_new(ends.into_array(), values)?; + + let binary_eval = elements.binary(runend.into_array(), op)?; + let executed: PrimitiveArray = binary_eval.execute(ctx)?; + + // SAFETY: We simply evaluated a scalar function on all of the elements, so none of the length + // properties have changed. + let fsl = unsafe { + FixedSizeListArray::new_unchecked(executed.into_array(), list_size, validity, num_lists) + }; + + Ok(fsl) +} + +/// Returns `1 / norm` if the norm is non-zero, or zero otherwise. +/// +/// This avoids division by zero for zero-length or all-zero vectors. +fn safe_inv_norm(norm: T) -> T { + if norm == T::zero() { + T::zero() + } else { + T::one() / norm + } +} diff --git a/vortex-tensor/src/encodings/norm/mod.rs b/vortex-tensor/src/encodings/norm/mod.rs new file mode 100644 index 00000000000..4a060567bbc --- /dev/null +++ b/vortex-tensor/src/encodings/norm/mod.rs @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod array; +pub use array::NormVectorArray; + +// TODO: Compute operations for NormVector. + +mod vtable; +pub use vtable::NormVector; + +#[cfg(test)] +mod tests; diff --git a/vortex-tensor/src/encodings/norm/tests.rs b/vortex-tensor/src/encodings/norm/tests.rs new file mode 100644 index 00000000000..66b3cdb14dc --- /dev/null +++ b/vortex-tensor/src/encodings/norm/tests.rs @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::IntoArray; +use vortex::array::LEGACY_SESSION; +use vortex::array::VortexSessionExecute; +use vortex::array::arrays::Extension; +use vortex::array::arrays::PrimitiveArray; +use vortex::error::VortexResult; + +use crate::encodings::norm::NormVectorArray; +use crate::utils::test_helpers::assert_close; +use crate::utils::test_helpers::extract_vector_rows; +use crate::utils::test_helpers::vector_array; + +#[test] +fn encode_unit_vectors() -> VortexResult<()> { + // Already unit-length vectors: norms should be 1.0 and vectors unchanged. + let arr = vector_array( + 3, + &[ + 1.0, 0.0, 0.0, // norm = 1.0 + 0.0, 1.0, 0.0, // norm = 1.0 + ], + )?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + let norms: PrimitiveArray = norm.norms().clone().execute(&mut ctx)?; + assert_close(norms.as_slice::(), &[1.0, 1.0]); + + let rows = extract_vector_rows(norm.vector_array(), &mut ctx)?; + assert_close(&rows[0], &[1.0, 0.0, 0.0]); + assert_close(&rows[1], &[0.0, 1.0, 0.0]); + + Ok(()) +} + +#[test] +fn encode_non_unit_vectors() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 0.0, 0.0, // norm = 0.0 (zero vector) + ], + )?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + let norms: PrimitiveArray = norm.norms().clone().execute(&mut ctx)?; + assert_close(norms.as_slice::(), &[5.0, 0.0]); + + let rows = extract_vector_rows(norm.vector_array(), &mut ctx)?; + assert_close(&rows[0], &[3.0 / 5.0, 4.0 / 5.0]); + assert_close(&rows[1], &[0.0, 0.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ], + )?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + + // Execute to reconstruct the original vectors. + let reconstructed = norm.decompress(&mut ctx)?; + + // The reconstructed array should be a Vector extension array. + assert!(reconstructed.as_opt::().is_some()); + + let rows = extract_vector_rows(&reconstructed, &mut ctx)?; + assert_close(&rows[0], &[3.0, 4.0]); + assert_close(&rows[1], &[6.0, 8.0]); + + Ok(()) +} + +#[test] +fn execute_round_trip_zero_vector() -> VortexResult<()> { + let arr = vector_array(2, &[0.0, 0.0])?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let norm = NormVectorArray::compress(arr, &mut ctx)?; + + let reconstructed = norm.decompress(&mut ctx)?; + + let rows = extract_vector_rows(&reconstructed, &mut ctx)?; + // Zero vector should remain zero after round-trip. + assert_close(&rows[0], &[0.0, 0.0]); + + Ok(()) +} + +#[test] +fn scalar_at_returns_original_vector() -> VortexResult<()> { + let arr = vector_array( + 2, + &[ + 3.0, 4.0, // norm = 5.0 + 6.0, 8.0, // norm = 10.0 + ], + )?; + + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let encoded = NormVectorArray::compress(arr, &mut ctx)?; + + // `scalar_at` on the NormVectorArray should match `scalar_at` on the decompressed result. + let decompressed = encoded.decompress(&mut ctx)?; + + let norm_array = encoded.into_array(); + for i in 0..2 { + assert_eq!(norm_array.scalar_at(i)?, decompressed.scalar_at(i)?); + } + + Ok(()) +} diff --git a/vortex-tensor/src/encodings/norm/vtable/mod.rs b/vortex-tensor/src/encodings/norm/vtable/mod.rs new file mode 100644 index 00000000000..d88de77d9e9 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/mod.rs @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::hash::Hasher; +use std::sync::Arc; + +use vortex::array::ArrayEq; +use vortex::array::ArrayHash; +use vortex::array::ArrayRef; +use vortex::array::EmptyMetadata; +use vortex::array::ExecutionCtx; +use vortex::array::ExecutionResult; +use vortex::array::Precision; +use vortex::array::buffer::BufferHandle; +use vortex::array::serde::ArrayChildren; +use vortex::array::stats::StatsSetRef; +use vortex::array::vtable; +use vortex::array::vtable::Array; +use vortex::array::vtable::ArrayId; +use vortex::array::vtable::VTable; +use vortex::array::vtable::ValidityVTableFromChild; +use vortex::dtype::DType; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_ensure_eq; +use vortex::error::vortex_err; +use vortex::error::vortex_panic; +use vortex::session::VortexSession; + +use crate::encodings::norm::array::NormVectorArray; +use crate::utils::extension_element_ptype; + +mod operations; +mod validity; + +pub(super) const VECTORS_SLOT: usize = 0; +pub(super) const NORMS_SLOT: usize = 1; +pub(super) const NUM_SLOTS: usize = 2; +pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["vectors", "norms"]; + +vtable!(NormVector); + +#[derive(Debug, Clone)] +pub struct NormVector; + +impl VTable for NormVector { + type Array = NormVectorArray; + type Metadata = EmptyMetadata; + type OperationsVTable = Self; + type ValidityVTable = ValidityVTableFromChild; + + fn vtable(_array: &Self::Array) -> &Self { + &NormVector + } + + fn id(&self) -> ArrayId { + ArrayId::new_ref("vortex.tensor.norm_vector") + } + + fn len(array: &NormVectorArray) -> usize { + array.vector_array().len() + } + + fn dtype(array: &NormVectorArray) -> &DType { + array.vector_array().dtype() + } + + fn stats(array: &NormVectorArray) -> StatsSetRef<'_> { + array.stats_set.to_ref(array.as_ref()) + } + + fn array_hash(array: &NormVectorArray, state: &mut H, precision: Precision) { + array.vector_array().array_hash(state, precision); + array.norms().array_hash(state, precision); + } + + fn array_eq(array: &NormVectorArray, other: &NormVectorArray, precision: Precision) -> bool { + array.norms().array_eq(other.norms(), precision) + && array + .vector_array() + .array_eq(other.vector_array(), precision) + } + + fn nbuffers(_array: &NormVectorArray) -> usize { + 0 + } + + fn buffer(_array: &NormVectorArray, idx: usize) -> BufferHandle { + vortex_panic!("NormVectorArray has no buffers (index {idx})") + } + + fn buffer_name(_array: &NormVectorArray, idx: usize) -> Option { + vortex_panic!("NormVectorArray has no buffers (index {idx})") + } + + fn nchildren(_array: &NormVectorArray) -> usize { + 2 + } + + fn child(array: &NormVectorArray, idx: usize) -> ArrayRef { + match idx { + 0 => array.vector_array().clone(), + 1 => array.norms().clone(), + _ => vortex_panic!("NormVectorArray child index {idx} out of bounds"), + } + } + + fn child_name(_array: &NormVectorArray, idx: usize) -> String { + match idx { + 0 => "vector_array".to_string(), + 1 => "norms".to_string(), + _ => vortex_panic!("NormVectorArray child_name index {idx} out of bounds"), + } + } + + fn metadata(_array: &NormVectorArray) -> VortexResult { + Ok(EmptyMetadata) + } + + fn serialize(_metadata: Self::Metadata) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + _bytes: &[u8], + _dtype: &DType, + _len: usize, + _buffers: &[BufferHandle], + _session: &VortexSession, + ) -> VortexResult { + Ok(EmptyMetadata) + } + + fn build( + dtype: &DType, + len: usize, + _metadata: &Self::Metadata, + _buffers: &[BufferHandle], + children: &dyn ArrayChildren, + ) -> VortexResult { + vortex_ensure_eq!( + children.len(), + 2, + "NormVectorArray requires exactly 2 children" + ); + + let vector_array = children.get(0, dtype, len)?; + + let ext = dtype.as_extension_opt().ok_or_else(|| { + vortex_err!("NormVectorArray dtype must be an extension type, got {dtype}") + })?; + let element_ptype = extension_element_ptype(ext)?; + let nullability = Nullability::from(dtype.is_nullable()); + let norms_dtype = DType::Primitive(element_ptype, nullability); + let norms = children.get(1, &norms_dtype, len)?; + + NormVectorArray::try_new(vector_array, norms) + } + + fn slots(array: &Self::Array) -> &[Option] { + &array.slots + } + + fn slot_name(_array: &Self::Array, idx: usize) -> String { + SLOT_NAMES[idx].to_string() + } + + fn with_slots(array: &mut Self::Array, slots: Vec>) -> VortexResult<()> { + vortex_ensure_eq!( + slots.len(), + NUM_SLOTS, + "FixedSizeListArray expects exactly {NUM_SLOTS} slots, got {}", + slots.len() + ); + + array.slots = slots; + Ok(()) + } + + fn execute( + array: Arc>, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(ExecutionResult::done(array.decompress(ctx)?)) + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/operations.rs b/vortex-tensor/src/encodings/norm/vtable/operations.rs new file mode 100644 index 00000000000..94d27f26e84 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/operations.rs @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::Canonical; +use vortex::array::ExecutionCtx; +use vortex::array::IntoArray; +use vortex::array::arrays::ConstantArray; +use vortex::array::builtins::ArrayBuiltins; +use vortex::array::vtable::OperationsVTable; +use vortex::dtype::Nullability; +use vortex::error::VortexResult; +use vortex::error::vortex_err; +use vortex::scalar::Scalar; +use vortex::scalar_fn::fns::operators::Operator; + +use crate::encodings::norm::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; + +impl OperationsVTable for NormVector { + fn scalar_at( + array: &NormVectorArray, + index: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let ext_dtype = array + .vector_array() + .dtype() + .as_extension_opt() + .ok_or_else(|| { + vortex_err!( + "expected Vector extension dtype, got {}", + array.vector_array().dtype() + ) + })?; + let list_size = extension_list_size(ext_dtype)? as usize; + + // Get the storage (FixedSizeList) and slice out the elements for this row. + let storage = extension_storage(array.vector_array())?; + let fsl = storage.execute::(ctx)?.into_fixed_size_list(); + let row_elements = fsl.fixed_size_list_elements_at(index)?; + + // Multiply all elements by the norm using a ConstantArray broadcast. + let norm_scalar = array.norms().scalar_at(index)?; + let norm_broadcast = ConstantArray::new(norm_scalar, list_size).into_array(); + let scaled = row_elements.binary(norm_broadcast, Operator::Mul)?; + + // Rebuild the FSL scalar, then wrap in the extension type. + let element_dtype = ext_dtype + .storage_dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| { + vortex_err!( + "expected FixedSizeList storage dtype, got {}", + ext_dtype.storage_dtype() + ) + })?; + + let children: Vec = (0..list_size) + .map(|i| scaled.scalar_at(i)) + .collect::>()?; + + let fsl_scalar = + Scalar::fixed_size_list(element_dtype.clone(), children, Nullability::NonNullable); + + Ok(Scalar::extension_ref(ext_dtype.clone(), fsl_scalar)) + } +} diff --git a/vortex-tensor/src/encodings/norm/vtable/validity.rs b/vortex-tensor/src/encodings/norm/vtable/validity.rs new file mode 100644 index 00000000000..8925ffc7378 --- /dev/null +++ b/vortex-tensor/src/encodings/norm/vtable/validity.rs @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex::array::ArrayRef; +use vortex::array::vtable::ValidityChild; + +use crate::encodings::norm::array::NormVectorArray; +use crate::encodings::norm::vtable::NormVector; + +impl ValidityChild for NormVector { + fn validity_child(array: &NormVectorArray) -> &ArrayRef { + array.vector_array() + } +} diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 82e2d1f5b45..88bacf5d5ae 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -221,7 +221,6 @@ pub mod test_helpers { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - #[expect(dead_code, reason = "TODO(connor): Use this!")] /// Extracts the f64 rows from a [`Vector`] extension array. /// /// Returns a `Vec>` where each inner vec is one vector's elements.