diff --git a/src/stratify/_vinterp.pyx b/src/stratify/_vinterp.pyx index f580487..0a6ae51 100644 --- a/src/stratify/_vinterp.pyx +++ b/src/stratify/_vinterp.pyx @@ -622,6 +622,36 @@ cdef class _Interpolation(object): emsg = 'Shape for z_src {} is not a subset of fz_src {}.' raise ValueError(emsg.format(z_src.shape, fz_src.shape)) + # If rising is not provided, work it out from the first two values of z_src. + # Then do the same thing for the target, and if the target is rising in the + # opposite direction to the source, then we flip the source and the source data + # on the interpolation axis. + if rising is None: + if z_src.shape[zp_axis] < 2: + raise ValueError('The rising keyword must be defined when ' + 'the size of the source array is <2 in ' + 'the interpolation axis.') + z_src_indexer = [0] * z_src.ndim + z_src_indexer[zp_axis] = slice(0, 2) + src_first_two = z_src[tuple(z_src_indexer)] + rising = src_first_two[0] <= src_first_two[1] + if len(z_target) < 2: + tgt_rising = rising + else: + if z_target.ndim == 1: + tgt_first_two = z_target[:2] + else: + tgt_axis = axis % z_target.ndim + tgt_indexer = [slice(None)] * z_target.ndim + tgt_indexer[tgt_axis] = slice(0, 2) + tgt_first_two = z_target[tuple(tgt_indexer)].ravel()[:2] + tgt_rising = tgt_first_two[0] <= tgt_first_two[1] + if tgt_rising != rising: + z_src = np.flip(z_src, axis=zp_axis) + fz_src = np.flip(fz_src, axis=zp_axis) + rising = tgt_rising + self.rising = bool(rising) + if z_target.ndim == 1: z_target_size = z_target.shape[0] else: @@ -645,7 +675,6 @@ cdef class _Interpolation(object): 'got ({}) != ({}).') raise ValueError(emsg.format(sep.join(ztsp), sep.join(zssp))) z_target_size = zts[zp_axis] - # We are going to put the source coordinate into a 3d shape for convenience of # Cython interface. Writing generic, fast, n-dimensional Cython code # is not possible, but it is possible to always support a 3d array with @@ -692,18 +721,6 @@ cdef class _Interpolation(object): #: The shape of the interpolated data. self.result_shape = tuple(result_shape) - if rising is None: - if z_src.shape[zp_axis] < 2: - raise ValueError('The rising keyword must be defined when ' - 'the size of the source array is <2 in ' - 'the interpolation axis.') - z_src_indexer = [0] * z_src.ndim - z_src_indexer[zp_axis] = slice(0, 2) - first_two = z_src[tuple(z_src_indexer)] - rising = first_two[0] <= first_two[1] - - self.rising = bool(rising) - # Sometimes we want to add additional constraints on our interpolation # and extrapolation - for example, linear extrapolation requires there # to be two coordinates to interpolate from. diff --git a/src/stratify/tests/test_vinterp.py b/src/stratify/tests/test_vinterp.py index 002bbc4..df1e411 100644 --- a/src/stratify/tests/test_vinterp.py +++ b/src/stratify/tests/test_vinterp.py @@ -114,7 +114,7 @@ def test_no_levels(self): def test_wrong_rising_target(self): r = self.interpolate([2, 1], [1, 2]) - assert_array_equal(r, [1, np.inf]) + assert_array_equal(r, [0.0, 1.0]) def test_wrong_rising_source(self): r = self.interpolate([1, 2], [2, 1], rising=True) @@ -124,11 +124,11 @@ def test_wrong_rising_source_and_target(self): # If we overshoot the first level, there is no hope, # so we end up extrapolating. r = self.interpolate([3, 2, 1, 0], [2, 1], rising=True) - assert_array_equal(r, [np.inf, np.inf, np.inf, np.inf]) + assert_array_equal(r, [-np.inf, -np.inf, 0.0, np.inf]) def test_non_monotonic_coordinate_interp(self): result = self.interpolate([15, 5, 15.0], [10.0, 20, 0, 20]) - assert_array_equal(result, [1, 2, 3]) + assert_array_equal(result, [1.0, 1.0, 2.0]) def test_non_monotonic_coordinate_extrap(self): result = self.interpolate([0, 15, 16, 17, 5, 15.0, 25], [10.0, 40, 0, 20])