Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh/take along axis #11076

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Enh/take along axis #11076

wants to merge 3 commits into from

Conversation

bzah
Copy link
Contributor

@bzah bzah commented Apr 26, 2024

This PR adds take_along_axis that works similarly to numpy's take_along_axis.

Credit

@zklaus for providing a working solution in climix.

Example of use

from dask.array import take_along_axis

data: dask.array.Array 
top10_indices = data.argtopk(k=10, axis=-1)
top10 = take_along_axis(data, top10_indices, axis=-1)

# Equivalent to 
top10 = data.topk(k=10, axis=-1)

Performances

This is just a basic benchmark I ran on my laptop to compare how this implementation prerforms against numpy's.

import numpy as np
from dask.array import from_array
from dask.array.slicing import take_along_axis


random_arr = np.random.rand(1000,1000,1000)
dask_random_arr = from_array(random_arr)
top50 = dask_random_arr.argtopk(k=50, axis=0)
top50_np = top50.compute()

# Compute both argtopk and take_along_axis
%timeit take_along_axis(dask_random_arr, top50, axis=0).compute()
# 6.74 s ± 9.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.take_along_axis(random_arr, top50_np, axis=0)
# 864 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit take_along_axis(dask_random_arr, from_array(top50_np), axis=0).compute()
# 2.3 s ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@GPUtester
Copy link
Collaborator

Can one of the admins verify this patch?

Admins can comment ok to test to allow this one PR to run or add to allowlist to allow all future PRs from the same author to run.

@bzah
Copy link
Contributor Author

bzah commented Apr 26, 2024

edit: added benchmark for sparse implementation.
@zklaus, there are a few deviation from your original code in climix:

  • I removed sparse arrays as it's not a dependency for dask.
    Performance wise, on my laptop, it looks better without sparse.
from dask_take_along_axis import dask_take_along_axis # copy-pasted module from climix

%timeit dask_take_along_axis(dask_random_arr, from_array(top50_np), axis=0).compute()
6.33 s ± 60.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
  • I remove the check of shapes because it was breaking a use case. In particular, with this assert it break the following test dask/array/tests/test_slicing.py::test_take_along_axis__indexing_twice_same_1darray
  • I removed the meta argument passed in blockwise as it was a sparse array. I suspect it some sort of optimization to give blockwise what is the expected shape but it's not really documented.
  • Also some variable renaming and added types, docs and integration tests.

Copy link
Contributor

Unit Test Results

See test report for an extended history of previous test failures. This is useful for diagnosing flaky tests.

     15 files  ± 0       15 suites  ±0   3h 29m 26s ⏱️ + 7m 10s
 13 124 tests + 3   12 185 ✅ + 3     931 💤 ±0  8 ❌ ±0 
162 507 runs  +39  142 400 ✅ +39  20 099 💤 ±0  8 ❌ ±0 

For more details on these failures, see this check.

Results for commit 105ef10. ± Comparison against base commit dafb6ac.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add NumPy's new take_along_axis
2 participants