New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add NumPy's new take_along_axis #3663
Comments
I'm going to work on it as soon as #3407 is merged. It can be implemented as a variant of #3407, although I don't think it will be possible to piggy-back on the exact same code because in take_along_axis you can have stacked elements from different chunks of x, e.g.
So I'll be forced to do something slower and less RAM-friendly by layering masked selections, and then stack them on top of each other through recursive aggregation based on
|
Also depends on #3610 as it is going to use the same trick of passing tuples of arrays across the chunk/combine/reduce functions of |
Just a heads up that I won't be able to work on this on the short term future - so if anybody wants to pick it up, he's most welcome to do so. |
This might be a good one for someone to pick up at the scipy sprint that's on in a few weeks. |
Thanks all. I've marked it as a good second issue. |
I'd be interested in taking a crack at this if no one is currently working on it. |
Crack away :) |
I want to contribute on this issue. Shall I implement this on file |
Thanks. That |
I have started working on this. |
I think a function similar to numpys |
That would be great! Thank you for working on this 😀 |
exit() exit logout exit() exit
I have defined a function for dask in |
Could create a new branch and use |
Thanks @jakirkham , I will look into it. |
FWIW Dask has a squash merge policy. So even if there are errant commits, wouldn't be too concerned about them. It's more important that the last commit reflects the code you would like to share. Mentioning this so you don't get lost in the world of git merge conflicts 😉 |
Just a comment: Unfortunately, the effort by Saanidhyavats turned out to be not parallelizable enough to be dask friendly, so this issue is wide open. |
edit: simplify implementation ContextI'm interested about having a import dask as da
data = da.arange(150, chunks=5).reshape(10,15)
indices = data.argtopk(10, axis=-1)
da.take_along_axis(data, indices, axis=-1) # not working And I don't see how to map_blocks numpy's take_along_axis. ImplementationBuilding on the ideas from @crusaderky 's comments above I think the goal here is to turn
DrawbacksAs @crusaderky already suggested:
Example of steps, using numpy# [input] data
In [196]: arr
Out[196]:
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]]])
# [input] indices of interest (could be the result of `argtopk(k =3, axis = -1)`)
In [206]: indices
Out[206]:
array([[[1, 2, 3],
[1, 2, 3]],
[[1, 2, 3],
[1, 2, 3]]])
# [intermediary output] Mask of indices with arr shape
In [217]: indices_mask
Out[217]:
array([[[False, True, True, True],
[False, True, True, True]],
[[False, True, True, True],
[False, True, True, True]]])
# [output result] Equivalent to `np.take_along_axis(arr, indices, axis=-1)`
In [218]: np.reshape(arr[indices_mask], indices.shape)
Out[218]:
array([[[ 1, 2, 3],
[ 5, 6, 7]],
[[ 9, 10, 11],
[13, 14, 15]]])
Implementation details
At chunk level, a function Once every chunk has returned it's corresponding mask, a Finally, a Final thoughtsI'm sure there are several things we can optimize here. |
In case it is useful: I had a need for a dask version of take_along_axis and implemented it in our climate index program climix. The source code is available here. I had the intention of upstreaming it, but never got around to doing it. |
Awesome Klaus, many thanks! I just replaced the sparse arrays with numpy arrays and it works as expected, probably not as memory efficient, but I don't think adding a dependency to sparse would be ok here. As a side note, I see that climix has already implemented the idea I had for distributed_percentile via argtopk, I will probably take some inspiration for that as well! |
I understand that using sparse adds a dependency, but the performance implication of using numpy arrays instead may be prohibitive. But let's discuss this in the PR. |
Would be nice to have a Dask Array implementation of the new NumPy function
take_along_axis
.The text was updated successfully, but these errors were encountered: