Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions src/stratify/_vinterp.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,36 @@ cdef class _Interpolation(object):
emsg = 'Shape for z_src {} is not a subset of fz_src {}.'
raise ValueError(emsg.format(z_src.shape, fz_src.shape))

# If rising is not provided, work it out from the first two values of z_src.
# Then do the same thing for the target, and if the target is rising in the
# opposite direction to the source, then we flip the source and the source data
# on the interpolation axis.
if rising is None:
if z_src.shape[zp_axis] < 2:
raise ValueError('The rising keyword must be defined when '
'the size of the source array is <2 in '
'the interpolation axis.')
z_src_indexer = [0] * z_src.ndim
z_src_indexer[zp_axis] = slice(0, 2)
src_first_two = z_src[tuple(z_src_indexer)]
rising = src_first_two[0] <= src_first_two[1]
if len(z_target) < 2:
tgt_rising = rising
else:
if z_target.ndim == 1:
tgt_first_two = z_target[:2]
else:
tgt_axis = axis % z_target.ndim
tgt_indexer = [slice(None)] * z_target.ndim
tgt_indexer[tgt_axis] = slice(0, 2)
tgt_first_two = z_target[tuple(tgt_indexer)].ravel()[:2]
tgt_rising = tgt_first_two[0] <= tgt_first_two[1]
if tgt_rising != rising:
z_src = np.flip(z_src, axis=zp_axis)
fz_src = np.flip(fz_src, axis=zp_axis)
rising = tgt_rising
self.rising = bool(rising)

if z_target.ndim == 1:
z_target_size = z_target.shape[0]
else:
Expand All @@ -645,7 +675,6 @@ cdef class _Interpolation(object):
'got ({}) != ({}).')
raise ValueError(emsg.format(sep.join(ztsp), sep.join(zssp)))
z_target_size = zts[zp_axis]

# We are going to put the source coordinate into a 3d shape for convenience of
# Cython interface. Writing generic, fast, n-dimensional Cython code
# is not possible, but it is possible to always support a 3d array with
Expand Down Expand Up @@ -692,18 +721,6 @@ cdef class _Interpolation(object):
#: The shape of the interpolated data.
self.result_shape = tuple(result_shape)

if rising is None:
if z_src.shape[zp_axis] < 2:
raise ValueError('The rising keyword must be defined when '
'the size of the source array is <2 in '
'the interpolation axis.')
z_src_indexer = [0] * z_src.ndim
z_src_indexer[zp_axis] = slice(0, 2)
first_two = z_src[tuple(z_src_indexer)]
rising = first_two[0] <= first_two[1]

self.rising = bool(rising)

# Sometimes we want to add additional constraints on our interpolation
# and extrapolation - for example, linear extrapolation requires there
# to be two coordinates to interpolate from.
Expand Down
6 changes: 3 additions & 3 deletions src/stratify/tests/test_vinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_no_levels(self):

def test_wrong_rising_target(self):
r = self.interpolate([2, 1], [1, 2])
assert_array_equal(r, [1, np.inf])
assert_array_equal(r, [0.0, 1.0])
Comment on lines 115 to +117
Copy link
Contributor Author

@HGWright HGWright Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We know that there are extra steps taking place in this test. But it is not indicative of the simple interpolation of this data, as shown by the example below.

import numpy as np
import stratify

z_targ = np.array([2, 1])

z_source = np.array([1, 2])
data_source = z_source.copy()
out  = stratify.interpolate(z_targ, z_source, data_source)
print(out)

[2. 1.]


def test_wrong_rising_source(self):
r = self.interpolate([1, 2], [2, 1], rising=True)
Expand All @@ -124,11 +124,11 @@ def test_wrong_rising_source_and_target(self):
# If we overshoot the first level, there is no hope,
# so we end up extrapolating.
r = self.interpolate([3, 2, 1, 0], [2, 1], rising=True)
assert_array_equal(r, [np.inf, np.inf, np.inf, np.inf])
assert_array_equal(r, [-np.inf, -np.inf, 0.0, np.inf])

def test_non_monotonic_coordinate_interp(self):
result = self.interpolate([15, 5, 15.0], [10.0, 20, 0, 20])
assert_array_equal(result, [1, 2, 3])
assert_array_equal(result, [1.0, 1.0, 2.0])

def test_non_monotonic_coordinate_extrap(self):
result = self.interpolate([0, 15, 16, 17, 5, 15.0, 25], [10.0, 40, 0, 20])
Expand Down
Loading