Skip to content

Commit

Permalink
Support blocks with multiple inputs (#574)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed May 14, 2024
1 parent b6ab577 commit f48dcd1
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 12 deletions.
11 changes: 10 additions & 1 deletion lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -746,17 +746,26 @@ defmodule Axon do
"""
@doc type: :special
def block(fun, opts \\ []) when is_function(fun) do
{:arity, arity} = Function.info(fun, :arity)
opts = Keyword.validate!(opts, [:name, :meta])
block_id = System.unique_integer([:positive, :monotonic])

fn inputs ->
block_fun(arity, fn inputs ->
layer(:block, List.wrap(inputs),
op_name: :block,
name: opts[:name],
meta: opts[:meta],
block_fun: fun,
block_id: block_id
)
end)
end

for i <- 0..128 do
args = Macro.generate_arguments(i, __MODULE__)

defp block_fun(unquote(i), callback) do
fn unquote_splicing(args) -> callback.(unquote(args)) end
end
end

Expand Down
33 changes: 22 additions & 11 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -606,17 +606,17 @@ defmodule Axon.Compiler do
%Axon.Node{
id: id,
op: :block,
parent: [parent],
parent: parents,
opts: [block_fun: block_fun, block_id: block_id],
name: name_fn
},
nodes,
cache_and_counts,
config
) do
{[parent_id], {cache, op_counts, block_cache, model_state_meta}} =
{parent_ids, {cache, op_counts, block_cache, model_state_meta}} =
Enum.map_reduce(
[parent],
parents,
cache_and_counts,
&to_model_funs(&1, nodes, &2, config)
)
Expand All @@ -627,7 +627,8 @@ defmodule Axon.Compiler do
{funs, name, block_cache, op_counts}

%{} ->
funs = build(block_fun.(Axon.input("subgraph")), debug?: config.debug?)
inputs = Enum.with_index(parents, fn _, i -> Axon.input("subgraph#{i}") end)
funs = build(apply(block_fun, inputs), debug?: config.debug?)
name = name_fn.(:block, op_counts)
op_counts = Map.update(op_counts, :block, 1, fn x -> x + 1 end)
{funs, name, Map.put(block_cache, block_id, {funs, name}), op_counts}
Expand All @@ -637,9 +638,9 @@ defmodule Axon.Compiler do
# Recurse graph inputs and invoke cache to get parent results,
# state, and result_cache and then apply dtype policy and hooks
# to each input
{[layer_input], {state, result_cache, none?}} =
{layer_inputs, {state, result_cache, none?}} =
Enum.map_reduce(
[parent_id],
parent_ids,
{state, result_cache, false},
fn parent_id, {state, result_cache, none?} ->
{layer_input, {state, result_cache}} =
Expand All @@ -663,7 +664,13 @@ defmodule Axon.Compiler do
{%Axon.None{}, {state, result_cache}}
else
block_params = params[block_name] || %{}
result = apply(block_predict_fun, [Axon.ModelState.new(block_params), layer_input])

inputs =
layer_inputs
|> Enum.with_index()
|> Map.new(fn {input, i} -> {"subgraph#{i}", input} end)

result = apply(block_predict_fun, [Axon.ModelState.new(block_params), inputs])

{out_result, out_state} =
case result do
Expand All @@ -685,8 +692,8 @@ defmodule Axon.Compiler do
end

init_fun = fn template, cache, result_cache, fn_stacktrace, keys ->
{[parent_shape], {parent_params, result_cache, none?}} =
Enum.map_reduce([parent_id], {%{}, result_cache, false}, fn
{parent_shapes, {parent_params, result_cache, none?}} =
Enum.map_reduce(parent_ids, {%{}, result_cache, false}, fn
parent_id, {params, result_cache, none?} ->
{parent_shape, {params, result_cache}} =
call_init_cache(
Expand All @@ -706,8 +713,12 @@ defmodule Axon.Compiler do
if none? do
{%Axon.None{}, {parent_params, result_cache}}
else
template = Nx.broadcast(0.0, parent_shape)
block_params = apply(block_init_fun, [template, Axon.ModelState.empty()])
templates =
parent_shapes
|> Enum.with_index()
|> Map.new(fn {shape, i} -> {"subgraph#{i}", Nx.broadcast(0.0, shape)} end)

block_params = apply(block_init_fun, [templates, Axon.ModelState.empty()])

params =
if block_params == %{} do
Expand Down
37 changes: 37 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5330,6 +5330,43 @@ defmodule CompilerTest do
input = random({1, 1})
assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k, b))
end

test "works with multiple block inputs" do
block =
Axon.block(fn x, y ->
dense = Axon.block(&Axon.dense(&1, 4))
Axon.add(dense.(y), dense.(x))
end)

input1 = Axon.input("input1")
input2 = Axon.input("input2")

model = block.(input1, input2) |> Axon.dense(1)

{init_fn, predict_fn} = Axon.build(model)

actual_predict_fn = fn %{"input1" => x, "input2" => y}, k1, b1, k2, b2 ->
x = Axon.Layers.dense(x, k1, b1)
y = Axon.Layers.dense(y, k1, b1)

x
|> Nx.add(y)
|> Axon.Layers.dense(k2, b2)
end

input = %{"input1" => Nx.tensor([[0.5]]), "input2" => Nx.tensor([[0.75]])}

assert %ModelState{
data: %{
"block_0" => %{
"block_0" => %{"dense_0" => %{"kernel" => k1, "bias" => b1}}
},
"dense_0" => %{"kernel" => k2, "bias" => b2}
}
} = params = init_fn.(input, ModelState.empty())

assert_equal(predict_fn.(params, input), actual_predict_fn.(input, k1, b1, k2, b2))
end
end

describe "initializers" do
Expand Down

0 comments on commit f48dcd1

Please sign in to comment.