From 8fb1dc992e9f1b50074d697cf12d350737efafb6 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 11 Feb 2026 11:25:05 -0500 Subject: [PATCH 1/2] fix sum decimal array overflow Signed-off-by: Connor Tsui --- .../src/arrays/decimal/compute/sum.rs | 133 +++++++++++------- 1 file changed, 80 insertions(+), 53 deletions(-) diff --git a/vortex-array/src/arrays/decimal/compute/sum.rs b/vortex-array/src/arrays/decimal/compute/sum.rs index e091162653e..60735d027bd 100644 --- a/vortex-array/src/arrays/decimal/compute/sum.rs +++ b/vortex-array/src/arrays/decimal/compute/sum.rs @@ -6,6 +6,8 @@ use num_traits::AsPrimitive; use num_traits::CheckedAdd; use vortex_buffer::BitBuffer; use vortex_buffer::Buffer; +use vortex_dtype::DType; +use vortex_dtype::DecimalDType; use vortex_dtype::DecimalType; use vortex_dtype::Nullability::Nullable; use vortex_dtype::match_each_decimal_value_type; @@ -24,75 +26,79 @@ use crate::expr::stats::Stat; use crate::register_kernel; impl SumKernel for DecimalVTable { - #[expect( - clippy::cognitive_complexity, - reason = "complexity from nested match_each_* macros" - )] fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult { let return_dtype = Stat::Sum .dtype(array.dtype()) .vortex_expect("sum for decimals exists"); - let return_decimal_dtype = return_dtype + let return_decimal_dtype = *return_dtype .as_decimal_opt() .vortex_expect("must be decimal"); - // Extract the initial value as a DecimalValue + // Extract the initial value as a `DecimalValue`. let initial_decimal = accumulator .as_decimal() .decimal_value() .vortex_expect("cannot be null"); - match array.validity_mask()? { - Mask::AllFalse(_) => { - vortex_bail!("invalid state, all-null array should be checked by top-level sum fn") - } - Mask::AllTrue(_) => { - let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype); - match_each_decimal_value_type!(array.values_type(), |I| { - match_each_decimal_value_type!(values_type, |O| { - let initial_val: O = initial_decimal - .cast() - .vortex_expect("cannot fail to cast initial value"); - if let Some(sum) = sum_decimal(array.buffer::(), initial_val) { - Ok(Scalar::decimal( - DecimalValue::from(sum), - *return_decimal_dtype, - Nullable, - )) - } else { - Ok(Scalar::null(return_dtype)) - } - }) - }) - } - Mask::Values(mask_values) => { - let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype); - match_each_decimal_value_type!(array.values_type(), |I| { - match_each_decimal_value_type!(values_type, |O| { - let initial_val: O = initial_decimal - .cast() - .vortex_expect("cannot fail to cast initial value"); - - if let Some(sum) = sum_decimal_with_validity( - array.buffer::(), - mask_values.bit_buffer(), - initial_val, - ) { - Ok(Scalar::decimal( - DecimalValue::from(sum), - *return_decimal_dtype, - Nullable, - )) - } else { - Ok(Scalar::null(return_dtype)) - } - }) - }) - } + let mask = array.validity_mask()?; + if matches!(&mask, Mask::AllFalse(_)) { + vortex_bail!("invalid state, all-null array should be checked by top-level sum fn") } + + let validity = match &mask { + Mask::AllTrue(_) => None, + Mask::Values(mask_values) => Some(mask_values.bit_buffer()), + Mask::AllFalse(_) => unreachable!("handled above"), + }; + + let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype); + match_each_decimal_value_type!(array.values_type(), |I| { + match_each_decimal_value_type!(values_type, |O| { + let initial_val: O = initial_decimal + .cast() + .vortex_expect("cannot fail to cast initial value"); + + Ok(sum_to_scalar( + array.buffer::(), + validity, + initial_val, + return_decimal_dtype, + &return_dtype, + )) + }) + }) } } +/// Compute the checked sum and convert the result to a [`Scalar`]. +/// +/// Returns a null scalar if the sum overflows the underlying integer type or if the result +/// exceeds the declared decimal precision. +fn sum_to_scalar( + values: Buffer, + validity: Option<&BitBuffer>, + initial: O, + return_decimal_dtype: DecimalDType, + return_dtype: &DType, +) -> Scalar +where + T: AsPrimitive, + O: Copy + CheckedAdd + Into + 'static, +{ + let raw_sum = match validity { + Some(v) => sum_decimal_with_validity(values, v, initial), + None => sum_decimal(values, initial), + }; + + raw_sum + .map(Into::::into) + // We have to make sure that the decimal value fits the precision of the decimal dtype. + .filter(|v| v.fits_in_precision(return_decimal_dtype)) + .map(|v| Scalar::decimal(v, return_decimal_dtype, Nullable)) + // If an overflow occurs during summation, or final value does not fit, then return a null. + .unwrap_or_else(|| Scalar::null(return_dtype.clone())) +} + fn sum_decimal, I: Copy + CheckedAdd + 'static>( values: Buffer, initial: I, @@ -371,6 +377,27 @@ mod tests { assert_eq!(result, expected); } + #[test] + fn test_sum_precision_overflow_without_i256_overflow() { + // Construct values that individually fit in precision 76 but whose sum exceeds it, + // while still fitting in `i256`. This ensures we return null for precision overflow + // and not just for arithmetic overflow. + let ten_to_38 = i256::from_i128(10i128.pow(38)); + let ten_to_75 = ten_to_38 * i256::from_i128(10i128.pow(37)); + // 6 * 10^75 is a 76-digit number, which fits in precision 76. + let val = ten_to_75 * i256::from_i128(6); + + let decimal_dtype = DecimalDType::new(76, 0); + let decimal = DecimalArray::new(buffer![val, val], decimal_dtype, Validity::AllValid); + + // Sum = 12 * 10^75 = 1.2 * 10^76, which exceeds precision 76 but fits in `i256`. + let result = sum(decimal.as_ref()).unwrap(); + assert_eq!( + result, + Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)) + ); + } + #[test] fn test_i256_overflow() { let decimal_dtype = DecimalDType::new(76, 0); From 838d3545a15511c2fe6a8ebc88245057bd3e6c41 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Wed, 11 Feb 2026 11:34:23 -0500 Subject: [PATCH 2/2] dont match validity twice Signed-off-by: Connor Tsui --- vortex-array/src/arrays/decimal/compute/sum.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vortex-array/src/arrays/decimal/compute/sum.rs b/vortex-array/src/arrays/decimal/compute/sum.rs index 60735d027bd..1456fa0a867 100644 --- a/vortex-array/src/arrays/decimal/compute/sum.rs +++ b/vortex-array/src/arrays/decimal/compute/sum.rs @@ -41,14 +41,12 @@ impl SumKernel for DecimalVTable { .vortex_expect("cannot be null"); let mask = array.validity_mask()?; - if matches!(&mask, Mask::AllFalse(_)) { - vortex_bail!("invalid state, all-null array should be checked by top-level sum fn") - } - let validity = match &mask { Mask::AllTrue(_) => None, Mask::Values(mask_values) => Some(mask_values.bit_buffer()), - Mask::AllFalse(_) => unreachable!("handled above"), + Mask::AllFalse(_) => { + vortex_bail!("invalid state, all-null array should be checked by top-level sum fn") + } }; let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype);