Skip to content

Commit

Permalink
BUG: complex special case for clip
Browse files Browse the repository at this point in the history
  • Loading branch information
birm committed Feb 25, 2020
1 parent a7adfda commit caea9ec
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 13 deletions.
3 changes: 3 additions & 0 deletions numpy/core/fromnumeric.py
Expand Up @@ -2039,6 +2039,9 @@ def clip(a, a_min, a_max, out=None, **kwargs):
is specified, values smaller than 0 become 0, and values larger
than 1 become 1.
If a complex array is given, the real and imaginary parts are clipped
independently.
Equivalent to but faster than ``np.minimum(a_max, np.maximum(a, a_min))``.
No check is performed to ensure ``a_min < a_max``.
Expand Down
71 changes: 61 additions & 10 deletions numpy/core/src/umath/clip.c.src
Expand Up @@ -31,19 +31,13 @@
#define _NPY_HALF_MAX(a, b) (npy_half_isnan(a) || npy_half_ge(a, b) ? (a) : (b))

/**begin repeat
* #name = FLOAT, DOUBLE, LONGDOUBLE#
* #name = FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE#
*/
#define _NPY_@name@_MIN(a, b) (npy_isnan(a) ? (a) : PyArray_MIN(a, b))
#define _NPY_@name@_MAX(a, b) (npy_isnan(a) ? (a) : PyArray_MAX(a, b))
/**end repeat**/

/**begin repeat
* #name = CFLOAT, CDOUBLE, CLONGDOUBLE#
*/
#define _NPY_@name@_MIN(a, b) (npy_isnan((a).real) || npy_isnan((a).imag) || PyArray_CLT(a, b) ? (a) : (b))
#define _NPY_@name@_MAX(a, b) (npy_isnan((a).real) || npy_isnan((a).imag) || PyArray_CGT(a, b) ? (a) : (b))
/**end repeat**/

/**begin repeat
* #name = DATETIME, TIMEDELTA#
*/
Expand All @@ -65,13 +59,11 @@
* BYTE, UBYTE, SHORT, USHORT, INT, UINT,
* LONG, ULONG, LONGLONG, ULONGLONG,
* HALF, FLOAT, DOUBLE, LONGDOUBLE,
* CFLOAT, CDOUBLE, CLONGDOUBLE,
* DATETIME, TIMEDELTA#
* #type = npy_bool,
* npy_byte, npy_ubyte, npy_short, npy_ushort, npy_int, npy_uint,
* npy_long, npy_ulong, npy_longlong, npy_ulonglong,
* npy_half, npy_float, npy_double, npy_longdouble,
* npy_cfloat, npy_cdouble, npy_clongdouble,
* npy_datetime, npy_timedelta#
*/

Expand Down Expand Up @@ -117,3 +109,62 @@ NPY_NO_EXPORT void
#undef _NPY_@name@_MIN

/**end repeat**/

/**begin repeat
*
* #name = CFLOAT, CDOUBLE, CLONGDOUBLE#
* #type =
* npy_cfloat, npy_cdouble, npy_clongdouble#
*/

#define _NPY_CLIP(x, min, max) \
_NPY_@name@_MIN(_NPY_@name@_MAX((x), (min)), (max))

NPY_NO_EXPORT void
@name@_clip(char **args, npy_intp const *dimensions, npy_intp const *steps, void *NPY_UNUSED(func))
{
if (steps[1] == 0 && steps[2] == 0) {
/* min and max are constant throughout the loop, the most common case */
/* NOTE: it may be possible to optimize these checks for nan */
@type@ min_val = *(@type@ *)args[1];
@type@ max_val = *(@type@ *)args[2];

char *ip1 = args[0], *op1 = args[3];
npy_intp is1 = steps[0], os1 = steps[3];
npy_intp n = dimensions[0];

/* contiguous, branch to let the compiler optimize */
if (is1 == sizeof(@type@) && os1 == sizeof(@type@)) {
for(npy_intp i = 0; i < n; i++, ip1 += is1, op1 += os1) {
@type@ c;
c.real = _NPY_CLIP((*(@type@ *)ip1).real, min_val.real, max_val.real);
c.imag = _NPY_CLIP((*(@type@ *)ip1).imag, min_val.imag, max_val.imag);
*(@type@ *)op1 = c;
}
}
else {
for(npy_intp i = 0; i < n; i++, ip1 += is1, op1 += os1) {
@type@ c;
c.real = _NPY_CLIP((*(@type@ *)ip1).real, min_val.real, max_val.real);
c.imag = _NPY_CLIP((*(@type@ *)ip1).imag, min_val.imag, max_val.imag);
*(@type@ *)op1 = c;
}
}
}
else {
TERNARY_LOOP {
@type@ c;
c.real = _NPY_CLIP((*(@type@ *)ip1).real, (*(@type@ *)ip2).real, (*(@type@ *)ip3).real);
c.imag = _NPY_CLIP((*(@type@ *)ip1).imag, (*(@type@ *)ip2).imag, (*(@type@ *)ip3).imag);
*(@type@ *)op1 = c;
}
}
npy_clear_floatstatus_barrier((char*)dimensions);
}

// clean up the macros we defined above
#undef _NPY_CLIP
#undef _NPY_@name@_MAX
#undef _NPY_@name@_MIN

/**end repeat**/
6 changes: 6 additions & 0 deletions numpy/core/tests/test_multiarray.py
Expand Up @@ -4426,6 +4426,12 @@ def test_nan(self):
expected = np.array([-1., np.nan, 0.5, 1., 0.25, np.nan])
assert_array_equal(result, expected)

def test_complex(self):
val = np.array([0+7j, 1+6j, 2+5j, 3+4j])
result = val.clip(1+5j, 2+6j)
expected = np.array([1+6j, 1+6j, 2+5j, 2+5j])
assert_array_equal(result, expected)


class TestCompress:
def test_axis(self):
Expand Down
7 changes: 4 additions & 3 deletions numpy/core/tests/test_numeric.py
Expand Up @@ -1566,7 +1566,8 @@ def test_simple_complex(self):
a = 3 * self._generate_data_complex(self.nr, self.nc)
m = -0.5
M = 1.
ac = self.fastclip(a, m, M)
ac = np.copy(a)
ac.real = self.fastclip(ac.real, m, M)
act = self.clip(a, m, M)
assert_array_strict_equal(ac, act)

Expand Down Expand Up @@ -2060,8 +2061,8 @@ def test_clip_property(self, data, shape):
base_shape=shape,
# Commenting out the min_dims line allows zero-dimensional arrays,
# and zero-dimensional arrays containing NaN make the test fail.
min_dims=1
min_dims=1

)
)
amin = data.draw(
Expand Down

0 comments on commit caea9ec

Please sign in to comment.