Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Normalized gradient descent #594

Open
smorad opened this issue Oct 9, 2023 · 6 comments
Open

[Feature Request] Normalized gradient descent #594

smorad opened this issue Oct 9, 2023 · 6 comments
Labels
enhancement New feature or request

Comments

@smorad
Copy link

smorad commented Oct 9, 2023

Optax has various clipping operators, but as far as I can tell, it cannot scale by gradient norm. Adding these capabilities such that they could be chained would allow us to use normalized gradient descent methods (e.g. normalized Adam, etc).

A simple implementation might look like

def scale_by_norm(scale: float=1.0, eps: float=1e-6):
  def init_fn(params):
    del params
    return optax._src.base.OptState

  def update_fn(updates, state, params=None):
    del params
    g_norm = jnp.maximum(optax.global_norm(gradient) + eps, scale)
    def scale_fn(t):
       return t / g_norm

    updates = jax.tree_util.tree_map(scale_fn, updates)
    return updates, state

  return optax.GradientTransformation(init_fn, update_fn)
@mtthss
Copy link
Collaborator

mtthss commented Oct 10, 2023

Do you have a reference to this specific way of normalising?

@smorad
Copy link
Author

smorad commented Oct 10, 2023

This textbook describes it fairly well. My example might be a little fancy, but you could replace the maximum with

g_norm = (optax.global_norm(gradient) + eps) / scale

In this case, scale would refer to alpha in Eq 6.

@mtthss
Copy link
Collaborator

mtthss commented Oct 10, 2023

Sounds like it could be a good addition. Do you want to put together a PR?

@SauravMaheshkar
Copy link
Contributor

Seems like a simple extension of

def normalize() -> base.GradientTransformation:
"""Normalizes the gradient.
Returns:
An (init_fn, update_fn) tuple.
"""
def init_fn(params):
del params
return NormalizeState()
def update_fn(updates, state, params=None):
del params
g_norm = utils.global_norm(updates)
updates = jax.tree_map(lambda g: g / g_norm, updates)
return updates, state
return base.GradientTransformation(init_fn, update_fn)

@mtthss can I take this up ?

@smorad
Copy link
Author

smorad commented Nov 25, 2023

I think this might actually be implemented in clip_by_global_norm. IIRC the code there actually scales the gradient rather than clips it. Might be worth double checking before starting.

@fabianp fabianp added the enhancement New feature or request label Dec 3, 2023
@vroulet
Copy link
Collaborator

vroulet commented Feb 5, 2024

clip_by_global_norm clips but do not necessarily normalize (if the updates are less than clip norm, then they are just returned as is). In other words clip projects on a ball and @smorad you want to project on a sphere.
I think @SauravMaheshkar pointed out a good starting point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants