Skip to content

Commit

Permalink
Remove gradient accumulation for now
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed May 14, 2024
1 parent f3d5bf9 commit 4851084
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 28 deletions.
6 changes: 3 additions & 3 deletions examples/generative/text_generator.exs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Based on https://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/
Mix.install([
{:axon, "~> 0.5"},
{:nx, "~> 0.5"},
{:exla, "~> 0.5"},
{:axon, path: "/Users/sean/projects/axon"},
{:nx, "~> 0.7"},
{:exla, "~> 0.7"},
{:req, "~> 0.3.3"}
])

Expand Down
53 changes: 46 additions & 7 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,26 @@ defmodule Axon do
name = name(op_name, name)

id = System.unique_integer([:positive, :monotonic])
axon_node = make_node(id, op, name, op_name, mode, inputs, params, args, meta, opts, global_options)

axon_node =
make_node(id, op, name, op_name, mode, inputs, params, args, meta, opts, global_options)

%Axon{output: id, nodes: Map.put(updated_nodes, id, axon_node)}
end

defp make_node(id, op, name, op_name, mode, inputs, params, args, meta, layer_opts, global_options) do
defp make_node(
id,
op,
name,
op_name,
mode,
inputs,
params,
args,
meta,
layer_opts,
global_options
) do
{:current_stacktrace, [_process_info, _axon_layer | stacktrace]} =
Process.info(self(), :current_stacktrace)

Expand Down Expand Up @@ -469,7 +484,14 @@ defmodule Axon do
input_shape = opts[:shape]

output_shape = input_shape && Axon.Shape.input(input_shape)
layer(:input, [], name: name, shape: output_shape, meta: meta, op_name: :input, optional: optional)

layer(:input, [],
name: name,
shape: output_shape,
meta: meta,
op_name: :input,
optional: optional
)
end

@doc """
Expand Down Expand Up @@ -559,7 +581,12 @@ defmodule Axon do
def constant(number, opts) when is_number(number) do
opts = Keyword.validate!(opts, [:name, :meta])

layer(:constant, [], name: opts[:name], meta: opts[:meta], value: Nx.tensor(number), op_name: :constant)
layer(:constant, [],
name: opts[:name],
meta: opts[:meta],
value: Nx.tensor(number),
op_name: :constant
)
end

def constant(value, _) do
Expand Down Expand Up @@ -2137,7 +2164,9 @@ defmodule Axon do
"""
@doc type: :shape
def resize(%Axon{} = x, resize_shape, opts \\ []) do
opts = Keyword.validate!(opts, [:name, :meta, method: :nearest, antialias: true, channels: :last])
opts =
Keyword.validate!(opts, [:name, :meta, method: :nearest, antialias: true, channels: :last])

channels = opts[:channels]

layer(:resize, [x],
Expand Down Expand Up @@ -2384,7 +2413,12 @@ defmodule Axon do
Nx.equal(Nx.as_type(x, :s64), opts[:eos_token])
end

layer(fun, [input], eos_token: eos_token, op_name: :mask, meta: opts[:meta], name: opts[:name])
layer(fun, [input],
eos_token: eos_token,
op_name: :mask,
meta: opts[:meta],
name: opts[:name]
)
end

@doc """
Expand Down Expand Up @@ -3163,7 +3197,12 @@ defmodule Axon do
def stack_columns(%Axon{} = x, opts \\ []) do
opts = Keyword.validate!(opts, [:name, ignore: []])

layer(:stack_columns, [x], meta: opts[:meta], name: opts[:name], ignore: opts[:ignore], op_name: :stack_columns)
layer(:stack_columns, [x],
meta: opts[:meta],
name: opts[:name],
ignore: opts[:ignore],
op_name: :stack_columns
)
end

@doc """
Expand Down
23 changes: 5 additions & 18 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -322,16 +322,11 @@ defmodule Axon.Loop do
* `:loss_scale` - type of loss-scaling to use, if any. Loss-scaling is necessary when
doing mixed precision training for numerical stability. Defaults to `:identity` or
no loss-scaling.
* `:gradient_accumulation_steps` - number of gradient accumulation steps to take during
training. Gradient accumulation decreases the number of updates by accumulating gradients
between steps, increasing the effective batch size on smaller devices. Defaults to 1.
"""
def train_step(model, loss, optimizer, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, loss_scale: :identity, gradient_accumulation_steps: 1])
opts = Keyword.validate!(opts, [:seed, loss_scale: :identity])

loss_scale = opts[:loss_scale] || :identity
gradient_accumulation_steps = opts[:gradient_accumulation_steps] || 1

{init_model_fn, forward_model_fn} = build_model_fns(model, :train, opts)
loss_fn = build_loss_fn(loss)
Expand Down Expand Up @@ -377,12 +372,8 @@ defmodule Axon.Loop do
tar
|> loss_fn.(model_out.prediction)
|> then(fn loss ->
scaled =
loss
|> scale_loss.(loss_scale_state)
|> Nx.divide(gradient_accumulation_steps)

{scaled, Nx.divide(loss, gradient_accumulation_steps)}
scaled = scale_loss.(loss, loss_scale_state)
{scaled, loss}
end)

{model_out, scaled_loss, unscaled_loss}
Expand Down Expand Up @@ -665,17 +656,13 @@ defmodule Axon.Loop do
* `:loss_scale` - type of loss-scaling to use, if any. Loss-scaling is necessary when
doing mixed precision training for numerical stability. Defaults to `:identity` or
no loss-scaling.
* `:gradient_accumulation_steps` - number of gradient accumulation steps to take during
training. Gradient accumulation decreases the number of updates by accumulating gradients
between steps, increasing the effective batch size on smaller devices. Defaults to 1.
"""
def trainer(model, loss, optimizer, opts \\ []) do
opts = Keyword.validate!(opts, [:seed, :loss_scale, :gradient_accumulation_steps, log: 50])
opts = Keyword.validate!(opts, [:seed, :loss_scale, log: 50])

# Build loss now so we can use it as a metric
loss_fn = build_loss_fn(loss)
step_opts = Keyword.take(opts, [:gradient_accumulation_steps, :loss_scale, :seed])
step_opts = Keyword.take(opts, [:loss_scale, :seed])
{init_fn, step_fn} = train_step(model, loss_fn, optimizer, step_opts)

log_interval = opts[:log] || 50
Expand Down

0 comments on commit 4851084

Please sign in to comment.