diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index a4d9c38b60e..15b0fbc9326 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -237,7 +237,7 @@ impl GroupedAccumulator { if validity.value(offset) { let group = elements.slice(offset..offset + size)?; accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.finish()?)?; + states.append_scalar(&accumulator.flush()?)?; } else { states.append_null() } @@ -309,7 +309,7 @@ impl GroupedAccumulator { if validity.value(i) { let group = elements.slice(offset..offset + size)?; accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.finish()?)?; + states.append_scalar(&accumulator.flush()?)?; } else { states.append_null() } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs new file mode 100644 index 00000000000..3c483372186 --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -0,0 +1,267 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_error::VortexExpect; +use vortex_error::VortexResult; + +use crate::ArrayRef; +use crate::Columnar; +use crate::ExecutionCtx; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::dtype::DType; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::scalar::Scalar; + +/// Return the count of non-null elements in an array. +/// +/// See [`Count`] for details. +pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + let result = acc.finish()?; + + Ok(result + .as_primitive() + .typed_value::() + .vortex_expect("count result should not be null")) +} + +/// Count the number of non-null elements in an array. +/// +/// Applies to all types. Returns a `u64` count. +/// The identity value is zero. +#[derive(Clone, Debug)] +pub struct Count; + +impl AggregateFnVTable for Count { + type Options = EmptyOptions; + type Partial = u64; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.count") + } + + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + _metadata: &[u8], + _session: &vortex_session::VortexSession, + ) -> VortexResult { + Ok(EmptyOptions) + } + + fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option { + Some(DType::Primitive(PType::U64, Nullability::NonNullable)) + } + + fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option { + self.return_dtype(options, input_dtype) + } + + fn empty_partial( + &self, + _options: &Self::Options, + _input_dtype: &DType, + ) -> VortexResult { + Ok(0u64) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + let val = other + .as_primitive() + .typed_value::() + .vortex_expect("count partial should not be null"); + *partial += val; + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + Ok(Scalar::primitive(*partial, Nullability::NonNullable)) + } + + fn reset(&self, partial: &mut Self::Partial) { + *partial = 0; + } + + #[inline] + fn is_saturated(&self, _partial: &Self::Partial) -> bool { + false + } + + fn accumulate( + &self, + partial: &mut Self::Partial, + batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + match batch { + Columnar::Constant(c) => { + if !c.scalar().is_null() { + *partial += c.len() as u64; + } + } + Columnar::Canonical(c) => { + let valid = c.as_ref().valid_count()?; + *partial += valid as u64; + } + } + Ok(()) + } + + fn finalize(&self, partials: ArrayRef) -> VortexResult { + Ok(partials) + } + + fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { + self.to_scalar(partial) + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::count::Count; + use crate::aggregate_fn::fns::count::count; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::PrimitiveArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::scalar::Scalar; + use crate::validity::Validity; + + #[test] + fn count_all_valid() -> VortexResult<()> { + let array = + PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 5); + Ok(()) + } + + #[test] + fn count_with_nulls() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)]) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 3); + Ok(()) + } + + #[test] + fn count_all_null() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array, &mut ctx)?, 0); + Ok(()) + } + + #[test] + fn count_empty() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(0)); + Ok(()) + } + + #[test] + fn count_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().typed_value::(), Some(3)); + Ok(()) + } + + #[test] + fn count_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + let result1 = acc.finish()?; + assert_eq!(result1.as_primitive().typed_value::(), Some(1)); + + let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + let result2 = acc.finish()?; + assert_eq!(result2.as_primitive().typed_value::(), Some(2)); + Ok(()) + } + + #[test] + fn count_state_merge() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); + let mut state = Count.empty_partial(&EmptyOptions, &dtype)?; + + let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable); + Count.combine_partials(&mut state, scalar1)?; + + let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable); + Count.combine_partials(&mut state, scalar2)?; + + let result = Count.to_scalar(&state)?; + Count.reset(&mut state); + assert_eq!(result.as_primitive().typed_value::(), Some(8)); + Ok(()) + } + + #[test] + fn count_constant_non_null() -> VortexResult<()> { + let array = ConstantArray::new(42i32, 10); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array.into_array(), &mut ctx)?, 10); + Ok(()) + } + + #[test] + fn count_constant_null() -> VortexResult<()> { + let array = ConstantArray::new( + Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), + 10, + ); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&array.into_array(), &mut ctx)?, 0); + Ok(()) + } + + #[test] + fn count_chunked() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]); + let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs new file mode 100644 index 00000000000..024fb4ea16b --- /dev/null +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -0,0 +1,461 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use num_traits::ToPrimitive; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_mask::Mask; + +use crate::ArrayRef; +use crate::Canonical; +use crate::Columnar; +use crate::DynArray; +use crate::ExecutionCtx; +use crate::IntoArray; +use crate::aggregate_fn::Accumulator; +use crate::aggregate_fn::AggregateFnId; +use crate::aggregate_fn::AggregateFnVTable; +use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::EmptyOptions; +use crate::arrays::PrimitiveArray; +use crate::canonical::ToCanonical; +use crate::dtype::DType; +use crate::dtype::FieldName; +use crate::dtype::Nullability; +use crate::dtype::PType; +use crate::dtype::StructFields; +use crate::match_each_native_ptype; +use crate::scalar::Scalar; +use crate::validity::Validity; + +/// Compute the arithmetic mean of an array. +/// +/// See [`Mean`] for details. +pub fn mean(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let mut acc = Accumulator::try_new(Mean, EmptyOptions, array.dtype().clone())?; + acc.accumulate(array, ctx)?; + acc.finish() +} + +/// Compute the arithmetic mean of an array, returning `f64`. +/// +/// Applies to boolean and primitive numeric types. Returns a nullable `f64`. +/// Internally tracks sum and count, returning `sum / count` on finalize. +/// If there are no valid elements, returns null. +/// +/// The partial state is a struct `{sum: f64, count: u64}` so that partials from +/// different accumulators can be correctly combined via weighted addition. +#[derive(Clone, Debug)] +pub struct Mean; + +/// Internal accumulation state for [`Mean`]. +pub struct MeanPartial { + sum: f64, + count: u64, +} + +fn partial_struct_dtype() -> DType { + DType::Struct( + StructFields::new( + [FieldName::from("sum"), FieldName::from("count")].into(), + vec![ + DType::Primitive(PType::F64, Nullability::NonNullable), + DType::Primitive(PType::U64, Nullability::NonNullable), + ], + ), + Nullability::Nullable, + ) +} + +impl AggregateFnVTable for Mean { + type Options = EmptyOptions; + type Partial = MeanPartial; + + fn id(&self) -> AggregateFnId { + AggregateFnId::new_ref("vortex.mean") + } + + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + _metadata: &[u8], + _session: &vortex_session::VortexSession, + ) -> VortexResult { + Ok(EmptyOptions) + } + + fn return_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) | DType::Primitive(..) => { + Some(DType::Primitive(PType::F64, Nullability::Nullable)) + } + _ => None, + } + } + + fn partial_dtype(&self, _options: &Self::Options, input_dtype: &DType) -> Option { + match input_dtype { + DType::Bool(_) | DType::Primitive(..) => Some(partial_struct_dtype()), + _ => None, + } + } + + fn empty_partial( + &self, + _options: &Self::Options, + _input_dtype: &DType, + ) -> VortexResult { + Ok(MeanPartial { sum: 0.0, count: 0 }) + } + + fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { + if other.is_null() { + return Ok(()); + } + let s = other.as_struct(); + let sum_scalar = s + .field("sum") + .vortex_expect("mean partial must have sum field"); + let count_scalar = s + .field("count") + .vortex_expect("mean partial must have count field"); + + partial.sum += sum_scalar + .as_primitive() + .typed_value::() + .vortex_expect("sum field should not be null"); + partial.count += count_scalar + .as_primitive() + .typed_value::() + .vortex_expect("count field should not be null"); + Ok(()) + } + + fn to_scalar(&self, partial: &Self::Partial) -> VortexResult { + if partial.count == 0 { + Ok(Scalar::null(partial_struct_dtype())) + } else { + Ok(Scalar::struct_( + partial_struct_dtype(), + vec![ + Scalar::primitive(partial.sum, Nullability::NonNullable), + Scalar::primitive(partial.count, Nullability::NonNullable), + ], + )) + } + } + + fn reset(&self, partial: &mut Self::Partial) { + partial.sum = 0.0; + partial.count = 0; + } + + #[inline] + fn is_saturated(&self, _partial: &Self::Partial) -> bool { + false + } + + fn accumulate( + &self, + partial: &mut Self::Partial, + batch: &Columnar, + _ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + match batch { + Columnar::Constant(c) => { + if !c.scalar().is_null() { + let val = scalar_to_f64(c.scalar())?; + partial.sum += val * c.len() as f64; + partial.count += c.len() as u64; + } + } + Columnar::Canonical(canonical) => match canonical { + Canonical::Primitive(prim) => { + let mask = prim.validity_mask()?; + match_each_native_ptype!(prim.ptype(), |T| { + accumulate_values(partial, prim.as_slice::(), &mask); + }); + } + Canonical::Bool(bool_arr) => { + let mask = bool_arr.validity_mask()?; + let bits = bool_arr.to_bit_buffer(); + match &mask { + Mask::AllTrue(_) => { + partial.sum += bits.true_count() as f64; + partial.count += bool_arr.len() as u64; + } + Mask::AllFalse(_) => {} + Mask::Values(validity) => { + let valid_count = validity.true_count(); + let valid_and_true = (&bits & validity.bit_buffer()).true_count(); + partial.sum += valid_and_true as f64; + partial.count += valid_count as u64; + } + } + } + _ => vortex_bail!("Unsupported canonical type for mean: {}", batch.dtype()), + }, + } + Ok(()) + } + + fn finalize(&self, partials: ArrayRef) -> VortexResult { + let struct_arr = partials.to_struct(); + let sums = struct_arr.unmasked_field_by_name("sum")?; + let counts = struct_arr.unmasked_field_by_name("count")?; + let validity_mask = struct_arr.validity_mask()?; + + let sum_prim = sums.to_primitive(); + let count_prim = counts.to_primitive(); + let sum_values = sum_prim.as_slice::(); + let count_values = count_prim.as_slice::(); + + let means: vortex_buffer::Buffer = sum_values + .iter() + .zip(count_values.iter()) + .map(|(s, c)| if *c == 0 { 0.0 } else { s / *c as f64 }) + .collect(); + + // A mean is valid when the group itself was valid AND had at least one + // non-null element (count > 0). + let validity = Validity::from_iter( + count_values + .iter() + .enumerate() + .map(|(i, c)| validity_mask.value(i) && *c > 0), + ); + + Ok(PrimitiveArray::new(means, validity).into_array()) + } + + fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult { + if partial.count == 0 { + Ok(Scalar::null(DType::Primitive( + PType::F64, + Nullability::Nullable, + ))) + } else { + Ok(Scalar::primitive( + partial.sum / partial.count as f64, + Nullability::Nullable, + )) + } + } +} + +fn scalar_to_f64(scalar: &Scalar) -> VortexResult { + match scalar.dtype() { + DType::Bool(_) => { + let v = scalar.as_bool().value().vortex_expect("checked non-null"); + Ok(if v { 1.0 } else { 0.0 }) + } + DType::Primitive(..) => f64::try_from(scalar), + _ => vortex_bail!("Cannot convert {} to f64 for mean", scalar.dtype()), + } +} + +fn accumulate_values(partial: &mut MeanPartial, values: &[T], mask: &Mask) { + match mask { + Mask::AllTrue(_) => { + partial.count += values.len() as u64; + for v in values { + partial.sum += v.to_f64().unwrap_or(0.0); + } + } + Mask::AllFalse(_) => {} + Mask::Values(v) => { + for (val, valid) in values.iter().zip(v.bit_buffer().iter()) { + if valid { + partial.count += 1; + partial.sum += val.to_f64().unwrap_or(0.0); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_error::VortexResult; + + use crate::IntoArray; + use crate::LEGACY_SESSION; + use crate::VortexSessionExecute; + use crate::aggregate_fn::Accumulator; + use crate::aggregate_fn::AggregateFnVTable; + use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::fns::mean::Mean; + use crate::aggregate_fn::fns::mean::mean; + use crate::aggregate_fn::fns::mean::partial_struct_dtype; + use crate::arrays::BoolArray; + use crate::arrays::ChunkedArray; + use crate::arrays::ConstantArray; + use crate::arrays::PrimitiveArray; + use crate::dtype::DType; + use crate::dtype::Nullability; + use crate::dtype::PType; + use crate::scalar::Scalar; + use crate::validity::Validity; + + #[test] + fn mean_all_valid() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0, 4.0, 5.0], Validity::NonNullable) + .into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_with_nulls() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter([Some(2.0f64), None, Some(4.0)]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_all_null() -> VortexResult<()> { + let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn mean_empty() -> VortexResult<()> { + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; + let result = acc.finish()?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn mean_integers() -> VortexResult<()> { + let array = PrimitiveArray::new(buffer![10i32, 20, 30], Validity::NonNullable).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array, &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(20.0)); + Ok(()) + } + + #[test] + fn mean_multi_batch() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; + + let batch1 = + PrimitiveArray::new(buffer![1.0f64, 2.0, 3.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + + let batch2 = PrimitiveArray::new(buffer![4.0f64, 5.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + + let result = acc.finish()?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } + + #[test] + fn mean_finish_resets_state() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut acc = Accumulator::try_new(Mean, EmptyOptions, dtype)?; + + let batch1 = PrimitiveArray::new(buffer![2.0f64, 4.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch1, &mut ctx)?; + let result1 = acc.finish()?; + assert_eq!(result1.as_primitive().as_::(), Some(3.0)); + + let batch2 = + PrimitiveArray::new(buffer![10.0f64, 20.0, 30.0], Validity::NonNullable).into_array(); + acc.accumulate(&batch2, &mut ctx)?; + let result2 = acc.finish()?; + assert_eq!(result2.as_primitive().as_::(), Some(20.0)); + Ok(()) + } + + #[test] + fn mean_state_merge() -> VortexResult<()> { + let dtype = DType::Primitive(PType::F64, Nullability::NonNullable); + let mut state = Mean.empty_partial(&EmptyOptions, &dtype)?; + + // Partition 1: mean of [2, 4] → sum=6, count=2 + let partial1 = Scalar::struct_( + partial_struct_dtype(), + vec![ + Scalar::primitive(6.0f64, Nullability::NonNullable), + Scalar::primitive(2u64, Nullability::NonNullable), + ], + ); + Mean.combine_partials(&mut state, partial1)?; + + // Partition 2: mean of [10, 20, 30] → sum=60, count=3 + let partial2 = Scalar::struct_( + partial_struct_dtype(), + vec![ + Scalar::primitive(60.0f64, Nullability::NonNullable), + Scalar::primitive(3u64, Nullability::NonNullable), + ], + ); + Mean.combine_partials(&mut state, partial2)?; + + // Combined: (6 + 60) / (2 + 3) = 13.2 + let result = Mean.finalize_scalar(&state)?; + assert_eq!(result.as_primitive().as_::(), Some(13.2)); + Ok(()) + } + + #[test] + fn mean_constant_non_null() -> VortexResult<()> { + let array = ConstantArray::new(5.0f64, 4); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(5.0)); + Ok(()) + } + + #[test] + fn mean_constant_null() -> VortexResult<()> { + let array = ConstantArray::new( + Scalar::null(DType::Primitive(PType::F64, Nullability::Nullable)), + 10, + ); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert!(result.is_null()); + Ok(()) + } + + #[test] + fn mean_bool() -> VortexResult<()> { + let array: BoolArray = [true, false, true, true].into_iter().collect(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&array.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(0.75)); + Ok(()) + } + + #[test] + fn mean_chunked() -> VortexResult<()> { + let chunk1 = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(3.0)]); + let chunk2 = PrimitiveArray::from_option_iter([Some(5.0f64), None]); + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let result = mean(&chunked.into_array(), &mut ctx)?; + assert_eq!(result.as_primitive().as_::(), Some(3.0)); + Ok(()) + } +} diff --git a/vortex-array/src/aggregate_fn/fns/mod.rs b/vortex-array/src/aggregate_fn/fns/mod.rs index 4c233ba4d27..4e6df22299a 100644 --- a/vortex-array/src/aggregate_fn/fns/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mod.rs @@ -1,8 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +pub mod count; pub mod is_constant; pub mod is_sorted; +pub mod mean; pub mod min_max; pub mod nan_count; pub mod sum;