diff --git a/vortex-array/src/arrays/assertions.rs b/vortex-array/src/arrays/assertions.rs index ef409f2af23..9be7a8c16c6 100644 --- a/vortex-array/src/arrays/assertions.rs +++ b/vortex-array/src/arrays/assertions.rs @@ -4,11 +4,38 @@ use std::fmt::Display; use itertools::Itertools; +use vortex_error::VortexExpect; -pub fn format_indices>(indices: I) -> impl Display { +use crate::ArrayRef; +use crate::DynArray; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::LEGACY_SESSION; +use crate::RecursiveCanonical; +use crate::VortexSessionExecute; + +fn format_indices>(indices: I) -> impl Display { indices.into_iter().format(",") } +/// Executes an array to recursive canonical form with the given execution context. +fn execute_to_canonical(array: ArrayRef, ctx: &mut ExecutionCtx) -> ArrayRef { + array + .execute::(ctx) + .vortex_expect("failed to execute array to recursive canonical form") + .0 + .into_array() +} + +/// Finds indices where two arrays differ based on `scalar_at` comparison. +#[expect(clippy::unwrap_used)] +fn find_mismatched_indices(left: &ArrayRef, right: &ArrayRef) -> Vec { + assert_eq!(left.len(), right.len()); + (0..left.len()) + .filter(|i| left.scalar_at(*i).unwrap() != right.scalar_at(*i).unwrap()) + .collect() +} + /// Asserts that the scalar at position `$n` in array `$arr` equals `$expected`. /// /// This is a convenience macro for testing that avoids verbose scalar comparison code. @@ -51,37 +78,64 @@ macro_rules! assert_arrays_eq { ($left:expr, $right:expr) => {{ let left = $left.clone(); let right = $right.clone(); - if left.dtype() != right.dtype() { - panic!( - "assertion left == right failed: arrays differ in type: {} != {}.\n left: {}\n right: {}", - left.dtype(), - right.dtype(), - left.display_values(), - right.display_values() - ) - } + assert_eq!( + left.dtype(), + right.dtype(), + "assertion left == right failed: arrays differ in type: {} != {}.\n left: {}\n right: {}", + left.dtype(), + right.dtype(), + left.display_values(), + right.display_values() + ); - if left.len() != right.len() { - panic!( - "assertion left == right failed: arrays differ in length: {} != {}.\n left: {}\n right: {}", - left.len(), - right.len(), - left.display_values(), - right.display_values() - ) - } + assert_eq!( + left.len(), + right.len(), + "assertion left == right failed: arrays differ in length: {} != {}.\n left: {}\n right: {}", + left.len(), + right.len(), + left.display_values(), + right.display_values() + ); - let n = left.len(); - let mismatched_indices = (0..n) - .filter(|i| left.scalar_at(*i).unwrap() != right.scalar_at(*i).unwrap()) - .collect::>(); - if mismatched_indices.len() != 0 { - panic!( - "assertion left == right failed: arrays do not match at indices: {}.\n left: {}\n right: {}", - $crate::arrays::format_indices(mismatched_indices), - left.display_values(), - right.display_values() - ) - } + #[allow(deprecated)] + let left = left.to_array(); + #[allow(deprecated)] + let right = right.to_array(); + $crate::arrays::assert_arrays_eq_impl(&left, &right); }}; } + +/// Implementation of `assert_arrays_eq!` — called by the macro after converting inputs to +/// `ArrayRef`. +#[track_caller] +#[allow(clippy::panic)] +pub fn assert_arrays_eq_impl(left: &ArrayRef, right: &ArrayRef) { + let executed = execute_to_canonical(left.clone(), &mut LEGACY_SESSION.create_execution_ctx()); + + let left_right = find_mismatched_indices(left, right); + let executed_right = find_mismatched_indices(&executed, right); + + if !left_right.is_empty() || !executed_right.is_empty() { + let mut msg = String::new(); + if !left_right.is_empty() { + msg.push_str(&format!( + "\n left != right at indices: {}", + format_indices(left_right) + )); + } + if !executed_right.is_empty() { + msg.push_str(&format!( + "\n executed != right at indices: {}", + format_indices(executed_right) + )); + } + panic!( + "assertion failed: arrays do not match:{}\n left: {}\n right: {}\n executed: {}", + msg, + left.display_values(), + right.display_values(), + executed.display_values() + ) + } +} diff --git a/vortex-array/src/arrays/mod.rs b/vortex-array/src/arrays/mod.rs index 43f8a84d49e..947116a5f73 100644 --- a/vortex-array/src/arrays/mod.rs +++ b/vortex-array/src/arrays/mod.rs @@ -7,7 +7,7 @@ mod assertions; #[cfg(any(test, feature = "_test-harness"))] -pub use assertions::format_indices; +pub use assertions::assert_arrays_eq_impl; #[cfg(test)] mod validation_tests;