Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enh/take along axis #11076

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dask/array/__init__.py
Expand Up @@ -158,6 +158,7 @@
vstack,
where,
)
from dask.array.slicing import take_along_axis
from dask.array.tiledb_io import from_tiledb, to_tiledb
from dask.array.ufunc import (
abs,
Expand Down
39 changes: 39 additions & 0 deletions dask/array/chunk.py
Expand Up @@ -431,3 +431,42 @@ def getitem(obj, index):
pass

return result


def take_along_axis_chunk(
arr: np.ndarray, indices: np.ndarray, offset: np.ndarray, arr_size: int, axis: int
):
"""Slice an ndarray according to ndarray indices along an axis.

Parameters
----------
arr: np.ndarray, dtype=Any
The data array.
indices: np.ndarray, dtype=int64
The indices of interest.
offset: np.ndarray, shape=(1, ), dtype=int64
Index of the first element along axis of the current chunk of arr
arr_size: int
Total size of the arr da.Array along axis
axis: int
The axis along which the indices are from.

Returns
-------
out: np.ndarray
The indexed arr.
"""
# Needed when indices is unsigned
indices = indices.astype(np.int64)
# Normalize negative indices
indices = np.where(indices < 0, indices + arr_size, indices)
# A chunk of the offset dask Array is a numpy array with shape (1, ).
# It indicates the index of the first element along axis of the current
# chunk of arr.
indices = indices - offset
# Drop elements of idx that do not fall inside the current chunk of arr.
idx_filter = (indices >= 0) & (indices < arr.shape[axis])
indices[~idx_filter] = 0
res = np.take_along_axis(arr, indices, axis=axis)
res[~idx_filter] = 0
return np.expand_dims(res, axis)
68 changes: 67 additions & 1 deletion dask/array/slicing.py
Expand Up @@ -7,16 +7,20 @@
from itertools import product
from numbers import Integral, Number
from operator import itemgetter
from typing import TYPE_CHECKING

import numpy as np
from tlz import concat, memoize, merge, pluck

from dask import config, core, utils
from dask.array.chunk import getitem
from dask.array.chunk import getitem, take_along_axis_chunk
from dask.base import is_dask_collection, tokenize
from dask.highlevelgraph import HighLevelGraph
from dask.utils import cached_cumsum, is_arraylike

if TYPE_CHECKING:
from dask.array import Array

colon = slice(None, None, None)


Expand Down Expand Up @@ -2162,3 +2166,65 @@ def setitem(x, v, indices):
) from e

return x


def take_along_axis(arr: Array, indices: Array, axis: int):
"""Slice a dask ndarray according to dask ndarray of indices along an axis.

Parameters
----------
arr: dask.array.Array, dtype=Any
Data array.
indices: dask.array.Array, dtype=int64
Indices of interest.
axis:int
The axis along which the indices are from.

Returns
-------
out: dask.array.Array
The indexed arr.
"""
from dask.array.core import Array, blockwise, from_array

if axis < 0:
axis += arr.ndim
assert 0 <= axis < arr.ndim
if np.isnan(arr.chunks[axis]).any():
raise NotImplementedError(
"take_along_axis for an array with unknown chunks with "
"a dask.array of ints is not supported"
)
# Calculate the offset at which each chunk starts along axis
# e.g. chunks=(..., (5, 3, 4), ...) -> offset=[0, 5, 8]
offset = np.roll(np.cumsum(arr.chunks[axis]), 1)
offset[0] = 0
da_offset = from_array(offset, chunks=1)
# Tamper with the declared chunks of offset to make blockwise align it with
# arr[axis]
da_offset = Array(
da_offset.dask, da_offset.name, (arr.chunks[axis],), da_offset.dtype
)
# Define axis labels for blockwise
arr_axes = tuple(range(arr.ndim))
idx_label = (arr.ndim,) # arbitrary unused
index_axes = arr_axes[:axis] + idx_label + arr_axes[axis + 1 :]
offset_axes = (axis,)
p_axes = arr_axes[: axis + 1] + idx_label + arr_axes[axis + 1 :]
# Compute take_along_axis for each chunk
# TODO: Add meta argument for blockwise ?
p = blockwise(
take_along_axis_chunk,
p_axes,
arr,
arr_axes,
indices,
index_axes,
da_offset,
offset_axes,
arr_size=arr.shape[axis],
axis=axis,
dtype=arr.dtype,
)
res = p.sum(axis=axis)
return res
41 changes: 41 additions & 0 deletions dask/array/tests/test_slicing.py
Expand Up @@ -23,6 +23,7 @@
slice_array,
slicing_plan,
take,
take_along_axis,
)
from dask.array.utils import assert_eq, same_keys

Expand Down Expand Up @@ -1054,3 +1055,43 @@ def test_slice_array_null_dimension():
array = da.from_array(np.zeros((3, 0)))
expected = np.zeros((3, 0))[[0]]
assert_eq(array[[0]], expected)


def test_take_along_axis__simple_indexing():
# GIVEN
data = da.from_array(
[[[0, 1, 2, 3], [4, 5, 6, 7]], [[8, 9, 10, 11], [12, 13, 14, 15]]],
chunks=(1, 1, 1),
)
indexes = da.from_array([[[1, 2], [0, 1]], [[1, 0], [2, 1]]], chunks=(1, 1, 1))
expected = da.from_array([[[1, 2], [4, 5]], [[9, 8], [14, 13]]], chunks=(1, 1, 1))
# WHEN
res = take_along_axis(data, indexes, axis=-1)
# THEN
assert_eq(res, expected)


def test_take_along_axis__indexing_twice_same_1darray():
# GIVEN
arr = da.from_array([[10, 20, 30, 40]], chunks=2)
idx = da.from_array([[0, 2], [2, 3]], chunks=-1)
expected = da.from_array([[10, 30], [30, 40]], chunks=2)
# WHEN
res = take_along_axis(arr, idx, axis=-1)
# THEN
assert_eq(res, expected)


def test_take_along_axis__error_indexing_has_nans():
# GIVEN
arr = da.from_array([[10, 20, 30, 40]], chunks=2)
idx = da.from_array([[0, 2], [2, 3]], chunks=-1)
arr = arr[arr > 0] # make it have a nan shape
expected_error = (
"take_along_axis for an array with unknown chunks with "
"a dask.array of ints is not supported"
)
# THEN
with pytest.raises(NotImplementedError, match=expected_error):
# WHEN
take_along_axis(arr, idx, axis=-1)
3 changes: 2 additions & 1 deletion docs/source/array-slicing.rst
Expand Up @@ -10,6 +10,8 @@ supports the following:
* Slicing one :class:`~dask.array.Array` with an :class:`~dask.array.Array` of bools: ``x[x > 0]``
* Slicing one :class:`~dask.array.Array` with a zero or one-dimensional :class:`~dask.array.Array`
of ints: ``a[b.argtopk(5)]``
* Slicing one :class:`~dask.array.Array` with a multi-dimensional :class:`~dask.array.Array` of ints.
This can be done using ``dask.array.slicing.take_along_axis``.

However, it does not currently support the following:

Expand All @@ -19,7 +21,6 @@ However, it does not currently support the following:
issue. Also, users interested in this should take a look at
:attr:`~dask.array.Array.vindex`.

* Slicing one :class:`~dask.array.Array` with a multi-dimensional :class:`~dask.array.Array` of ints

.. _array.slicing.efficiency:

Expand Down