-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: stats._xp_mean
, an array API compatible mean
with weights
and nan_policy
#20743
base: main
Are you sure you want to change the base?
Conversation
(xp_mean_1samp, tuple(), dict(), 1, 1, False, lambda x: (x,)), | ||
(xp_mean_2samp, tuple(), dict(), 2, 1, True, lambda x: (x,)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most scipy.stats
functions use the _axis_nan_policy
decorator to implement nan_policy
, keepdims
, and tuple axis
. I've implemented all these features natively for improved performance (e.g. nan_policy='omit'
would otherwise loop over each slice), and the function still passes all the tests, which are quite stringent. So if you don't want to write tests with hypothesis
, I'm still pretty comfortable with this.
scipy/_lib/_array_api.py
Outdated
|
||
if weights is not None and x.shape != weights.shape: | ||
try: | ||
x, weights = xp.broadcast_arrays(x, weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few thoughts about broadcasting:
- Technically
x = [1, 2, 3]
is broadcastable withweights = [2]
, and it can be interpreted as giving all observations a weight of2
. - Technically,
x = [1]
is broadcastable withweights = [1, 2, 3]
: now we havex
being broadcast to the shape ofweights
rather than the (more natural) other way around. - Technically
x = []
is broadcastable withweights = [1]
:weights
gets broadcasted to shape(0,)
, and the weighted mean is NaN.
It's clearly simpler to just accept these sorts of things, but since they're not useful, one could argue that we shouldn't. I'd propose that we just accept them, but if there are strong opinions about not accepting them, LMK.
scipy/_lib/_array_api.py
Outdated
@@ -475,3 +476,155 @@ def xp_sign(x, xp=None): | |||
sign = xp.where(x < 0, -one, sign) | |||
sign = xp.where(x == 0, 0*one, sign) | |||
return sign | |||
|
|||
|
|||
def xp_add_reduced_axes(res, axis, initial_shape, *, xp=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a note on why this is needed? Is it temporary, why can't xp.add
not be used, etc.?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type annotations and consistency with other functions in this file would be useful too (at least if you expect this function to stay around for a while).
res
should preferably be positional-only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps a better name would have been xp_replace_reduced_axes
or xp_keepdims
: it adds back axes that have been reduced away. However, when there are other comments to respond to, I'll just move the logic back into xp_mean
, since I'm not sure if it will be used elsewhere. It can be factored out again as needed. Although the comment wasn't about xp_mean
, I can make the first argument of xp_mean
positional-only.
Why was |
|
Thanks for the link @lucascolley . To align with the naming convention of |
mean
with weights
and nan_policy
xp_mean
, an array API compatible mean
with weights
and nan_policy
xp_mean
, an array API compatible mean
with weights
and nan_policy
stats._xp_mean
, an array API compatible mean
with weights
and nan_policy
@@ -406,7 +422,7 @@ def unpacker(res): | |||
res = hypotest(*data1d, *args, nan_policy=nan_policy, **kwds) | |||
res_1db = unpacker(res) | |||
|
|||
assert_equal(res_1db, res_1da) | |||
assert_allclose(res_1db, res_1da, 1e-15) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: can we pass the tol as a kwarg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. If there are other things to change, I can commit that then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert_allclose(res_1db, res_1da, 1e-15) | |
assert_allclose(res_1db, res_1da, rtol=1e-15) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all seems pretty reasonable!
scipy/stats/tests/test_stats.py
Outdated
# Check for warning if omitting NaNs causes empty slice | ||
message = 'After omitting NaNs...' | ||
with pytest.warns(RuntimeWarning, match=message): | ||
res = _xp_mean(x * np.nan, nan_policy='omit') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
res = _xp_mean(x * np.nan, nan_policy='omit') | |
res = _xp_mean(x * np.nan, nan_policy='omit') |
scipy/stats/tests/test_stats.py
Outdated
# it's really a `SmallSampleWarning`, but not sure | ||
# where it will be imported from yet | ||
message = 'One or more sample arguments is too small...' | ||
with pytest.warns(SmallSampleWarning, match=message): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for my understanding, can you explain this comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment will be removed. I had this all in _lib
so I couldn't import SmallSampleWarning
.
scipy/_lib/_util.py
Outdated
@@ -707,7 +707,7 @@ def _nan_allsame(a, axis, keepdims=False): | |||
return ((a0 == a) | np.isnan(a)).all(axis=axis, keepdims=keepdims) | |||
|
|||
|
|||
def _contains_nan(a, nan_policy='propagate', policies=None, *, xp=None): | |||
def _contains_nan(a, nan_policy='propagate', policies=None, *, xp_ok=False, xp=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you briefly explain the intended semantics of xp_ok
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Temporarily, while _axis_nan_policy
does not handle non-NumPy arrays, other functions that call _contains_nan
want it to raise an error if nan_policy='omit'
and xp
is not np
.
_xp_mean
supports nan_policy='omit'
natively, so setting this to True
prevents the error from being raised.
In this name, xp
was intended to imply xp
other than NumPy. Other possibilities include xp_omit_ok
and non_numpy_omit_ok
. Or we could take another perspective for naming the variable... maybe consider the name to indicate whether the calling function implements nan_policy='omit'
itself or whether this function should raise when xp
is not NumPy and nan_policy='omit'
. Feel free to suggest a preferred name.
Since this function is private , the need for the argument is temporary, and the argument will probably only ever be needed by a handful of functions, I'll probably just change it. Another possibility is to eliminate the argument and just try/except the error as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I should just try
/except
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the argument sounds fine. xp_omit_okay
is probably the clearest name IMO. In any case, it would be nice to add a short docstring to explain. But feel free to just explain to reviewers each time if you would rather not add a docstring.
scipy/stats/_stats_py.py
Outdated
arrays will be broadcasted before performing the calculation. See | ||
Notes for details. | ||
keepdims : boolean, optional | ||
If this is set to True, the axes which are reduced are left |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is set to True, the axes which are reduced are left | |
If this is set to ``True``, the axes which are reduced are left |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, that would be my preference. In the past, other reviewers have criticized double backticks for short literals, so I have become inconsistent as I go back and forth between that advice and my natural tendency (code should render in monospaced font, and True
is code).
Note that I was able to get numpy/numpydoc#525 merged last week, and I opened pydata/pydata-sphinx-theme#1852, but I did not handle this aspect of the issue. I'll go ahead and open another issue along these lines in the numpydoc
repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I'm certainly +1 for the record
|
||
# convert integers to the default float of the array library | ||
if not xp.isdtype(x.dtype, 'real floating'): | ||
dtype = xp.asarray(1.).dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think we should add an xp_default_float
helper at some point. Now may be a good time for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I'm adding something like that in an upcoming PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lucascolley Please see xp_broadcast_promote
in gh-20935.
else too_small_nd_not_omit) | ||
if xp_size(x) == 0: | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which warning is this catching?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whatever xp.mean
with an empty argument wants to emit. It is not consistent among array libraries. This makes it consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool. Let's leave this comment unresolved for anyone looking this PR up from the allowed filter list.
scipy/stats/_stats_py.py
Outdated
message = (too_small_1d_omit if (x.ndim == 1 or axis is None) | ||
else too_small_nd_omit) | ||
if contains_nan and nan_policy == 'omit': | ||
i = xp.isnan(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor: i
is not a very readable var name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose I learned my ABCs before learning to read other words, so I find it readable : )
But feel free to suggest a preferred name if the meaning of the variable will be difficult to interpret in the context of this if
block. i_nan
? nan_mask
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I definitely prefer nan_mask
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed just for you as a thank you for reviewing : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor comments but otherwise looks good
scipy/stats/_stats_py.py
Outdated
Parameters | ||
---------- | ||
x : real floating array |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps misleading given the casting of integers?
x : real floating array | |
x : real array |
scipy/stats/_stats_py.py
Outdated
.. math:: | ||
\bar{x}_w = \frac{ \sum_{i=0}^{n-1} w_i x_i } | ||
{ \sum_{i=0}^{n-1} i w_i } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo?
{ \sum_{i=0}^{n-1} i w_i } | |
{ \sum_{i=0}^{n-1} w_i } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup
scipy/stats/_stats_py.py
Outdated
warnings.warn(message, SmallSampleWarning, stacklevel=2) | ||
return res | ||
|
||
# avoid circular import |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left over comment from when this lived in _util
?
# avoid circular import |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
scipy/stats/_stats_py.py
Outdated
else too_small_nd_omit) | ||
if contains_nan and nan_policy == 'omit': | ||
i = xp.isnan(x) | ||
i = (i | xp.isnan(weights)) if weights is not None else i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
optional
i = (i | xp.isnan(weights)) if weights is not None else i | |
if weights is not None: | |
i |= xp.isnan(weights) |
Responses to comments committed. Thanks @lucascolley @j-bowhay! When this is in, would either of you like to tackle conversion of one or more of the other mean functions? I'd be happy to review. |
Reference issue
Toward gh-20544
What does this implement/fix?
This function adds
_xp_mean
, an array-API compatible function which combines the features ofnp.mean
,np.average
, andnp.nanmean
in interface that fits withscipy.stats
. This will be needed for making functions likepmean
,hmean
, andgmean
array-API compatible.Additional information
Potential reviewers: would you be willing to write some unit tests withhypothesis
? For such a fundamental function, it's particularly important that it works flawlessly!If it doesn't sound too crazy, I'd suggest that this and similarvar
andstd
functions be added publicly toscipy.stats
because they provide functionality that does not exist with the array API (e.g.weights
, which has been explicitly rejected, andnan_policy
, which has not been standardized and may not follow SciPy's convention). Even considering NumPy alone, it would be useful to have a single function that has all the functionality ofmean
,average
, andnanmean
in an interface consistent with the rest ofscipy.stats
.Not pursuing these things right now. Let's just get this in so we can finish the other mean functions.