Skip to content

Commit

Permalink
Fix mismatch for certain optimizers with dropout (#458)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Jan 21, 2023
1 parent 76d5a45 commit 27b9b53
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 7 deletions.
6 changes: 4 additions & 2 deletions lib/axon/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,13 @@ defmodule Axon.Shared do
@doc """
Creates a fulls-like tuple of inputs.
"""
deftransform fulls_like(params, value) do
deftransform fulls_like(params, value, opts \\ []) do
opts = Keyword.validate!(opts, [:type])
fun = Axon.Initializers.full(value)

deep_new(params, fn x ->
fun.(Nx.shape(x), Nx.type(x))
type = opts[:type] || Nx.type(x)
fun.(Nx.shape(x), type)
end)
end

Expand Down
8 changes: 4 additions & 4 deletions lib/axon/updates.ex
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ defmodule Axon.Updates do
end

defnp init_scale_by_rss(params, value) do
sum_of_squares = fulls_like(params, value)
sum_of_squares = fulls_like(params, value, type: :f32)
%{sum_of_squares: sum_of_squares}
end

Expand Down Expand Up @@ -278,7 +278,7 @@ defmodule Axon.Updates do
end

defnp init_scale_by_rms(params, scale) do
nu = fulls_like(params, scale)
nu = fulls_like(params, scale, type: :f32)
%{nu: nu}
end

Expand Down Expand Up @@ -395,7 +395,7 @@ defmodule Axon.Updates do

defnp init_scale_by_stddev(params, value) do
mu = zeros_like(params, type: :f32)
nu = fulls_like(params, value)
nu = fulls_like(params, value, type: :f32)
%{mu: mu, nu: nu}
end

Expand Down Expand Up @@ -860,7 +860,7 @@ defmodule Axon.Updates do
end

defnp init_scale_by_yogi(params, value) do
value = fulls_like(params, value)
value = fulls_like(params, value, type: :f32)
mu = value
nu = value
count = Nx.tensor(0)
Expand Down
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule Axon.MixProject do
use Mix.Project

@source_url "https://github.com/elixir-nx/axon"
@version "0.4.0"
@version "0.4.1"

def project do
[
Expand Down
54 changes: 54 additions & 0 deletions test/axon/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,58 @@ defmodule Axon.IntegrationTest do
assert_equal(step_state1, step_state2)
end)
end

test "dropout with certain optimizers regression test" do
{train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337)

train =
train
|> Stream.map(fn {xs, ys} ->
{xs, one_hot(ys, num_classes: 2)}
end)
|> Enum.to_list()

[{x_test, _}] = Enum.take(train, 1)

model =
Axon.input("input")
|> Axon.dense(16)
|> Axon.dropout(rate: 0.1)
|> Axon.dense(2, activation: :softmax)

optimizers = [
Axon.Optimizers.rmsprop(5.0e-3, centered: true),
Axon.Optimizers.rmsprop(5.0e-3, centered: false),
:adagrad,
:yogi
]

ExUnit.CaptureIO.capture_io(fn ->
for optim <- optimizers do
results =
model
|> Axon.Loop.trainer(:categorical_cross_entropy, optim)
# TODO: Fix default output transform
|> Map.update(:output_transform, nil, fn _ -> & &1 end)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, train)
|> Axon.Loop.run(train, %{}, epochs: 10)

assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} =
results

eval_results =
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train, model_state)

assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results

assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7)
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
end
end)
end
end

0 comments on commit 27b9b53

Please sign in to comment.