Skip to content

Commit

Permalink
Merge metadata in compiler, resolves #567
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed May 14, 2024
1 parent 5f9e7bc commit 0473cf9
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,9 @@ defmodule Axon.Compiler do
out = Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})
%{stateful | output: out}

%Nx.Tensor{data: %{op: :metadata, args: [arg, metadata]} = expr} = out ->
%{out | data: %{expr | args: [arg, Map.put(metadata, :axon_layer, op_name)]}}

%Nx.Tensor{} = out ->
Nx.Defn.Expr.metadata(Nx.Defn.Expr.tensor(out), %{axon_layer: op_name})

Expand Down
2 changes: 1 addition & 1 deletion test/axon/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ defmodule Axon.IntegrationTest do
model_state =
model
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
|> Axon.Loop.run(data, Axon.ModelState.empty(), iterations: 100, epochs: 10)
|> Axon.Loop.run(data, Axon.ModelState.empty(), iterations: 100, epochs: 20)

eval_results =
model
Expand Down

0 comments on commit 0473cf9

Please sign in to comment.