Skip to content

Commit

Permalink
Make take_along_axis an optional callback
Browse files Browse the repository at this point in the history
Closes #1440.
  • Loading branch information
Benjamin-Philip authored and josevalim committed May 13, 2024
1 parent 6139d2a commit 7c36e06
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 131 deletions.
1 change: 0 additions & 1 deletion exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,6 @@ defmodule EXLA.Backend do
{:reverse, [:tensor, :axes], [:tensor]},
{:dot, [:left, :c1, :b1, :right, :c2, :b2], [:left, :right]},
{:clip, [:tensor, :min, :max], [:tensor, :min, :max]},
{:take_along_axis, [:tensor, :indices, :axis], [:tensor, :indices]},
{:gather, [:input, :indices, :opts], [:input, :indices]},
{:select, [:pred, :on_true, :on_false], [:pred, :on_true, :on_false]},
{:conv, [:tensor, :kernel, :opts], [:tensor, :kernel]},
Expand Down
37 changes: 0 additions & 37 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1259,43 +1259,6 @@ defmodule EXLA.Defn do
Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans))
end

defp to_operator(:take_along_axis, [%Value{} = tensor, indices, axis], ans, state) do
%{shape: indices_shape} = indices_typespec = Value.get_typespec(indices)
indices_rank = tuple_size(indices_shape)

axes_range = 0..(indices_rank - 1)//1

index_vector_dim = indices_rank
slice_sizes = List.duplicate(1, indices_rank)
offset_dims = []
collapsed_slice_dims = Enum.to_list(axes_range)
start_index_map = Enum.to_list(axes_range)

new_axis_typespec = Typespec.to_shape(indices_typespec, Tuple.append(indices_shape, 1))

full_indices_typespec =
Typespec.to_shape(indices_typespec, Tuple.append(indices_shape, indices_rank))

full_indices =
axes_range
|> Enum.map(fn
^axis -> Value.reshape(indices, new_axis_typespec)
axis -> Value.iota(state.builder, axis, new_axis_typespec)
end)
|> Value.concatenate(indices_rank, full_indices_typespec)

Value.gather(
tensor,
full_indices,
index_vector_dim,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map,
expr_to_typespec(ans)
)
end

defp to_operator(:gather, [%Value{} = tensor, indices, opts], ans, _state) do
axes = Keyword.fetch!(opts, :axes)
tensor_shape = op_shape(tensor)
Expand Down
25 changes: 20 additions & 5 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14142,7 +14142,7 @@ defmodule Nx do
tensor
|> gather(gather_indices, axes: [axis])
|> transpose(axes: transpose_axes)
|> reshape(inner_shape, names: inner_names)
|> rename(inner_names)
end)
end
end
Expand Down Expand Up @@ -14302,17 +14302,32 @@ defmodule Nx do
end

opts = keyword!(opts, axis: 0)

tensor = devectorize(tensor, keep_names: false)
indices = devectorize(indices, keep_names: false)

offset = length(vectorized_axes)

axis = Nx.Shape.normalize_axis(tensor.shape, opts[:axis], tensor.names, offset)

shape = Nx.Shape.take_along_axis(tensor.shape, indices.shape, axis)
out = %{tensor | shape: shape}

result = impl!(tensor).take_along_axis(%{tensor | shape: shape}, tensor, indices, axis)
result =
Nx.Shared.optional(:take_along_axis, [tensor, indices, [axis: axis]], out, fn
tensor, indices, _opts ->
axes_range = axes(indices)
new_axis_shape = Tuple.append(shape(indices), 1)

full_indices =
axes_range
|> Enum.map(fn
^axis -> reshape(indices, new_axis_shape)
axis -> iota(new_axis_shape, axis: axis)
end)
|> concatenate(axis: rank(indices))

tensor
|> gather(full_indices)
|> rename(tensor.names)
end)

vectorize(result, vectorized_axes)
end
Expand Down
5 changes: 3 additions & 2 deletions nx/lib/nx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ defmodule Nx.Backend do
@callback clip(out :: tensor, tensor, min :: tensor, max :: tensor) :: tensor
@callback slice(out :: tensor, tensor, list, list, list) :: tensor
@callback put_slice(out :: tensor, tensor, tensor, list) :: tensor
@callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, axis) :: tensor
@callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
@callback concatenate(out :: tensor, tensor, axis) :: tensor
@callback select(out :: tensor, tensor, tensor, tensor) :: tensor
Expand Down Expand Up @@ -159,6 +158,7 @@ defmodule Nx.Backend do
@callback all_close(out :: tensor, tensor, tensor, keyword) :: tensor
@callback top_k(out :: tensor, tensor, keyword) :: tensor
@callback take(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
@callback take_along_axis(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor

@optional_callbacks [
optional: 3,
Expand All @@ -178,7 +178,8 @@ defmodule Nx.Backend do
qr: 3,
cholesky: 2,
eigh: 3,
take: 4
take: 4,
take_along_axis: 4
]

## Inspect implementation
Expand Down
37 changes: 0 additions & 37 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1939,43 +1939,6 @@ defmodule Nx.BinaryBackend do
from_binary(out, data)
end

@impl true
def take_along_axis(
%T{type: output_type} = output,
%T{shape: t_shape, type: {_, t_size} = t_type} = tensor,
%T{shape: idx_shape, type: {_, idx_size} = idx_type} = indices,
axis
) do
permutation =
tensor
|> Nx.axes()
|> List.delete(axis)
|> List.insert_at(Nx.rank(tensor) - 1, axis)

inverse_permutation = inverse_permutation(permutation)
shape_list = Tuple.to_list(output.shape)
permuted_shape = permutation |> Enum.map(&Enum.at(shape_list, &1)) |> List.to_tuple()

t_view = tensor |> to_binary() |> aggregate_axes([axis], t_shape, t_size)
idx_view = indices |> to_binary() |> aggregate_axes([axis], idx_shape, idx_size)

[t_view, idx_view]
|> Enum.zip_with(fn [data_bin, idx_bin] ->
data = binary_to_list(data_bin, t_type)

binary_to_binary(idx_bin, idx_type, output_type, fn idx ->
if idx < 0 or idx >= elem(tensor.shape, axis) do
raise ArgumentError,
"index #{idx} is out of bounds for axis #{axis} in shape #{inspect(tensor.shape)}"
end

Enum.at(data, idx)
end)
end)
|> then(&from_binary(%{output | shape: permuted_shape}, &1))
|> then(&transpose(output, &1, inverse_permutation))
end

@impl true
def gather(out, tensor, indices, opts) do
axes = opts[:axes]
Expand Down
6 changes: 0 additions & 6 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1183,12 +1183,6 @@ defmodule Nx.Defn.Expr do
expr(out, context, :put_slice, [tensor, start, slice])
end

@impl true
def take_along_axis(out, tensor, indices, axis) do
{[tensor, indices], context} = to_exprs([tensor, indices])
expr(out, context, :take_along_axis, [tensor, indices, axis])
end

@impl true
def gather(out, tensor, indices, opts) do
{[tensor, indices], context} = to_exprs([tensor, indices])
Expand Down
42 changes: 1 addition & 41 deletions nx/lib/nx/defn/grad.ex
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ defmodule Nx.Defn.Grad do
defp reduce_args(:put_slice, %{data: %{args: [arg, _, update | _]}}, acc, fun),
do: fun.(arg, fun.(update, acc))

defp reduce_args(:take_along_axis, %{data: %{args: [arg | _]}}, acc, fun),
do: fun.(arg, acc)

defp reduce_args(:gather, %{data: %{args: [arg | _]}}, acc, fun),
do: fun.(arg, acc)

Expand Down Expand Up @@ -663,44 +660,6 @@ defmodule Nx.Defn.Grad do
[{t, g}]
end

defp grad(:take_along_axis, [t, i, axis], _ans, g) do
num_elements = i |> Nx.shape() |> Tuple.product()

# Convert `i`, the take_along_axis indices, to a list of
# fully qualified (i.e. [0, 2, 1] for a {_, _, _}-shaped tensor)
# indices

indices =
0..(Nx.rank(g) - 1)//1
|> Enum.map(fn
# For the axis of interest, we'll use the actual take_along_axis indices
^axis ->
Nx.reshape(i, {num_elements, 1})

axis ->
i
|> Nx.shape()
|> Nx.iota(axis: axis)
|> Nx.reshape({num_elements, 1})
end)
|> Nx.concatenate(axis: 1)

# Since g is produced through the given indices,
# we can reshape g to be a {num_elements} shaped tensor
# which will directly correspond to each of the reshaped
# indices above
updates = Nx.reshape(g, {num_elements})

# The intuition for this grad is that for each index taken, we'll
# add the corresponding result grad to the original
g =
t
|> Expr.broadcast(0, Nx.shape(t), Nx.axes(t))
|> Nx.indexed_add(indices, updates)

[{t, g}]
end

defp grad(:gather, [t, i, opts], _ans, g) do
i_axes = opts[:axes]
i_shape = i.shape
Expand All @@ -714,6 +673,7 @@ defmodule Nx.Defn.Grad do

g =
0
|> Nx.as_type(t.type)
|> Nx.broadcast(t_shape)
|> Nx.indexed_add(indices, updates, opts)

Expand Down
4 changes: 2 additions & 2 deletions torchx/lib/torchx/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,12 @@ defmodule Torchx.Backend do
end

@impl true
def take_along_axis(out, tensor, idx, axis) do
def take_along_axis(out, tensor, idx, opts) do
idx_tx = idx |> from_nx() |> Torchx.to_type(:long)

tensor
|> from_nx()
|> Torchx.gather(idx_tx, axis)
|> Torchx.gather(idx_tx, opts[:axis])
|> to_nx(out)
end

Expand Down

0 comments on commit 7c36e06

Please sign in to comment.