Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 121 additions & 114 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import warnings
from abc import abstractmethod
from collections.abc import Callable, Mapping
from dataclasses import dataclass
Expand All @@ -9,6 +8,7 @@

import dask.dataframe as dd
import numpy as np
import pandas as pd
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from shapely.geometry import MultiPolygon, Point, Polygon
Expand Down Expand Up @@ -78,7 +78,7 @@ def _get_bounding_box_corners_in_intrinsic_coordinates(

# compute the output axes of the transformation, remove c from input and output axes, return the matrix without c
# and then build an affine transformation from that
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation(
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_transformation(
element, target_coordinate_system
)
spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c)
Expand Down Expand Up @@ -142,7 +142,7 @@ def _get_polygon_in_intrinsic_coordinates(

polygon_gdf = ShapesModel.parse(GeoDataFrame(geometry=[polygon]))

m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation(
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_transformation(
element, target_coordinate_system
)
spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c)
Expand Down Expand Up @@ -186,7 +186,7 @@ def _get_polygon_in_intrinsic_coordinates(
return transform(polygon_gdf, to_coordinate_system="inverse")


def _get_axes_of_tranformation(
def _get_axes_of_transformation(
element: SpatialElement, target_coordinate_system: str
) -> tuple[ArrayLike, tuple[str, ...], tuple[str, ...]]:
"""
Expand Down Expand Up @@ -321,6 +321,11 @@ def _get_case_of_bounding_box_query(
return case


def _is_scaling_transform(m_linear: np.ndarray) -> bool:
"""True when the linear part is a diagonal (pure scaling) matrix."""
return np.allclose(m_linear, np.diag(np.diagonal(m_linear)))


@dataclass(frozen=True)
class BaseSpatialRequest:
"""Base class for spatial queries."""
Expand Down Expand Up @@ -386,6 +391,7 @@ def _bounding_box_mask_points(
axes: tuple[str, ...],
min_coordinate: list[Number] | ArrayLike,
max_coordinate: list[Number] | ArrayLike,
points_df: pd.DataFrame | None = None,
) -> list[ArrayLike]:
"""Compute a mask that is true for the points inside axis-aligned bounding boxes.

Expand All @@ -404,31 +410,35 @@ def _bounding_box_mask_points(
The lower right hand corners of the bounding boxes (i.e., the maximum coordinates along all dimensions).
Shape: (n_boxes, n_axes) or (n_axes,) for a single box.
{max_coordinate_docs}
points_df
A pre-computed pandas dataframe. Useful if the points_df has already been materialized, otherwise the methods simply
calls .compute() on the dask data frame

Returns
-------
The masks for the points inside the bounding boxes.
"""
element_axes = get_axes_names(points)

min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)

# Ensure min_coordinate and max_coordinate are 2D arrays
min_coordinate = min_coordinate[np.newaxis, :] if min_coordinate.ndim == 1 else min_coordinate
max_coordinate = max_coordinate[np.newaxis, :] if max_coordinate.ndim == 1 else max_coordinate

# Compute once here only if the caller hasn't already done so
if points_df is None:
points_df = points.compute()

n_boxes = min_coordinate.shape[0]
in_bounding_box_masks = []

for box in range(n_boxes):
box_masks = []
for axis_index, axis_name in enumerate(axes):
if axis_name not in element_axes:
continue
min_value = min_coordinate[box, axis_index]
max_value = max_coordinate[box, axis_index]
box_masks.append(points[axis_name].gt(min_value).compute() & points[axis_name].lt(max_value).compute())
col = points_df[axis_name].values
box_masks.append((col > min_value) & (col < max_value))
bounding_box_mask = np.stack(box_masks, axis=-1)
in_bounding_box_masks.append(np.all(bounding_box_mask, axis=1))
return in_bounding_box_masks
Expand Down Expand Up @@ -514,16 +524,6 @@ def _(
min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)
new_elements = {}
if sdata.points:
warnings.warn(
(
"The object has `points` element. Depending on the number of points, querying MAY suffer from "
"performance issues. Please consider filtering the object before calling this function by calling the "
"`subset()` method of `SpatialData`."
),
UserWarning,
stacklevel=2,
)
for element_type in ["points", "images", "labels", "shapes"]:
elements = getattr(sdata, element_type)
queried_elements = _dict_query_dispatcher(
Expand Down Expand Up @@ -630,7 +630,6 @@ def _(
max_coordinate: list[Number] | ArrayLike,
target_coordinate_system: str,
) -> DaskDataFrame | list[DaskDataFrame] | None:
from spatialdata import transform
from spatialdata.transformations import get_transformation

min_coordinate = _parse_list_into_array(min_coordinate)
Expand All @@ -648,100 +647,103 @@ def _(
max_coordinate=max_coordinate,
)

# get the four corners of the bounding box (2D case), or the 8 corners of the "3D bounding box" (3D case)
(intrinsic_bounding_box_corners, intrinsic_axes) = _get_bounding_box_corners_in_intrinsic_coordinates(
element=points,
axes=axes,
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
target_coordinate_system=target_coordinate_system,
m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_transformation(
points, target_coordinate_system
)
m_without_c_linear = m_without_c[:-1, :-1]
_ = _get_case_of_bounding_box_query(
m_without_c_linear,
input_axes_without_c,
output_axes_without_c,
)
min_coordinate_intrinsic = intrinsic_bounding_box_corners.min(dim="corner")
max_coordinate_intrinsic = intrinsic_bounding_box_corners.max(dim="corner")

min_coordinate_intrinsic = min_coordinate_intrinsic.data
max_coordinate_intrinsic = max_coordinate_intrinsic.data

# get the points in the intrinsic coordinate bounding box
in_intrinsic_bounding_box = _bounding_box_mask_points(
points=points,
axes=intrinsic_axes,
min_coordinate=min_coordinate_intrinsic,
max_coordinate=max_coordinate_intrinsic,
axes_adjusted, min_coordinate_adjusted, max_coordinate_adjusted = _adjust_bounding_box_to_real_axes(
axes,
min_coordinate,
max_coordinate,
output_axes_without_c,
)
if axes_adjusted != output_axes_without_c:
raise RuntimeError("This should not happen")

if not (len_df := len(in_intrinsic_bounding_box)) == (len_bb := len(min_coordinate)):
raise ValueError(
f"Length of list of dataframes `{len_df}` is not equal to the number of bounding boxes axes `{len_bb}`."
)
points_in_intrinsic_bounding_box: list[DaskDataFrame | None] = []
# materialize the points in the intrinsic coordinate system once
points_pd = points.compute()
attrs = points.attrs.copy()
for mask_np in in_intrinsic_bounding_box:
if mask_np.sum() == 0:
points_in_intrinsic_bounding_box.append(None)
else:
# TODO there is a problem when mixing dask dataframe graph with dask array graph. Need to compute for now.
# we can't compute either mask or points as when we calculate either one of them
# test_query_points_multiple_partitions will fail as the mask will be used to index each partition.
# However, if we compute and then create the dask array again we get the mixed dask graph problem.
filtered_pd = points_pd[mask_np]
points_filtered = dd.from_pandas(filtered_pd, npartitions=points.npartitions)
points_filtered.attrs.update(attrs)
points_in_intrinsic_bounding_box.append(points_filtered)
if len(points_in_intrinsic_bounding_box) == 0:
return None

# assert that the number of queried points is correct
assert len(points_in_intrinsic_bounding_box) == len(min_coordinate)

# # we have to reset the index since we have subset
# # https://stackoverflow.com/questions/61395351/how-to-reset-index-on-concatenated-dataframe-in-dask
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.assign(idx=1)
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.set_index(
# points_in_intrinsic_bounding_box.idx.cumsum() - 1
# )
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.map_partitions(
# lambda df: df.rename(index={"idx": None})
# )
# points_in_intrinsic_bounding_box = points_in_intrinsic_bounding_box.drop(columns=["idx"])

# transform the element to the query coordinate system
# checking the type of the transformation
# in the case of an identity or scaling transform, we can skip the whole
# projection into intrinsic space and reprojection into the global coordinate system
is_identity_transform = input_axes_without_c == output_axes_without_c and np.allclose(
m_without_c, np.eye(m_without_c.shape[0])
)
is_scaling_transform = input_axes_without_c == output_axes_without_c and _is_scaling_transform(m_without_c_linear)

# if the transform is identity, we can save extra for the affine transformation
if is_identity_transform:
bounding_box_masks = _bounding_box_mask_points(
points=points,
axes=axes,
min_coordinate=min_coordinate,
max_coordinate=max_coordinate,
points_df=points_pd,
)
elif is_scaling_transform:
# Pull scale factors from the diagonal and the translation from the last column
scales = np.diagonal(m_without_c_linear) # shape: (n_axes,)
translation = m_without_c[:-1, -1] # shape: (n_axes,)

# Invert the affine: x_intrinsic = (x_output - translation) / scale
min_intrinsic = (min_coordinate_adjusted - translation) / scales
max_intrinsic = (max_coordinate_adjusted - translation) / scales

# Ensure min < max after inversion (negative scale flips the interval)
min_intrinsic, max_intrinsic = (
np.minimum(min_intrinsic, max_intrinsic),
np.maximum(min_intrinsic, max_intrinsic),
)

bounding_box_masks = _bounding_box_mask_points(
points=points,
axes=tuple(input_axes_without_c),
min_coordinate=min_intrinsic,
max_coordinate=max_intrinsic,
points_df=points_pd,
)
else:
query_coordinates = points_pd.loc[:, list(input_axes_without_c)].to_numpy(copy=False)
query_coordinates = query_coordinates @ m_without_c[:-1, :-1].T + m_without_c[:-1, -1]

bounding_box_masks = []
for box_index in range(min_coordinate_adjusted.shape[0]):
bounding_box_mask = np.ones(len(points_pd), dtype=bool)
for axis_index in range(len(output_axes_without_c)):
min_value = min_coordinate_adjusted[box_index, axis_index]
max_value = max_coordinate_adjusted[box_index, axis_index]
column = query_coordinates[:, axis_index]
bounding_box_mask &= (column > min_value) & (column < max_value)
bounding_box_masks.append(bounding_box_mask)

if not (len_df := len(bounding_box_masks)) == (len_bb := len(min_coordinate)):
raise ValueError(f"Length of list of masks `{len_df}` is not equal to the number of bounding boxes `{len_bb}`.")

old_transformations = get_transformation(points, get_all=True)
assert isinstance(old_transformations, dict)
feature_key = points.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)

output: list[DaskDataFrame | None] = []
for p, min_c, max_c in zip(points_in_intrinsic_bounding_box, min_coordinate, max_coordinate, strict=True):
if p is None:
for mask_np in bounding_box_masks:
bounding_box_indices = np.flatnonzero(mask_np)
if len(bounding_box_indices) == 0:
output.append(None)
else:
points_query_coordinate_system = transform(
p, to_coordinate_system=target_coordinate_system, maintain_positioning=False
continue

# The exact mask is computed in the query coordinate system, but the returned points must stay intrinsic.
queried_points = points_pd.iloc[bounding_box_indices]
output.append(
PointsModel.parse(
dd.from_pandas(queried_points, npartitions=1),
transformations=old_transformations.copy(),
feature_key=feature_key,
)

# get a mask for the points in the bounding box
bounding_box_mask = _bounding_box_mask_points(
points=points_query_coordinate_system,
axes=axes,
min_coordinate=min_c, # type: ignore[arg-type]
max_coordinate=max_c, # type: ignore[arg-type]
)
if len(bounding_box_mask) != 1:
raise ValueError(f"Expected a single mask, got {len(bounding_box_mask)} masks. Please report this bug.")
bounding_box_indices = np.where(bounding_box_mask[0])[0]

if len(bounding_box_indices) == 0:
output.append(None)
else:
points_df = p.compute().iloc[bounding_box_indices]
old_transformations = get_transformation(p, get_all=True)
assert isinstance(old_transformations, dict)
feature_key = p.attrs.get(ATTRS_KEY, {}).get(PointsModel.FEATURE_KEY)

output.append(
PointsModel.parse(
dd.from_pandas(points_df, npartitions=1),
transformations=old_transformations.copy(),
feature_key=feature_key,
)
)
)
if len(output) == 0:
return None
if len(output) == 1:
Expand Down Expand Up @@ -791,8 +793,8 @@ def _(
)
for box_corners in intrinsic_bounding_box_corners:
bounding_box_non_axes_aligned = Polygon(box_corners.data)
indices = polygons.geometry.intersects(bounding_box_non_axes_aligned)
queried = polygons[indices]
candidate_idx = polygons.sindex.query(bounding_box_non_axes_aligned, predicate="intersects")
queried = polygons.iloc[candidate_idx]
if len(queried) == 0:
queried_polygon = None
else:
Expand Down Expand Up @@ -949,17 +951,22 @@ def _(
assert np.all(element[OLD_INDEX] == buffered.index)
else:
buffered[OLD_INDEX] = buffered.index
indices = buffered.geometry.apply(lambda x: x.intersects(polygon))
if np.sum(indices) == 0:

# Use sindex for fast candidate pre-filtering, then exact intersection check
# only on the (typically small) candidate set — same pattern as bounding_box_query.
candidate_idx = buffered.sindex.query(polygon, predicate="intersects")
if len(candidate_idx) == 0:
del buffered[OLD_INDEX]
return None
queried_shapes = element[indices]
queried_shapes.index = buffered[indices][OLD_INDEX]

queried_shapes = element.iloc[candidate_idx].copy()
queried_shapes.index = buffered.iloc[candidate_idx][OLD_INDEX]
queried_shapes.index.name = None

if clip:
if isinstance(element.geometry.iloc[0], Point):
queried_shapes = buffered[indices]
queried_shapes.index = buffered[indices][OLD_INDEX]
queried_shapes = buffered.iloc[candidate_idx].copy()
queried_shapes.index = buffered.iloc[candidate_idx][OLD_INDEX]
queried_shapes.index.name = None
queried_shapes = queried_shapes.clip(polygon_gdf, keep_geom_type=True)

Expand Down
29 changes: 29 additions & 0 deletions tests/core/query/test_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,35 @@ def _query(p: DaskDataFrame) -> DaskDataFrame:
assert np.array_equal(q0.index.compute(), q1.index.compute())


def test_query_points_multiple_boxes_in_transformed_coordinate_system():
from spatialdata.transformations import Affine

points_element = _make_points(np.array([[10, 10], [20, 30], [20, 30], [40, 50]]))
set_transformation(
points_element,
transformation=Affine(
np.array([[1, 0, 100], [0, 1, -50], [0, 0, 1]]),
input_axes=("x", "y"),
output_axes=("x", "y"),
),
to_coordinate_system="aligned",
)

points_result = bounding_box_query(
points_element,
axes=("x", "y"),
min_coordinate=np.array([[118, -22], [138, -2], [200, 200]]),
max_coordinate=np.array([[122, -18], [142, 2], [210, 210]]),
target_coordinate_system="aligned",
)

np.testing.assert_allclose(points_result[0]["x"].compute(), [20, 20])
np.testing.assert_allclose(points_result[0]["y"].compute(), [30, 30])
np.testing.assert_allclose(points_result[1]["x"].compute(), [40])
np.testing.assert_allclose(points_result[1]["y"].compute(), [50])
assert points_result[2] is None


@pytest.mark.parametrize("with_polygon_query", [True, False])
@pytest.mark.parametrize(
"name",
Expand Down
Loading