Skip to content

Commit

Permalink
Make take an optional callback
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed May 12, 2024
1 parent 50ecf0a commit b8f2254
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 16 deletions.
55 changes: 47 additions & 8 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,49 @@ defmodule EXLA.Defn do

defp cached_recur_operator(
:optional,
%T{data: %Expr{args: [%{data: %{op: :top_k, args: [tensor, opts]}}, expr, _callback]}} =
_out,
%T{
data: %Expr{
args: [%{data: %{op: :take, args: [tensor, indices, opts]}}, expr, _callback]
}
},
state,
cache
) do
axis = opts[:axis]
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()
{indices, cache} = recur_operator(indices, state, cache) |> unwrap_single_tensor!()

tensor_rank = tensor |> op_shape() |> tuple_size()
indices_rank = indices |> op_shape() |> tuple_size()
result_rank = tensor_rank - 1 + indices_rank

index_vector_dim = indices_rank
slice_sizes = tensor |> op_shape() |> put_elem(axis, 1) |> Tuple.to_list()

{left, right} = result_rank |> axes_for_rank() |> Enum.split(axis)
offset_dims = left ++ Enum.drop(right, indices_rank)

collapsed_slice_dims = [axis]
start_index_map = [axis]

result =
Value.gather(
tensor,
indices,
index_vector_dim,
slice_sizes,
offset_dims,
collapsed_slice_dims,
start_index_map,
expr_to_typespec(expr)
)

{result, cache}
end

defp cached_recur_operator(
:optional,
%T{data: %Expr{args: [%{data: %{op: :top_k, args: [tensor, opts]}}, expr, _callback]}},
state,
cache
) do
Expand All @@ -612,26 +653,24 @@ defmodule EXLA.Defn do

defp cached_recur_operator(
:optional,
%T{data: %Expr{args: [%{data: %{op: :fft2, args: [tensor, opts]}}, _expr, _callback]}} =
out,
%T{data: %Expr{args: [%{data: %{op: :fft2, args: [tensor, opts]}}, expr, _callback]}},
state,
cache
) do
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

{fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], out, state), cache}
{fft2(&Value.fft(&1, :fft, &2, &3), [tensor, opts], expr, state), cache}
end

defp cached_recur_operator(
:optional,
%T{data: %Expr{args: [%{data: %{op: :ifft2, args: [tensor, opts]}}, _expr, _callback]}} =
out,
%T{data: %Expr{args: [%{data: %{op: :ifft2, args: [tensor, opts]}}, expr, _callback]}},
state,
cache
) do
{tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!()

{fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], out, state), cache}
{fft2(&Value.fft(&1, :ifft, &2, &3), [tensor, opts], expr, state), cache}
end

defp cached_recur_operator(:optional, %T{data: %Expr{args: args}}, state, cache) do
Expand Down
19 changes: 11 additions & 8 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14130,17 +14130,20 @@ defmodule Nx do
else
tensor = devectorize(tensor, keep_names: false)
indices = devectorize(indices, keep_names: false)
gather_indices = new_axis(indices, rank(indices))
out = %{tensor | shape: inner_shape, names: inner_names}

{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
{leading, trailing} = Enum.split(tensor_axes, axis)
Nx.Shared.optional(:take, [tensor, indices, [axis: axis]], out, fn tensor, indices, _opts ->
gather_indices = new_axis(indices, rank(indices))
{indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices))
{leading, trailing} = Enum.split(tensor_axes, axis)

transpose_axes = leading ++ indices_axes ++ trailing
transpose_axes = leading ++ indices_axes ++ trailing

tensor
|> gather(gather_indices, axes: [axis])
|> transpose(axes: transpose_axes)
|> reshape(inner_shape, names: inner_names)
tensor
|> gather(gather_indices, axes: [axis])
|> transpose(axes: transpose_axes)
|> reshape(inner_shape, names: inner_names)
end)
end
end

Expand Down

0 comments on commit b8f2254

Please sign in to comment.