Skip to content

Commit

Permalink
Allow dynamic calls in defn
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed May 13, 2024
1 parent 83419bb commit 44313f6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 23 deletions.
20 changes: 3 additions & 17 deletions nx/lib/nx/defn/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -568,19 +568,6 @@ defmodule Nx.Defn.Compiler do
{{{:., dot_meta, [fun]}, meta, args}, state}
end

# TODO: Remove me once transform/2 is removed.
defp normalize({{:., _, [Nx.Defn.Kernel, :transform]} = call, meta, [ast, fun]}, state) do
{ast, state} = normalize(ast, state)

fun =
Macro.prewalk(fun, fn
var when is_var(var) -> normalize_var(var)
node -> node
end)

{{call, meta, [ast, fun]}, state}
end

defp normalize({{:., _, [Nx.Defn.Kernel, :hook]} = call, meta, [ast | rest]}, state) do
{ast, state} = normalize(ast, state)
{{call, meta, [ast | rest]}, state}
Expand Down Expand Up @@ -647,11 +634,10 @@ defmodule Nx.Defn.Compiler do
state}
end

defp normalize({{:., dot_meta, [remote, name]}, meta, args}, state)
# TODO: Remove args == [] once we require Elixir version where args are nil
when is_atom(name) and (args == nil or args == []) do
defp normalize({{:., dot_meta, [remote, name]}, meta, args}, state) when is_atom(name) do
{remote, state} = normalize(remote, state)
{{{:., dot_meta, [Map, :fetch!]}, meta, [remote, name]}, state}
{args, state} = normalize_list(args, state)
{{{:., dot_meta, [remote, name]}, meta, args}, state}
end

defp normalize({left, right}, state) do
Expand Down
18 changes: 12 additions & 6 deletions nx/test/nx/defn_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -953,34 +953,40 @@ defmodule Nx.DefnTest do

describe "remote functions" do
defmodule Remote do
defn(add_two(c, d), do: c + d)
defn add_two(c, d), do: c + d
end

defn(add_two_remote(a, b), do: Remote.add_two(a, b))
defn add_two_remote(a, b), do: Remote.add_two(a, b)

test "public" do
assert %T{data: %Expr{op: :add, args: [_, _]}} = add_two_remote(1, 2)
end

defn(add_two_unknown(a, b), do: Nx.DefnTest.unknown(a, b))
defn add_two_dynamic(a, b, opts \\ []), do: opts[:remote].add_two(a, b)

def not_defn(a, b), do: Nx.add(a, b)
defn(add_two_not_defn(a, b), do: Nx.DefnTest.not_defn(a, b))
test "dynamic" do
assert %T{data: %Expr{op: :add, args: [_, _]}} = add_two_remote(1, 2)
end

defn(add_two_io(a, b), do: IO.inspect({a, b}))
defn add_two_unknown(a, b), do: Nx.DefnTest.unknown(a, b)

test "undefined remote" do
assert_raise UndefinedFunctionError,
"function Nx.DefnTest.unknown/2 is undefined or private",
fn -> add_two_unknown(1, 2) end
end

def not_defn(a, b), do: Nx.add(a, b)
defn add_two_not_defn(a, b), do: Nx.DefnTest.not_defn(a, b)

test "not defn remote" do
assert_raise RuntimeError,
"cannot invoke Nx.DefnTest.not_defn/2 inside defn because it was not defined with defn",
fn -> add_two_not_defn(1, 2) end
end

defn add_two_io(a, b), do: IO.inspect({a, b})

test "IO remote" do
assert_raise RuntimeError,
"cannot invoke IO.inspect/1 inside defn because it was not defined with defn. " <>
Expand Down

0 comments on commit 44313f6

Please sign in to comment.