Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 76 additions & 51 deletions vortex-array/src/arrays/decimal/compute/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use num_traits::AsPrimitive;
use num_traits::CheckedAdd;
use vortex_buffer::BitBuffer;
use vortex_buffer::Buffer;
use vortex_dtype::DType;
use vortex_dtype::DecimalDType;
use vortex_dtype::DecimalType;
use vortex_dtype::Nullability::Nullable;
use vortex_dtype::match_each_decimal_value_type;
Expand All @@ -24,75 +26,77 @@ use crate::expr::stats::Stat;
use crate::register_kernel;

impl SumKernel for DecimalVTable {
#[expect(
clippy::cognitive_complexity,
reason = "complexity from nested match_each_* macros"
)]
fn sum(&self, array: &DecimalArray, accumulator: &Scalar) -> VortexResult<Scalar> {
let return_dtype = Stat::Sum
.dtype(array.dtype())
.vortex_expect("sum for decimals exists");
let return_decimal_dtype = return_dtype
let return_decimal_dtype = *return_dtype
.as_decimal_opt()
.vortex_expect("must be decimal");

// Extract the initial value as a DecimalValue
// Extract the initial value as a `DecimalValue`.
let initial_decimal = accumulator
.as_decimal()
.decimal_value()
.vortex_expect("cannot be null");

match array.validity_mask()? {
let mask = array.validity_mask()?;
let validity = match &mask {
Mask::AllTrue(_) => None,
Mask::Values(mask_values) => Some(mask_values.bit_buffer()),
Mask::AllFalse(_) => {
vortex_bail!("invalid state, all-null array should be checked by top-level sum fn")
}
Mask::AllTrue(_) => {
let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype);
match_each_decimal_value_type!(array.values_type(), |I| {
match_each_decimal_value_type!(values_type, |O| {
let initial_val: O = initial_decimal
.cast()
.vortex_expect("cannot fail to cast initial value");
if let Some(sum) = sum_decimal(array.buffer::<I>(), initial_val) {
Ok(Scalar::decimal(
DecimalValue::from(sum),
*return_decimal_dtype,
Nullable,
))
} else {
Ok(Scalar::null(return_dtype))
}
})
})
}
Mask::Values(mask_values) => {
let values_type = DecimalType::smallest_decimal_value_type(return_decimal_dtype);
match_each_decimal_value_type!(array.values_type(), |I| {
match_each_decimal_value_type!(values_type, |O| {
let initial_val: O = initial_decimal
.cast()
.vortex_expect("cannot fail to cast initial value");

if let Some(sum) = sum_decimal_with_validity(
array.buffer::<I>(),
mask_values.bit_buffer(),
initial_val,
) {
Ok(Scalar::decimal(
DecimalValue::from(sum),
*return_decimal_dtype,
Nullable,
))
} else {
Ok(Scalar::null(return_dtype))
}
})
})
}
}
};

let values_type = DecimalType::smallest_decimal_value_type(&return_decimal_dtype);
match_each_decimal_value_type!(array.values_type(), |I| {
match_each_decimal_value_type!(values_type, |O| {
let initial_val: O = initial_decimal
.cast()
.vortex_expect("cannot fail to cast initial value");

Ok(sum_to_scalar(
array.buffer::<I>(),
validity,
initial_val,
return_decimal_dtype,
&return_dtype,
))
})
})
}
}

/// Compute the checked sum and convert the result to a [`Scalar`].
///
/// Returns a null scalar if the sum overflows the underlying integer type or if the result
/// exceeds the declared decimal precision.
fn sum_to_scalar<T, O>(
values: Buffer<T>,
validity: Option<&BitBuffer>,
initial: O,
return_decimal_dtype: DecimalDType,
return_dtype: &DType,
) -> Scalar
where
T: AsPrimitive<O>,
O: Copy + CheckedAdd + Into<DecimalValue> + 'static,
{
let raw_sum = match validity {
Some(v) => sum_decimal_with_validity(values, v, initial),
None => sum_decimal(values, initial),
};

raw_sum
.map(Into::<DecimalValue>::into)
// We have to make sure that the decimal value fits the precision of the decimal dtype.
.filter(|v| v.fits_in_precision(return_decimal_dtype))
.map(|v| Scalar::decimal(v, return_decimal_dtype, Nullable))
// If an overflow occurs during summation, or final value does not fit, then return a null.
.unwrap_or_else(|| Scalar::null(return_dtype.clone()))
}

fn sum_decimal<T: AsPrimitive<I>, I: Copy + CheckedAdd + 'static>(
values: Buffer<T>,
initial: I,
Expand Down Expand Up @@ -371,6 +375,27 @@ mod tests {
assert_eq!(result, expected);
}

#[test]
fn test_sum_precision_overflow_without_i256_overflow() {
// Construct values that individually fit in precision 76 but whose sum exceeds it,
// while still fitting in `i256`. This ensures we return null for precision overflow
// and not just for arithmetic overflow.
let ten_to_38 = i256::from_i128(10i128.pow(38));
let ten_to_75 = ten_to_38 * i256::from_i128(10i128.pow(37));
// 6 * 10^75 is a 76-digit number, which fits in precision 76.
let val = ten_to_75 * i256::from_i128(6);

let decimal_dtype = DecimalDType::new(76, 0);
let decimal = DecimalArray::new(buffer![val, val], decimal_dtype, Validity::AllValid);

// Sum = 12 * 10^75 = 1.2 * 10^76, which exceeds precision 76 but fits in `i256`.
let result = sum(decimal.as_ref()).unwrap();
assert_eq!(
result,
Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable))
);
}

#[test]
fn test_i256_overflow() {
let decimal_dtype = DecimalDType::new(76, 0);
Expand Down
Loading