diff --git a/native/core/src/execution/shuffle/spark_unsafe/list.rs b/native/core/src/execution/shuffle/spark_unsafe/list.rs index d8e39e8b09..cc72813946 100644 --- a/native/core/src/execution/shuffle/spark_unsafe/list.rs +++ b/native/core/src/execution/shuffle/spark_unsafe/list.rs @@ -51,7 +51,8 @@ impl SparkUnsafeObject for SparkUnsafeArray { impl SparkUnsafeArray { /// Creates a `SparkUnsafeArray` which points to the given address and size in bytes. pub fn new(addr: i64) -> Self { - // Read the number of elements from the first 8 bytes. + // SAFETY: addr points to valid Spark UnsafeArray data from the JVM. + // The first 8 bytes contain the element count as a little-endian i64. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; let num_elements = i64::from_le_bytes(slice.try_into().unwrap()); @@ -83,6 +84,9 @@ impl SparkUnsafeArray { /// Returns true if the null bit at the given index of the array is set. #[inline] pub(crate) fn is_null_at(&self, index: usize) -> bool { + // SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts + // at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures + // index < num_elements, so word_offset is within the bitset region. unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64; diff --git a/native/core/src/execution/shuffle/spark_unsafe/map.rs b/native/core/src/execution/shuffle/spark_unsafe/map.rs index de2b96146b..dbb5b404aa 100644 --- a/native/core/src/execution/shuffle/spark_unsafe/map.rs +++ b/native/core/src/execution/shuffle/spark_unsafe/map.rs @@ -30,7 +30,8 @@ pub struct SparkUnsafeMap { impl SparkUnsafeMap { /// Creates a `SparkUnsafeMap` which points to the given address and size in bytes. pub(crate) fn new(addr: i64, size: i32) -> Self { - // Read the number of bytes of key array from the first 8 bytes. + // SAFETY: addr points to valid Spark UnsafeMap data from the JVM. + // The first 8 bytes contain the key array size as a little-endian i64. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; let key_array_size = i64::from_le_bytes(slice.try_into().unwrap()); diff --git a/native/core/src/execution/shuffle/spark_unsafe/row.rs b/native/core/src/execution/shuffle/spark_unsafe/row.rs index 1c121c506b..02daf6a34f 100644 --- a/native/core/src/execution/shuffle/spark_unsafe/row.rs +++ b/native/core/src/execution/shuffle/spark_unsafe/row.rs @@ -58,6 +58,19 @@ const NESTED_TYPE_BUILDER_CAPACITY: usize = 100; /// A common trait for Spark Unsafe classes that can be used to access the underlying data, /// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to /// access the underlying data with index. +/// +/// # Safety +/// +/// Implementations must ensure that: +/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory +/// - `get_element_offset()` returns a valid pointer within the row/array data region +/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format +/// - The memory remains valid for the lifetime of the object (guaranteed by JVM ownership) +/// +/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer operations but are +/// safe to call as long as: +/// - The index is within bounds (caller's responsibility) +/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data pub trait SparkUnsafeObject { /// Returns the address of the row. fn get_row_addr(&self) -> i64; @@ -77,12 +90,15 @@ pub trait SparkUnsafeObject { /// Returns boolean value at the given index of the object. fn get_boolean(&self, index: usize) -> bool { let addr = self.get_element_offset(index, 1); + // SAFETY: addr points to valid element data within the UnsafeRow/UnsafeArray region. + // The caller ensures index is within bounds. unsafe { *addr != 0 } } /// Returns byte value at the given index of the object. fn get_byte(&self, index: usize) -> i8 { let addr = self.get_element_offset(index, 1); + // SAFETY: addr points to valid element data (1 byte) within the row/array region. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) }; i8::from_le_bytes(slice.try_into().unwrap()) } @@ -90,6 +106,7 @@ pub trait SparkUnsafeObject { /// Returns short value at the given index of the object. fn get_short(&self, index: usize) -> i16 { let addr = self.get_element_offset(index, 2); + // SAFETY: addr points to valid element data (2 bytes) within the row/array region. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) }; i16::from_le_bytes(slice.try_into().unwrap()) } @@ -97,6 +114,7 @@ pub trait SparkUnsafeObject { /// Returns integer value at the given index of the object. fn get_int(&self, index: usize) -> i32 { let addr = self.get_element_offset(index, 4); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; i32::from_le_bytes(slice.try_into().unwrap()) } @@ -104,6 +122,7 @@ pub trait SparkUnsafeObject { /// Returns long value at the given index of the object. fn get_long(&self, index: usize) -> i64 { let addr = self.get_element_offset(index, 8); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; i64::from_le_bytes(slice.try_into().unwrap()) } @@ -111,6 +130,7 @@ pub trait SparkUnsafeObject { /// Returns float value at the given index of the object. fn get_float(&self, index: usize) -> f32 { let addr = self.get_element_offset(index, 4); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; f32::from_le_bytes(slice.try_into().unwrap()) } @@ -118,6 +138,7 @@ pub trait SparkUnsafeObject { /// Returns double value at the given index of the object. fn get_double(&self, index: usize) -> f64 { let addr = self.get_element_offset(index, 8); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; f64::from_le_bytes(slice.try_into().unwrap()) } @@ -126,6 +147,8 @@ pub trait SparkUnsafeObject { fn get_string(&self, index: usize) -> &str { let (offset, len) = self.get_offset_and_len(index); let addr = self.get_row_addr() + offset as i64; + // SAFETY: addr points to valid UTF-8 string data within the variable-length region. + // Offset and length are read from the fixed-length portion of the row/array. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) }; from_utf8(slice).unwrap() @@ -135,12 +158,15 @@ pub trait SparkUnsafeObject { fn get_binary(&self, index: usize) -> &[u8] { let (offset, len) = self.get_offset_and_len(index); let addr = self.get_row_addr() + offset as i64; + // SAFETY: addr points to valid binary data within the variable-length region. + // Offset and length are read from the fixed-length portion of the row/array. unsafe { std::slice::from_raw_parts(addr as *const u8, len as usize) } } /// Returns date value at the given index of the object. fn get_date(&self, index: usize) -> i32 { let addr = self.get_element_offset(index, 4); + // SAFETY: addr points to valid element data (4 bytes) within the row/array region. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; i32::from_le_bytes(slice.try_into().unwrap()) } @@ -148,6 +174,7 @@ pub trait SparkUnsafeObject { /// Returns timestamp value at the given index of the object. fn get_timestamp(&self, index: usize) -> i64 { let addr = self.get_element_offset(index, 8); + // SAFETY: addr points to valid element data (8 bytes) within the row/array region. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; i64::from_le_bytes(slice.try_into().unwrap()) } @@ -257,6 +284,9 @@ impl SparkUnsafeRow { /// Returns true if the null bit at the given index of the row is set. #[inline] pub(crate) fn is_null_at(&self, index: usize) -> bool { + // SAFETY: row_addr points to valid Spark UnsafeRow data with at least + // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. + // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64; @@ -267,6 +297,9 @@ impl SparkUnsafeRow { /// Unsets the null bit at the given index of the row, i.e., set the bit to 0 (not null). pub fn set_not_null_at(&mut self, index: usize) { + // SAFETY: row_addr points to valid Spark UnsafeRow data with at least + // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. + // Writing is safe because we have mutable access and the memory is owned by the JVM. unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64; @@ -463,6 +496,8 @@ fn append_columns( let mut row = SparkUnsafeRow::new(schema); for i in row_start..row_end { + // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least + // row_end elements. i is in [row_start, row_end) so the offset is in bounds. let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size); @@ -593,6 +628,8 @@ fn append_columns( let mut row = SparkUnsafeRow::new(schema); for i in row_start..row_end { + // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least + // row_end elements. i is in [row_start, row_end) so the offset is in bounds. let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size); @@ -613,6 +650,8 @@ fn append_columns( let mut row = SparkUnsafeRow::new(schema); for i in row_start..row_end { + // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least + // row_end elements. i is in [row_start, row_end) so the offset is in bounds. let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size); @@ -640,6 +679,8 @@ fn append_columns( let mut row = SparkUnsafeRow::new(schema); for i in row_start..row_end { + // SAFETY: row_addresses_ptr and row_sizes_ptr are JNI arrays with at least + // row_end elements. i is in [row_start, row_end) so the offset is in bounds. let row_addr = unsafe { *row_addresses_ptr.add(i) }; let row_size = unsafe { *row_sizes_ptr.add(i) }; row.point_to(row_addr, row_size);