diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 18107f84db4..151cf5167da 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -1,5 +1,7 @@ pub mod vortex_tensor +pub mod vortex_tensor::encodings + pub mod vortex_tensor::fixed_shape pub struct vortex_tensor::fixed_shape::FixedShapeTensor @@ -136,7 +138,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&se pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result @@ -166,7 +168,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self: pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs new file mode 100644 index 00000000000..090151e9226 --- /dev/null +++ b/vortex-tensor/src/encodings/mod.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Encodings for the different tensor types. + +// TODO(connor): +// pub mod norm; // Unit-normalized vectors. +// pub mod spherical; // Spherical transform on unit-normalized vectors. + +// TODO(will): +// pub mod turboquant; diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 56e96488167..c036b9854b2 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -5,8 +5,12 @@ //! including unit vectors, spherical coordinates, and similarity measures such as cosine //! similarity. +pub mod matcher; +pub mod scalar_fns; + pub mod fixed_shape; pub mod vector; -pub mod matcher; -pub mod scalar_fns; +pub mod encodings; + +mod utils; diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index cd2f158d719..e32c6dade9f 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// Cosine similarity between two columns. /// @@ -115,7 +115,7 @@ impl ScalarFnVTable for CosineSimilarity { &self, _options: &Self::Options, args: &dyn ExecutionArgs, - _ctx: &mut ExecutionCtx, + ctx: &mut ExecutionCtx, ) -> VortexResult { let lhs = args.get(0)?; let rhs = args.get(1)?; @@ -128,15 +128,15 @@ impl ScalarFnVTable for CosineSimilarity { lhs.dtype() ) })?; - let list_size = extension_list_size(ext)?; + let list_size = extension_list_size(ext)? as usize; // Extract the storage array from each extension input. We pass the storage (FSL) rather // than the extension array to avoid canonicalizing the extension wrapper. let lhs_storage = extension_storage(&lhs)?; let rhs_storage = extension_storage(&rhs)?; - let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?; - let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?; + let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?; + let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?; match_each_float_ptype!(lhs_flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) @@ -196,11 +196,11 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::cosine_similarity::CosineSimilarity; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::constant_tensor_array; - use crate::scalar_fns::utils::test_helpers::constant_vector_array; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::constant_tensor_array; + use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index e0a3bac4143..43ff5c6fd7e 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId; use vortex::scalar_fn::ScalarFnVTable; use crate::matcher::AnyTensor; -use crate::scalar_fns::utils::extension_element_ptype; -use crate::scalar_fns::utils::extension_list_size; -use crate::scalar_fns::utils::extension_storage; -use crate::scalar_fns::utils::extract_flat_elements; +use crate::utils::extension_element_ptype; +use crate::utils::extension_list_size; +use crate::utils::extension_storage; +use crate::utils::extract_flat_elements; /// L2 norm (Euclidean norm) of a tensor or vector column. /// @@ -98,7 +98,7 @@ impl ScalarFnVTable for L2Norm { &self, _options: &Self::Options, args: &dyn ExecutionArgs, - _ctx: &mut ExecutionCtx, + ctx: &mut ExecutionCtx, ) -> VortexResult { let input = args.get(0)?; let row_count = args.row_count(); @@ -110,10 +110,10 @@ impl ScalarFnVTable for L2Norm { input.dtype() ) })?; - let list_size = extension_list_size(ext)?; + let list_size = extension_list_size(ext)? as usize; let storage = extension_storage(&input)?; - let flat = extract_flat_elements(&storage, list_size)?; + let flat = extract_flat_elements(&storage, list_size, ctx)?; match_each_float_ptype!(flat.ptype(), |T| { let result: PrimitiveArray = (0..row_count) @@ -163,9 +163,9 @@ mod tests { use vortex::scalar_fn::ScalarFn; use crate::scalar_fns::l2_norm::L2Norm; - use crate::scalar_fns::utils::test_helpers::assert_close; - use crate::scalar_fns::utils::test_helpers::tensor_array; - use crate::scalar_fns::utils::test_helpers::vector_array; + use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::tensor_array; + use crate::utils::test_helpers::vector_array; /// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec`. fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult> { diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 2597f1115c8..2f56305cd53 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -5,5 +5,3 @@ pub mod cosine_similarity; pub mod l2_norm; - -mod utils; diff --git a/vortex-tensor/src/scalar_fns/utils.rs b/vortex-tensor/src/utils.rs similarity index 84% rename from vortex-tensor/src/scalar_fns/utils.rs rename to vortex-tensor/src/utils.rs index 0eb3e423ea0..82e2d1f5b45 100644 --- a/vortex-tensor/src/scalar_fns/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -2,10 +2,12 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex::array::ArrayRef; +use vortex::array::ExecutionCtx; use vortex::array::IntoArray; use vortex::array::arrays::Constant; use vortex::array::arrays::ConstantArray; use vortex::array::arrays::Extension; +use vortex::array::arrays::FixedSizeListArray; use vortex::array::arrays::PrimitiveArray; use vortex::dtype::DType; use vortex::dtype::NativePType; @@ -19,7 +21,7 @@ use vortex::error::vortex_err; /// Extracts the list size from a tensor-like extension dtype. /// /// The storage dtype must be a `FixedSizeList`. -pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { +pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else { vortex_bail!( "expected FixedSizeList storage dtype, got {}", @@ -27,7 +29,7 @@ pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult { ); }; - Ok(*list_size as usize) + Ok(*list_size) } /// Extracts the float element [`PType`] from a tensor-like extension dtype. @@ -91,13 +93,17 @@ impl FlatElements { /// /// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is /// materialized to avoid expanding it to the full column length. -pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResult { +pub fn extract_flat_elements( + storage: &ArrayRef, + list_size: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult { if let Some(constant) = storage.as_opt::() { // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge // amount of data. let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let fsl = single.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); + let fsl: FixedSizeListArray = single.execute(ctx)?; + let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; return Ok(FlatElements { elems, stride: 0, @@ -106,8 +112,8 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu } // Otherwise we have to fully expand all of the data. - let fsl = storage.to_canonical()?.into_fixed_size_list(); - let elems = fsl.elements().to_canonical()?.into_primitive(); + let fsl: FixedSizeListArray = storage.clone().execute(ctx)?; + let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; Ok(FlatElements { elems, stride: list_size, @@ -118,6 +124,7 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu #[cfg(test)] pub mod test_helpers { use vortex::array::ArrayRef; + use vortex::array::ExecutionCtx; use vortex::array::IntoArray; use vortex::array::arrays::ConstantArray; use vortex::array::arrays::ExtensionArray; @@ -128,9 +135,13 @@ pub mod test_helpers { use vortex::dtype::Nullability; use vortex::dtype::extension::ExtDType; use vortex::error::VortexResult; + use vortex::error::vortex_err; use vortex::extension::EmptyMetadata; use vortex::scalar::Scalar; + use super::extension_list_size; + use super::extension_storage; + use super::extract_flat_elements; use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; use crate::vector::Vector; @@ -210,6 +221,26 @@ pub mod test_helpers { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } + #[expect(dead_code, reason = "TODO(connor): Use this!")] + /// Extracts the f64 rows from a [`Vector`] extension array. + /// + /// Returns a `Vec>` where each inner vec is one vector's elements. + pub fn extract_vector_rows( + array: &ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult>> { + let ext = array + .dtype() + .as_extension_opt() + .ok_or_else(|| vortex_err!("expected Vector extension dtype, got {}", array.dtype()))?; + let list_size = extension_list_size(ext)? as usize; + let storage = extension_storage(array)?; + let flat = extract_flat_elements(&storage, list_size, ctx)?; + Ok((0..array.len()) + .map(|i| flat.row::(i).to_vec()) + .collect()) + } + /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` /// value, with support for NaN (NaN == NaN is considered equal). #[track_caller]