Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ Same-CRS tiles skip reprojection entirely and are placed by direct coordinate al
| [Preview](xrspatial/preview.py) | Downsamples a raster to target pixel dimensions for visualization | Custom | ✅️ | ✅️ | ✅️ | 🔄 |
| [Rescale](xrspatial/normalize.py) | Min-max normalization to a target range (default [0, 1]) | Standard | ✅️ | ✅️ | ✅️ | ✅️ |
| [Standardize](xrspatial/normalize.py) | Z-score normalization (subtract mean, divide by std) | Standard | ✅️ | ✅️ | ✅️ | ✅️ |
| [rechunk_no_shuffle](xrspatial/utils.py) | Rechunk dask arrays using whole-chunk multiples (no shuffle) | Custom | 🔄 | ✅️ | 🔄 | ✅️ |

-----------

Expand Down
7 changes: 7 additions & 0 deletions docs/source/reference/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ Normalization
xrspatial.normalize.rescale
xrspatial.normalize.standardize

Rechunking
==========
.. autosummary::
:toctree: _autosummary

xrspatial.utils.rechunk_no_shuffle

Diagnostics
===========
.. autosummary::
Expand Down
165 changes: 165 additions & 0 deletions examples/user_guide/36_Rechunk_No_Shuffle.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Rechunk Without Shuffling\n",
"\n",
"When working with large dask-backed rasters, rechunking to bigger blocks can\n",
"speed up downstream operations like `slope()` or `focal_mean()` that use\n",
"`map_overlap`. But if the new chunk size is not an exact multiple of the\n",
"original, dask has to split and recombine blocks — essentially a shuffle —\n",
"which tanks performance.\n",
"\n",
"`rechunk_no_shuffle` picks the largest whole-chunk multiple that fits your\n",
"target size, so dask can merge blocks in place with zero shuffle overhead."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import dask.array as da\n",
"import xarray as xr\n",
"import xrspatial\n",
"from xrspatial.utils import rechunk_no_shuffle"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create a synthetic dask-backed raster\n",
"\n",
"Start with a 4096 x 4096 raster chunked at 256 x 256 (about 0.25 MB per\n",
"chunk for float32)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"raw = np.random.rand(4096, 4096).astype(np.float32) * 1000\n",
"dem = xr.DataArray(\n",
" da.from_array(raw, chunks=256),\n",
" dims=['y', 'x'],\n",
" coords={\n",
" 'y': np.linspace(40.0, 41.0, 4096),\n",
" 'x': np.linspace(-105.0, -104.0, 4096),\n",
" },\n",
")\n",
"print(f'Original chunks: {dem.chunks}')\n",
"print(f'Chunks per axis: {len(dem.chunks[0])} x {len(dem.chunks[1])}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Rechunk to ~64 MB target\n",
"\n",
"Each new chunk will be an exact multiple of 256, so dask just groups\n",
"existing blocks together."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"big = rechunk_no_shuffle(dem, target_mb=64)\n",
"print(f'New chunks: {big.chunks}')\n",
"print(f'Chunks per axis: {len(big.chunks[0])} x {len(big.chunks[1])}')\n",
"print(f'Block size: {big.chunks[0][0]} x {big.chunks[1][0]}')\n",
"print(f'Multiple of 256: {big.chunks[0][0] // 256}x')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using the .xrs accessor\n",
"\n",
"The same function is available directly on any DataArray."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"big_via_accessor = dem.xrs.rechunk_no_shuffle(target_mb=64)\n",
"print(f'Accessor chunks: {big_via_accessor.chunks}')\n",
"assert big.chunks == big_via_accessor.chunks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compare task graph sizes\n",
"\n",
"Fewer, larger chunks means a smaller task graph for downstream operations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from xrspatial.slope import slope\n",
"\n",
"slope_small = slope(dem)\n",
"slope_big = slope(big)\n",
"\n",
"print(f'slope() graph with original chunks: {len(dict(slope_small.data.__dask_graph__())):,} tasks')\n",
"print(f'slope() graph with rechunked: {len(dict(slope_big.data.__dask_graph__())):,} tasks')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Non-dask arrays pass through unchanged\n",
"\n",
"If the input is a plain numpy-backed DataArray, the function returns it\n",
"as-is — no copy, no error."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"numpy_dem = xr.DataArray(raw, dims=['y', 'x'])\n",
"result = rechunk_no_shuffle(numpy_dem, target_mb=64)\n",
"assert result is numpy_dem\n",
"print('Numpy passthrough: OK')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
1 change: 1 addition & 0 deletions xrspatial/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
from xrspatial.zonal import suggest_zonal_canvas as suggest_zonal_canvas # noqa
from xrspatial.reproject import merge # noqa
from xrspatial.reproject import reproject # noqa
from xrspatial.utils import rechunk_no_shuffle # noqa

import xrspatial.mcda # noqa: F401 — exposes xrspatial.mcda subpackage
import xrspatial.accessor # noqa: F401 — registers .xrs accessors
Expand Down
6 changes: 6 additions & 0 deletions xrspatial/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,12 @@ def to_geotiff(self, path, **kwargs):
from .geotiff import to_geotiff
return to_geotiff(self._obj, path, **kwargs)

# ---- Chunking ----

def rechunk_no_shuffle(self, **kwargs):
from .utils import rechunk_no_shuffle
return rechunk_no_shuffle(self._obj, **kwargs)


@xr.register_dataset_accessor("xrs")
class XrsSpatialDatasetAccessor:
Expand Down
1 change: 1 addition & 0 deletions xrspatial/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def test_dataarray_accessor_has_expected_methods(elevation):
'generate_terrain', 'perlin',
'ndvi', 'evi', 'arvi', 'savi', 'nbr', 'sipi',
'rasterize',
'rechunk_no_shuffle',
]
for name in expected:
assert name in names, f"Missing method: {name}"
Expand Down
123 changes: 123 additions & 0 deletions xrspatial/tests/test_rechunk_no_shuffle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Tests for rechunk_no_shuffle."""

import numpy as np
import pytest
import xarray as xr

from xrspatial.utils import rechunk_no_shuffle

da = pytest.importorskip("dask.array")


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _make_dask_raster(shape=(1024, 1024), chunks=256, dtype=np.float32):
data = da.zeros(shape, chunks=chunks, dtype=dtype)
dims = ['y', 'x'] if len(shape) == 2 else [f'd{i}' for i in range(len(shape))]
return xr.DataArray(data, dims=dims)


# ---------------------------------------------------------------------------
# Basic behaviour
# ---------------------------------------------------------------------------

def test_chunks_are_exact_multiples():
"""New chunks must be an integer multiple of original chunks."""
raster = _make_dask_raster(shape=(2048, 2048), chunks=128)
result = rechunk_no_shuffle(raster, target_mb=16)

for orig, new in zip(raster.chunks, result.chunks):
base = orig[0]
# every chunk (except possibly the last) should be a multiple
for c in new[:-1]:
assert c % base == 0, f"chunk {c} is not a multiple of {base}"


def test_chunks_grow():
"""Output chunks should be larger than input when target is larger."""
raster = _make_dask_raster(shape=(2048, 2048), chunks=64)
result = rechunk_no_shuffle(raster, target_mb=16)
assert result.chunks[0][0] > raster.chunks[0][0]


def test_already_large_returns_unchanged():
"""If chunks already meet or exceed target, return as-is."""
raster = _make_dask_raster(shape=(512, 512), chunks=512)
result = rechunk_no_shuffle(raster, target_mb=0.5)
assert result.chunks == raster.chunks


def test_3d_input():
"""Works with 3-D arrays (e.g. stacked bands)."""
raster = _make_dask_raster(shape=(4, 512, 512), chunks=(1, 128, 128))
result = rechunk_no_shuffle(raster, target_mb=16)
for orig, new in zip(raster.chunks, result.chunks):
base = orig[0]
for c in new[:-1]:
assert c % base == 0


def test_preserves_values():
"""Rechunked array should contain identical data."""
np.random.seed(1067)
data = da.from_array(np.random.rand(256, 256).astype(np.float32), chunks=64)
raster = xr.DataArray(data, dims=['y', 'x'])
result = rechunk_no_shuffle(raster, target_mb=1)
np.testing.assert_array_equal(raster.values, result.values)


def test_preserves_coords_and_attrs():
"""Coordinates and attributes must survive rechunking."""
data = da.zeros((256, 256), chunks=64, dtype=np.float32)
raster = xr.DataArray(
data,
dims=['y', 'x'],
coords={'y': np.arange(256), 'x': np.arange(256)},
attrs={'crs': 'EPSG:4326'},
)
result = rechunk_no_shuffle(raster, target_mb=1)
assert result.attrs == raster.attrs
xr.testing.assert_equal(result.coords.to_dataset(), raster.coords.to_dataset())


# ---------------------------------------------------------------------------
# Non-dask passthrough
# ---------------------------------------------------------------------------

def test_numpy_passthrough():
"""Numpy-backed DataArray should be returned unchanged."""
raster = xr.DataArray(np.zeros((100, 100)), dims=['y', 'x'])
result = rechunk_no_shuffle(raster, target_mb=1)
assert result is raster


# ---------------------------------------------------------------------------
# Input validation
# ---------------------------------------------------------------------------

def test_rejects_non_dataarray():
with pytest.raises(TypeError, match="expected xr.DataArray"):
rechunk_no_shuffle(np.zeros((10, 10)))


def test_rejects_nonpositive_target():
raster = _make_dask_raster()
with pytest.raises(ValueError, match="target_mb must be > 0"):
rechunk_no_shuffle(raster, target_mb=0)
with pytest.raises(ValueError, match="target_mb must be > 0"):
rechunk_no_shuffle(raster, target_mb=-1)


# ---------------------------------------------------------------------------
# Accessor integration
# ---------------------------------------------------------------------------

def test_accessor():
"""The .xrs.rechunk_no_shuffle() accessor delegates correctly."""
import xrspatial # noqa: F401 — registers accessor
raster = _make_dask_raster(shape=(1024, 1024), chunks=128)
direct = rechunk_no_shuffle(raster, target_mb=16)
via_accessor = raster.xrs.rechunk_no_shuffle(target_mb=16)
assert direct.chunks == via_accessor.chunks
Loading
Loading