Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions vortex-tensor/public-api.lock
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
pub mod vortex_tensor

pub mod vortex_tensor::encodings

pub mod vortex_tensor::fixed_shape

pub struct vortex_tensor::fixed_shape::FixedShapeTensor
Expand Down Expand Up @@ -136,7 +138,7 @@ pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::arity(&se

pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName

pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>

pub fn vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

Expand Down Expand Up @@ -166,7 +168,7 @@ pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::arity(&self, _options: &Self:

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::child_name(&self, _options: &Self::Options, child_idx: usize) -> vortex_array::scalar_fn::vtable::ChildName

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, _ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>
pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::execute(&self, _options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult<vortex_array::array::ArrayRef>

pub fn vortex_tensor::scalar_fns::l2_norm::L2Norm::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result

Expand Down
11 changes: 11 additions & 0 deletions vortex-tensor/src/encodings/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

//! Encodings for the different tensor types.

// TODO(connor):
// pub mod norm; // Unit-normalized vectors.
// pub mod spherical; // Spherical transform on unit-normalized vectors.

// TODO(will):
// pub mod turboquant;
8 changes: 6 additions & 2 deletions vortex-tensor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
//! including unit vectors, spherical coordinates, and similarity measures such as cosine
//! similarity.

pub mod matcher;
pub mod scalar_fns;

pub mod fixed_shape;
pub mod vector;

pub mod matcher;
pub mod scalar_fns;
pub mod encodings;

mod utils;
26 changes: 13 additions & 13 deletions vortex-tensor/src/scalar_fns/cosine_similarity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId;
use vortex::scalar_fn::ScalarFnVTable;

use crate::matcher::AnyTensor;
use crate::scalar_fns::utils::extension_element_ptype;
use crate::scalar_fns::utils::extension_list_size;
use crate::scalar_fns::utils::extension_storage;
use crate::scalar_fns::utils::extract_flat_elements;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::extension_storage;
use crate::utils::extract_flat_elements;

/// Cosine similarity between two columns.
///
Expand Down Expand Up @@ -115,7 +115,7 @@ impl ScalarFnVTable for CosineSimilarity {
&self,
_options: &Self::Options,
args: &dyn ExecutionArgs,
_ctx: &mut ExecutionCtx,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let lhs = args.get(0)?;
let rhs = args.get(1)?;
Expand All @@ -128,15 +128,15 @@ impl ScalarFnVTable for CosineSimilarity {
lhs.dtype()
)
})?;
let list_size = extension_list_size(ext)?;
let list_size = extension_list_size(ext)? as usize;

// Extract the storage array from each extension input. We pass the storage (FSL) rather
// than the extension array to avoid canonicalizing the extension wrapper.
let lhs_storage = extension_storage(&lhs)?;
let rhs_storage = extension_storage(&rhs)?;

let lhs_flat = extract_flat_elements(&lhs_storage, list_size)?;
let rhs_flat = extract_flat_elements(&rhs_storage, list_size)?;
let lhs_flat = extract_flat_elements(&lhs_storage, list_size, ctx)?;
let rhs_flat = extract_flat_elements(&rhs_storage, list_size, ctx)?;

match_each_float_ptype!(lhs_flat.ptype(), |T| {
let result: PrimitiveArray = (0..row_count)
Expand Down Expand Up @@ -196,11 +196,11 @@ mod tests {
use vortex::scalar_fn::ScalarFn;

use crate::scalar_fns::cosine_similarity::CosineSimilarity;
use crate::scalar_fns::utils::test_helpers::assert_close;
use crate::scalar_fns::utils::test_helpers::constant_tensor_array;
use crate::scalar_fns::utils::test_helpers::constant_vector_array;
use crate::scalar_fns::utils::test_helpers::tensor_array;
use crate::scalar_fns::utils::test_helpers::vector_array;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::constant_tensor_array;
use crate::utils::test_helpers::constant_vector_array;
use crate::utils::test_helpers::tensor_array;
use crate::utils::test_helpers::vector_array;

/// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec<f64>`.
fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
Expand Down
20 changes: 10 additions & 10 deletions vortex-tensor/src/scalar_fns/l2_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ use vortex::scalar_fn::ScalarFnId;
use vortex::scalar_fn::ScalarFnVTable;

use crate::matcher::AnyTensor;
use crate::scalar_fns::utils::extension_element_ptype;
use crate::scalar_fns::utils::extension_list_size;
use crate::scalar_fns::utils::extension_storage;
use crate::scalar_fns::utils::extract_flat_elements;
use crate::utils::extension_element_ptype;
use crate::utils::extension_list_size;
use crate::utils::extension_storage;
use crate::utils::extract_flat_elements;

/// L2 norm (Euclidean norm) of a tensor or vector column.
///
Expand Down Expand Up @@ -98,7 +98,7 @@ impl ScalarFnVTable for L2Norm {
&self,
_options: &Self::Options,
args: &dyn ExecutionArgs,
_ctx: &mut ExecutionCtx,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let input = args.get(0)?;
let row_count = args.row_count();
Expand All @@ -110,10 +110,10 @@ impl ScalarFnVTable for L2Norm {
input.dtype()
)
})?;
let list_size = extension_list_size(ext)?;
let list_size = extension_list_size(ext)? as usize;

let storage = extension_storage(&input)?;
let flat = extract_flat_elements(&storage, list_size)?;
let flat = extract_flat_elements(&storage, list_size, ctx)?;

match_each_float_ptype!(flat.ptype(), |T| {
let result: PrimitiveArray = (0..row_count)
Expand Down Expand Up @@ -163,9 +163,9 @@ mod tests {
use vortex::scalar_fn::ScalarFn;

use crate::scalar_fns::l2_norm::L2Norm;
use crate::scalar_fns::utils::test_helpers::assert_close;
use crate::scalar_fns::utils::test_helpers::tensor_array;
use crate::scalar_fns::utils::test_helpers::vector_array;
use crate::utils::test_helpers::assert_close;
use crate::utils::test_helpers::tensor_array;
use crate::utils::test_helpers::vector_array;

/// Evaluates L2 norm on a tensor/vector array and returns the result as `Vec<f64>`.
fn eval_l2_norm(input: vortex::array::ArrayRef, len: usize) -> VortexResult<Vec<f64>> {
Expand Down
2 changes: 0 additions & 2 deletions vortex-tensor/src/scalar_fns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,3 @@

pub mod cosine_similarity;
pub mod l2_norm;

mod utils;
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex::array::ArrayRef;
use vortex::array::ExecutionCtx;
use vortex::array::IntoArray;
use vortex::array::arrays::Constant;
use vortex::array::arrays::ConstantArray;
use vortex::array::arrays::Extension;
use vortex::array::arrays::FixedSizeListArray;
use vortex::array::arrays::PrimitiveArray;
use vortex::dtype::DType;
use vortex::dtype::NativePType;
Expand All @@ -19,15 +21,15 @@ use vortex::error::vortex_err;
/// Extracts the list size from a tensor-like extension dtype.
///
/// The storage dtype must be a `FixedSizeList`.
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<usize> {
pub fn extension_list_size(ext: &ExtDTypeRef) -> VortexResult<u32> {
let DType::FixedSizeList(_, list_size, _) = ext.storage_dtype() else {
vortex_bail!(
"expected FixedSizeList storage dtype, got {}",
ext.storage_dtype()
);
};

Ok(*list_size as usize)
Ok(*list_size)
}

/// Extracts the float element [`PType`] from a tensor-like extension dtype.
Expand Down Expand Up @@ -91,13 +93,17 @@ impl FlatElements {
///
/// When the input is a [`ConstantArray`] (e.g., a literal query vector), only a single row is
/// materialized to avoid expanding it to the full column length.
pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResult<FlatElements> {
pub fn extract_flat_elements(
storage: &ArrayRef,
list_size: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<FlatElements> {
if let Some(constant) = storage.as_opt::<Constant>() {
// Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge
// amount of data.
let single = ConstantArray::new(constant.scalar().clone(), 1).into_array();
let fsl = single.to_canonical()?.into_fixed_size_list();
let elems = fsl.elements().to_canonical()?.into_primitive();
let fsl: FixedSizeListArray = single.execute(ctx)?;
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
return Ok(FlatElements {
elems,
stride: 0,
Expand All @@ -106,8 +112,8 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu
}

// Otherwise we have to fully expand all of the data.
let fsl = storage.to_canonical()?.into_fixed_size_list();
let elems = fsl.elements().to_canonical()?.into_primitive();
let fsl: FixedSizeListArray = storage.clone().execute(ctx)?;
let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?;
Ok(FlatElements {
elems,
stride: list_size,
Expand All @@ -118,6 +124,7 @@ pub fn extract_flat_elements(storage: &ArrayRef, list_size: usize) -> VortexResu
#[cfg(test)]
pub mod test_helpers {
use vortex::array::ArrayRef;
use vortex::array::ExecutionCtx;
use vortex::array::IntoArray;
use vortex::array::arrays::ConstantArray;
use vortex::array::arrays::ExtensionArray;
Expand All @@ -128,9 +135,13 @@ pub mod test_helpers {
use vortex::dtype::Nullability;
use vortex::dtype::extension::ExtDType;
use vortex::error::VortexResult;
use vortex::error::vortex_err;
use vortex::extension::EmptyMetadata;
use vortex::scalar::Scalar;

use super::extension_list_size;
use super::extension_storage;
use super::extract_flat_elements;
use crate::fixed_shape::FixedShapeTensor;
use crate::fixed_shape::FixedShapeTensorMetadata;
use crate::vector::Vector;
Expand Down Expand Up @@ -210,6 +221,26 @@ pub mod test_helpers {
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}

#[expect(dead_code, reason = "TODO(connor): Use this!")]
/// Extracts the f64 rows from a [`Vector`] extension array.
///
/// Returns a `Vec<Vec<f64>>` where each inner vec is one vector's elements.
pub fn extract_vector_rows(
array: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Vec<Vec<f64>>> {
let ext = array
.dtype()
.as_extension_opt()
.ok_or_else(|| vortex_err!("expected Vector extension dtype, got {}", array.dtype()))?;
let list_size = extension_list_size(ext)? as usize;
let storage = extension_storage(array)?;
let flat = extract_flat_elements(&storage, list_size, ctx)?;
Ok((0..array.len())
.map(|i| flat.row::<f64>(i).to_vec())
.collect())
}

/// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected`
/// value, with support for NaN (NaN == NaN is considered equal).
#[track_caller]
Expand Down
Loading