Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Adan Optimizer #410

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open

Adds Adan Optimizer #410

wants to merge 34 commits into from

Conversation

joaogui1
Copy link
Contributor

@joaogui1 joaogui1 commented Sep 4, 2022

Closes #401
Implementation based on the official code

@mtthss
Copy link
Collaborator

mtthss commented Sep 5, 2022

Thank you very much! If the paper claims stack up, this will be very useful to the jax community.

Btw there is pytorch reference code from the authors of the paper,
https://github.com/sail-sg/Adan

Would you mind loading both the pytorch and the optax implementation in a colab and show that they match?
when applying 5/10 steps with some dummy gradients as inputs
it might highlight subtle differences that might be hard to spot from just staring at the code

@joaogui1
Copy link
Contributor Author

joaogui1 commented Sep 5, 2022

Hey @mtthss, always happy to help :D
Yeah, I based my code in the official implementation and here's the colab
At the moment I have 2 questions:

  1. The current implementation seems to generate slightly different results, for example when using a 256x256 matrix and doing 1000 updates the norm of the difference between the pytorch and jax version is 2.4e-5 on the CPU, is that expected or should I keep looking for more divergences in the implementation?
  2. The default for the official implementation of Adan seems to use a different type of weight decay, where the params are divided by (1 + lr * wd) instead of multiplied by (1 - lr * wd), this causes a large difference in behavior, but I'm not exactly sure what is meant by the no_prox parameters that controls what kind of weight decay is being used

@joaogui1
Copy link
Contributor Author

I've computed the relative error and its in the order of 10^-8 (though it still grows as we do more updates), thoughts? @mtthss

@Zach-ER
Copy link

Zach-ER commented Sep 14, 2022

Hi there,
I have been looking at this optimizer as well, and thought I'd chime in!

Firstly, thanks for the work that you have done. This looks to be a credible and nicely-written implementation.

Some notes:

  1. L6 of the Adan algorithm says: $\mathbf{\eta_k}=\eta / (\sqrt{\mathbf{n_k} + \epsilon})$; however, in the released code, the $\epsilon$ is not within the square root. This is a discrepancy between the paper and published code: it is an open question which we should follow here. This was the same in v1 and v2 of the paper. I have submitted an issue — we can see what they say. We could add an eps_root parameter to enable the user to set it how they like.2. I would change the weight_decay default to 0.02, as in the paper.
  2. why do you have a mu_dtype but not a delta_dtype? Or, since they are both first-order accumulators, you may want to reuse the datatype.
  3. you mention that there are two different implementations of the weight decay in the released code. I think that the no_prox==True condition is the one you have implemented, which fits nicely into an optax.chain in your code. However, in the colab you wrote, weight decay is set to 0., which means that there is no difference in the conditions. The implementation which matches the paper is the no_prox==False condition, where the learning rate is harder to factor out into an optax chain.

To deal with this final issue, we could either have:

  • a no_prox condition that copies their implementation
  • we could implement as written in the paper (different to yours — _scale_by_learning_rate and transform.add_decayed_weights will not be chained — the update will be in one step), or just leave it as this.

I did some testing myself and it looks like your implementation only really diverges for non-zero weight decay and no_prox==False, as expected.

@joaogui1
Copy link
Contributor Author

Hi @Zach-ER! Thanks for the thorough response, for checking with the authors and running my code!

  1. From the authors' response to your issue it looks like their results were achieved with eps outside the square root, so I think we can leave it as it, wdyt?
  2. Makes sense, I believe I'll resuse it
  3. Yeah I'm not sure about the best way to proceed here, maybe I can implement the no_prox condition on the transform and create to alias, adan and adan_no_prox, wdyt?

Once again thanks for the comments :D

@Zach-ER
Copy link

Zach-ER commented Sep 21, 2022

  1. I would prefer leaving the defaults as they are but also having an eps_root argument, defaulting to 0.0. This has the benefit of being a closer match to the adam signature and also letting people implement it as in the paper (if they so choose).
  2. 👍🏻 — just make sure the name is appropriate.
  3. I think we should definitely have a method that matches their default implementation. I would put a condition into the adan optimizer that matches the reference behaviour. This will be slightly fiddly. Do you want to have a try at this? If not, I would be happy to draft something.

minor nitpicking:
docstring needs fixing for b1, b2 (and b3 needs adding). Something like

    b1: Decay rate for the exponentially weighted average of gradients.
    b2: Decay rate for the exponentially weighted average of difference of
      gradients.
    b3: Decay rate for the exponentially weighted average of the squared term.

@Zach-ER
Copy link

Zach-ER commented Sep 29, 2022

Hi there,
I have posted a new issue in the original repo here. If they say their experiments were conducted with no_prox=False, then I think we can ignore the other condition and your PR reproduces their algorithm and fits well style-wise with the rest of the codebase.

Will update this when the authors respond.

@Zach-ER
Copy link

Zach-ER commented Sep 30, 2022

OK, the authors have responded.
Their experiments use no_prox=False, the condition that you have not implemented, so I think we do need to implement that one and match their algorithm exactly.

what I would do:

  1. add a use_proximal_operator boolean argument, defaulting to True, to match what's in the paper. (this is apparently what prox is short for).
  2. if False, implement exactly as you've already done
  3. if True, need a slightly less standard implementation

I wrote a jax version that matches their implementation, based on your gist and codebase. It is here and matches their results completely with weight decay turned on.

If you could integrate what I've written into your PR, that would be great — if not, I will find some time to do it (but quite busy at the moment).

Again, thanks for your work — this will be a great addition to the library 🎖

@joaogui1
Copy link
Contributor Author

Thank you very much for all the help!
Right now I'm at RIIAA Ecuador, but I'll try to integrate your code as soon as I'm back home, around Sunday or Monday :)

@joaogui1
Copy link
Contributor Author

So it does pass the alias tests, but it looks like sphinx is erroing now, any ideas of a quick fix (I'm not experienced with Sphinx) ? @Zach-ER

@hbq1
Copy link
Collaborator

hbq1 commented Oct 20, 2022

@joaogui1 it should be fixed now, could you update the PR?

@joaogui1
Copy link
Contributor Author

@hbq1 update how? Should I merge main?

@hbq1
Copy link
Collaborator

hbq1 commented Oct 20, 2022

Should I merge main?

Yes 👍

@@ -280,6 +288,7 @@ Optax Transforms and States
.. autofunction:: scale
.. autofunction:: scale_by_adam
.. autofunction:: scale_by_adamax
.. autofunction:: scale_by_adan
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add scale_by_proximal_adan?

the corresponding `GradientTransformation`.
"""
if use_proximal_operator:
return transform.scale_by_proximal_adan(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you comment that _scale_by_learning_rate is not needed here

@hbq1
Copy link
Collaborator

hbq1 commented Oct 20, 2022

@Zach-ER thanks a lot for your great comments! Since you are an experienced user of this optimiser, I was wondering if the current version looks good to you? :)

@Zach-ER
Copy link

Zach-ER commented Oct 21, 2022

Yes, LGTM. Looking forward to trying it out some more 🙌🏻

@joaogui1
Copy link
Contributor Author

Done @hbq1

Copy link
Member

@mkunesch mkunesch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I am doing an internal review for this PR and have noticed the following potential problem:

The optax transforms are all composable building blocks that can be chained together. While scale_by_proximal_adan can be chained I think adding any transforms on top of it might in some situations give unexpected results since it calculates new parameters internally and then the new_updates based on them but this calculation isn't aware of the other transforms a user might add on top of scale_by_proximal_adan.

I think in that sense it might be similar to the lookahead optimizer which we have moved to a separate file (there were other reasons for this too) and added warnings to the docstring.

Do you agree that this could be a problem or did I miss something? If it is a problem, can we rewrite scale_by_proximal_adan such that it can be chained with further transforms? If we can't rewrite it and it is a problem, I think we should discuss a general way to deal with these cases (optimizers that cannot be chained any further) in the optax API.

Thanks a lot and let me know what you think!

@Zach-ER
Copy link

Zach-ER commented Oct 25, 2022

Yes, I think that this could be a problem.

I agree that it is in a similar boat to the lookahead optimizer.

@carlosgmartin
Copy link
Contributor

Any update on this?

@fabianp
Copy link
Member

fabianp commented Mar 21, 2024

@carlosgmartin no updates, this PR is currently orphaned. Do you want to take over?

@carlosgmartin
Copy link
Contributor

@fabianp What changes, if any, need to be made to @joaogui1's PR?

And what's the consensus on @mkunesch's questions?

@fabianp
Copy link
Member

fabianp commented Mar 27, 2024

  1. A first step would be to update with main, currently there are some conflicts with the current head.
  2. I believe the issues highlighted by @mkunesch only apply to the proximal version scale_by_proximal_adan. I would suggest focusing first on scale_by_adan which shouldn't have those issues if I understood correctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Adan Optimizer
7 participants