Skip to content

Commit

Permalink
RMS precision
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Mar 12, 2024
1 parent f8135eb commit 8e1fc07
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
68 changes: 63 additions & 5 deletions lib/bumblebee/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,31 @@ defmodule Bumblebee.Layers do

@doc """
Adds an RMS Normalization layer to the network.
## Options
* `:name` - layer name
* `:initializer` - initializer for the standard deviation parameter.
Defaults to `:ones`
* `:channel_index` - input feature index used for calculating
variance. Defaults to `-1`
* `:epsilon` - numerical stability term
* `:shift` - numeric shift in the scaling expression. Defaults to
`0.0`
* `:upcast` - adds explicit type casting to make sure the norm
is computed in high numerical precision. Either of:
* `:normalization` (default) - upcasts only the input normalization
part
* `:all` - upcasts both input normalization and the scaling
expression
"""
# TODO: Add to Axon
def rms_norm(input, opts \\ []) do
Expand All @@ -1118,33 +1143,66 @@ defmodule Bumblebee.Layers do
shift: 0.0,
channel_index: -1,
epsilon: 1.0e-6,
upcast: :normalization,
initializer: :ones
])

impl =
case opts[:upcast] do
:normalization ->
&rms_norm_impl_upcast_normalization/3

:all ->
&rms_norm_impl_upcast_all/3

other ->
raise ArgumentError,
"expected :upcast to be either :all or :normalization, got: #{other}"
end

weight =
Axon.param("weight", &Axon.Shape.norm_param(&1, opts[:channel_index]),
initializer: opts[:initializer]
)

Axon.layer(&rms_norm_impl/3, [input, weight],
Axon.layer(impl, [input, weight],
name: opts[:name],
shift: opts[:shift],
epsilon: opts[:epsilon],
op_name: :rms_norm
)
end

defnp rms_norm_impl(input, weight, opts \\ []) do
defnp rms_norm_impl_upcast_normalization(input, weight, opts \\ []) do
opts = keyword!(opts, shift: 0.0, epsilon: 1.0e-6, channel_index: -1, mode: :train)

normalized_input =
input
|> Nx.as_type(:f32)
|> rms_normalize(opts)
|> Nx.as_type(Nx.type(input))

normalized_input * (opts[:shift] + weight)
end

defnp rms_norm_impl_upcast_all(input, weight, opts \\ []) do
opts = keyword!(opts, shift: 0.0, epsilon: 1.0e-6, channel_index: -1, mode: :train)

input = Nx.as_type(input, :f32)
weight = Nx.as_type(weight, :f32)

normalized_input = rms_normalize(input, opts)

normalized_input * (opts[:shift] + weight)
end

defnp rms_normalize(input, opts) do
variance =
input
|> Nx.pow(2)
|> Nx.mean(axes: [opts[:channel_index]], keep_axes: true)

x = input * Nx.rsqrt(variance + opts[:epsilon])

x * (opts[:shift] + weight)
input * Nx.rsqrt(variance + opts[:epsilon])
end

@doc """
Expand Down
6 changes: 4 additions & 2 deletions lib/bumblebee/text/gemma.ex
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ defmodule Bumblebee.Text.Gemma do
Layers.rms_norm(decoder_outputs.hidden_state,
name: "output_norm",
shift: 1.0,
epsilon: spec.layer_norm_epsilon
epsilon: spec.layer_norm_epsilon,
upcast: :all
)

%{
Expand Down Expand Up @@ -339,7 +340,8 @@ defmodule Bumblebee.Text.Gemma do
num_key_value_heads: spec.num_key_value_heads,
hidden_size: spec.hidden_size,
kernel_initializer: kernel_initializer(spec),
layer_norm: &Layers.rms_norm(&1, shift: 1.0, name: &2, epsilon: spec.layer_norm_epsilon),
layer_norm:
&Layers.rms_norm(&1, shift: 1.0, name: &2, epsilon: spec.layer_norm_epsilon, upcast: :all),
ffn:
&gated_ffn(&1, spec.intermediate_size, spec.hidden_size,
name: &2,
Expand Down

0 comments on commit 8e1fc07

Please sign in to comment.