Skip to content

Commit

Permalink
Improve gather docs, closes #1443
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed May 12, 2024
1 parent dde9152 commit ef1c08d
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14132,7 +14132,7 @@ defmodule Nx do
indices = devectorize(indices, keep_names: false)
out = %{tensor | shape: inner_shape, names: inner_names}

Nx.Shared.optional(:take, [tensor, indices, [axis: axis]], out, fn tensor, indices, _opts ->
Nx.Shared.optional(:take, [tensor, indices, axis], out, fn tensor, indices, axis ->
gather_indices = new_axis(indices, rank(indices))
{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
{leading, trailing} = Enum.split(tensor_axes, axis)
Expand Down Expand Up @@ -14321,12 +14321,20 @@ defmodule Nx do
Builds a new tensor by taking individual values from the original
tensor at the given indices.
Indices must be a tensor where the last dimension is usually of the
same size as the `tensor` rank. Each entry in `indices` will be
part of the results. If the last dimension of indices is less than
the `tensor` rank, then a multidimensional tensor is gathered and
spliced into the result.
## Options
* `:axes` - controls which dimensions the indexes apply to.
It must be a sorted list of axes and be of the same size
as the second (last) dimension of the indexes tensor.
It defaults to the leading axes of the tensor.
* `:axes` - controls to which dimensions of `tensor`
each element in the last dimension of `indexes` applies to.
It defaults so the first element in indexes apply to the first
axis, the second to the second, and so on. It must be a sorted
list of axes and be of the same size as the last dimension of
the indexes tensor.
## Examples
Expand Down

0 comments on commit ef1c08d

Please sign in to comment.