Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: stats._xp_mean, an array API compatible mean with weights and nan_policy #20743

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented May 18, 2024

Reference issue

Toward gh-20544

What does this implement/fix?

This function adds _xp_mean, an array-API compatible function which combines the features of np.mean, np.average, and np.nanmean in interface that fits with scipy.stats. This will be needed for making functions like pmean, hmean, and gmean array-API compatible.

Additional information

Potential reviewers: would you be willing to write some unit tests with hypothesis? For such a fundamental function, it's particularly important that it works flawlessly!

If it doesn't sound too crazy, I'd suggest that this and similar var and std functions be added publicly to scipy.stats because they provide functionality that does not exist with the array API (e.g. weights, which has been explicitly rejected, and nan_policy, which has not been standardized and may not follow SciPy's convention). Even considering NumPy alone, it would be useful to have a single function that has all the functionality of mean, average, and nanmean in an interface consistent with the rest of scipy.stats.

Not pursuing these things right now. Let's just get this in so we can finish the other mean functions.

@mdhaber mdhaber added scipy.stats enhancement A new feature or improvement array types Items related to array API support and input array validation (see gh-18286) labels May 18, 2024
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
scipy/_lib/tests/test_array_api.py Outdated Show resolved Hide resolved
@mdhaber mdhaber marked this pull request as ready for review May 18, 2024 23:28
Comment on lines +126 to +127
(xp_mean_1samp, tuple(), dict(), 1, 1, False, lambda x: (x,)),
(xp_mean_2samp, tuple(), dict(), 2, 1, True, lambda x: (x,)),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most scipy.stats functions use the _axis_nan_policy decorator to implement nan_policy, keepdims, and tuple axis. I've implemented all these features natively for improved performance (e.g. nan_policy='omit' would otherwise loop over each slice), and the function still passes all the tests, which are quite stringent. So if you don't want to write tests with hypothesis, I'm still pretty comfortable with this.


if weights is not None and x.shape != weights.shape:
try:
x, weights = xp.broadcast_arrays(x, weights)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few thoughts about broadcasting:

  • Technically x = [1, 2, 3] is broadcastable with weights = [2], and it can be interpreted as giving all observations a weight of 2.
  • Technically, x = [1] is broadcastable with weights = [1, 2, 3]: now we have x being broadcast to the shape of weights rather than the (more natural) other way around.
  • Technically x = [] is broadcastable with weights = [1]: weights gets broadcasted to shape (0,), and the weighted mean is NaN.

It's clearly simpler to just accept these sorts of things, but since they're not useful, one could argue that we shouldn't. I'd propose that we just accept them, but if there are strong opinions about not accepting them, LMK.

@@ -475,3 +476,155 @@ def xp_sign(x, xp=None):
sign = xp.where(x < 0, -one, sign)
sign = xp.where(x == 0, 0*one, sign)
return sign


def xp_add_reduced_axes(res, axis, initial_shape, *, xp=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a note on why this is needed? Is it temporary, why can't xp.add not be used, etc.?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type annotations and consistency with other functions in this file would be useful too (at least if you expect this function to stay around for a while).

res should preferably be positional-only.

Copy link
Contributor Author

@mdhaber mdhaber May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a better name would have been xp_replace_reduced_axes or xp_keepdims: it adds back axes that have been reduced away. However, when there are other comments to respond to, I'll just move the logic back into xp_mean, since I'm not sure if it will be used elsewhere. It can be factored out again as needed. Although the comment wasn't about xp_mean, I can make the first argument of xp_mean positional-only.

@mdhaber mdhaber mentioned this pull request May 19, 2024
74 tasks
@fancidev
Copy link
Contributor

Why was weights explicitly rejected for the Array API? Would you by chance have a link or something for the discussion back then?

@lucascolley
Copy link
Member

Why was weights explicitly rejected for the Array API? Would you by chance have a link or something for the discussion back then?

data-apis/array-api#366

@fancidev
Copy link
Contributor

Thanks for the link @lucascolley .

To align with the naming convention of hmean, pmean, and gmean, would it be more appropriate to call the function amean (a for arithmetic)?

@lucascolley lucascolley changed the title ENH: xp_mean: an array-API compatible mean with weights and nan_policy ENH: xp_mean, an array API compatible mean with weights and nan_policy Jun 2, 2024
@mdhaber mdhaber changed the title ENH: xp_mean, an array API compatible mean with weights and nan_policy ENH: stats._xp_mean, an array API compatible mean with weights and nan_policy Jun 9, 2024
@@ -406,7 +422,7 @@ def unpacker(res):
res = hypotest(*data1d, *args, nan_policy=nan_policy, **kwds)
res_1db = unpacker(res)

assert_equal(res_1db, res_1da)
assert_allclose(res_1db, res_1da, 1e-15)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: can we pass the tol as a kwarg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. If there are other things to change, I can commit that then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert_allclose(res_1db, res_1da, 1e-15)
assert_allclose(res_1db, res_1da, rtol=1e-15)

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all seems pretty reasonable!

# Check for warning if omitting NaNs causes empty slice
message = 'After omitting NaNs...'
with pytest.warns(RuntimeWarning, match=message):
res = _xp_mean(x * np.nan, nan_policy='omit')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
res = _xp_mean(x * np.nan, nan_policy='omit')
res = _xp_mean(x * np.nan, nan_policy='omit')

Comment on lines 9163 to 9166
# it's really a `SmallSampleWarning`, but not sure
# where it will be imported from yet
message = 'One or more sample arguments is too small...'
with pytest.warns(SmallSampleWarning, match=message):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for my understanding, can you explain this comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment will be removed. I had this all in _lib so I couldn't import SmallSampleWarning.

@@ -707,7 +707,7 @@ def _nan_allsame(a, axis, keepdims=False):
return ((a0 == a) | np.isnan(a)).all(axis=axis, keepdims=keepdims)


def _contains_nan(a, nan_policy='propagate', policies=None, *, xp=None):
def _contains_nan(a, nan_policy='propagate', policies=None, *, xp_ok=False, xp=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you briefly explain the intended semantics of xp_ok?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporarily, while _axis_nan_policy does not handle non-NumPy arrays, other functions that call _contains_nan want it to raise an error if nan_policy='omit' and xp is not np.

_xp_mean supports nan_policy='omit' natively, so setting this to True prevents the error from being raised.

In this name, xp was intended to imply xp other than NumPy. Other possibilities include xp_omit_ok and non_numpy_omit_ok. Or we could take another perspective for naming the variable... maybe consider the name to indicate whether the calling function implements nan_policy='omit' itself or whether this function should raise when xp is not NumPy and nan_policy='omit'. Feel free to suggest a preferred name.

Since this function is private , the need for the argument is temporary, and the argument will probably only ever be needed by a handful of functions, I'll probably just change it. Another possibility is to eliminate the argument and just try/except the error as needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I should just try/except

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the argument sounds fine. xp_omit_okay is probably the clearest name IMO. In any case, it would be nice to add a short docstring to explain. But feel free to just explain to reviewers each time if you would rather not add a docstring.

arrays will be broadcasted before performing the calculation. See
Notes for details.
keepdims : boolean, optional
If this is set to True, the axes which are reduced are left
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If this is set to True, the axes which are reduced are left
If this is set to ``True``, the axes which are reduced are left

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, that would be my preference. In the past, other reviewers have criticized double backticks for short literals, so I have become inconsistent as I go back and forth between that advice and my natural tendency (code should render in monospaced font, and True is code).

Note that I was able to get numpy/numpydoc#525 merged last week, and I opened pydata/pydata-sphinx-theme#1852, but I did not handle this aspect of the issue. I'll go ahead and open another issue along these lines in the numpydoc repo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I'm certainly +1 for the record


# convert integers to the default float of the array library
if not xp.isdtype(x.dtype, 'real floating'):
dtype = xp.asarray(1.).dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we should add an xp_default_float helper at some point. Now may be a good time for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm adding something like that in an upcoming PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucascolley Please see xp_broadcast_promote in gh-20935.

else too_small_nd_not_omit)
if xp_size(x) == 0:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which warning is this catching?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whatever xp.mean with an empty argument wants to emit. It is not consistent among array libraries. This makes it consistent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Let's leave this comment unresolved for anyone looking this PR up from the allowed filter list.

message = (too_small_1d_omit if (x.ndim == 1 or axis is None)
else too_small_nd_omit)
if contains_nan and nan_policy == 'omit':
i = xp.isnan(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: i is not a very readable var name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose I learned my ABCs before learning to read other words, so I find it readable : )
But feel free to suggest a preferred name if the meaning of the variable will be difficult to interpret in the context of this if block. i_nan? nan_mask?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely prefer nan_mask!

Copy link
Contributor Author

@mdhaber mdhaber Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed just for you as a thank you for reviewing : )

Copy link
Member

@j-bowhay j-bowhay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments but otherwise looks good

Parameters
----------
x : real floating array
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps misleading given the casting of integers?

Suggested change
x : real floating array
x : real array

.. math::
\bar{x}_w = \frac{ \sum_{i=0}^{n-1} w_i x_i }
{ \sum_{i=0}^{n-1} i w_i }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

Suggested change
{ \sum_{i=0}^{n-1} i w_i }
{ \sum_{i=0}^{n-1} w_i }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup

warnings.warn(message, SmallSampleWarning, stacklevel=2)
return res

# avoid circular import
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left over comment from when this lived in _util?

Suggested change
# avoid circular import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

else too_small_nd_omit)
if contains_nan and nan_policy == 'omit':
i = xp.isnan(x)
i = (i | xp.isnan(weights)) if weights is not None else i
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional

Suggested change
i = (i | xp.isnan(weights)) if weights is not None else i
if weights is not None:
i |= xp.isnan(weights)

@mdhaber
Copy link
Contributor Author

mdhaber commented Jun 11, 2024

Responses to comments committed. Thanks @lucascolley @j-bowhay!

When this is in, would either of you like to tackle conversion of one or more of the other mean functions? I'd be happy to review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy._lib scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants