Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytree-based Optimizers #432

Open
cgarciae opened this issue Oct 4, 2022 · 3 comments
Open

Pytree-based Optimizers #432

cgarciae opened this issue Oct 4, 2022 · 3 comments
Labels
enhancement New feature or request

Comments

@cgarciae
Copy link

cgarciae commented Oct 4, 2022

This topic is in my mind every once in a while, it has already been discussed extensively (e.g. #197 (comment)), but I feel it needs new life because it could resolve the last remaining quirks in optax.

Optax optimizers have well defined API and contrary to neural networks they have clear ways on how to update their state, making them perfectly suitable for pytree/dataclass interfaces. Similar to what @NeilGirdhar has done here, one could express Pytree version of all optimizers by wrapping functional optax with the added benefits:

  1. Optimizer can now pass through jax's function transformation boundaries, e.g. jit.
  2. Hyper-parameters could be updated using immutable API's like .replace().
  3. You could get rid of the optimizer vs opt_state separation.
  4. You can now inspect hyper-parameter updates e.g. log the learning rate under a schedule.

Example

For this example I'l be using Flax's PyTreeNode but any pytree implementation is just as good.

class SGD(PyTreeNode):
  learning_rate: ScalarOrSchedule
  momentum: Optional[float] = None
  nesterov: bool = False
  accumulator_dtype: Optional[Any] = field(pytree_node=False, default=None)
  opt_state: Optional[OptState] = None

  @property
  def tx(self):
    return optax.sdg(**{k: v for k, v in vars(self).items() if k != 'opt_state'})

  def init(self: A, params: Params) -> A:
    return self.replace(opt_state=self.tx.init(params))

  def update(
    self: A, updates: Updates, params: Optional[Params] = None
  ) -> Tuple[Updates, A]:
    updates, opt_state = self.tx.update(updates, self.opt_state, params=params)
    return update, self.replace(opt_state=opt_state)
    
# sample usage
tx = SDG(3e-4)
tx = tx.init(params)
updates, tx = tx.update(grads)
params = optax.apply_updates(params, updates)

Proposal

Given that any community shim will probably not succeed, how about a optax.pytree namespace (naming suggestions are welcomed) where a shim could officially live and be discussed with the core team?

@8bitmp3
Copy link
Contributor

8bitmp3 commented Oct 4, 2022

@mkunesch @rosshemsley @hbq1 let us know what you think 👍

@rosshemsley
Copy link
Collaborator

@cgarciae Sorry for taking a while to respond! And thanks for sharing your design proposal!

There have been a few projects recently working on attaching functions to custom pytrees in JAX (e.g. equinox), and this proposal has seems to have some similar ideas.

For the reasons you mentioned, this factoring can be attractive! Although it's worth highlighting that there are some downsides to this approach, too:

  1. JAX function transformations on classes can be a source of confusion for users. Functions are generally easier to reason about under JAX transformations than class methods.
  2. Many JAX users checkpoint their states using pickle, and custom pytrees cause problems with this. Keeping state tree as close to 'a dictionary of arrays' as possible has generally been a useful goal for many of the teams currently using optax. Furthermore, any changes to the optax API would have to be able to support reloading existing checkpoints in a backwards-compatible way.

We are continuing to work on polishing the optax API - although we are also deliberately being conservative about the changes we make - part of what makes optax successful is its ruthless simplicity, and forking the API with two sets of alternative factorings would increase the API surface area and could make it harder for us to support.

We're currently working on improving the package factoring, which will hopefully leave optax in a better place for trying out some more experimental ideas (such as this kind of API factoring), but it may be a little while before we would want to introduce big changes like this to the core library.

We'd encourage you to keep thinking about this idea though! Especially with regards to 2) above. Optax has thousands of users at the moment, and so charting a path forwards whilst retaining checkpoint compatibility is probably the biggest barrier we have to making these kinds of changes.

It would also be a good idea to try and "break" this design - e.g. what happens when using more esoteric JAX transforms (such as vmap, pmap, pjit, or grad) can you break this design through unexpected jit placement? (as a rule, someone has done one of these things somewhere to all optax optimizers already)

@rosshemsley rosshemsley added the enhancement New feature or request label Nov 21, 2022
@NeilGirdhar
Copy link
Contributor

NeilGirdhar commented Feb 10, 2023

Just some thoughts:

JAX function transformations on classes can be a source of confusion for users. Functions are generally easier to reason about under JAX transformations than class methods.

This proposal doesn't change anything for users since the interface is identical except for the four benefits mentioned. The reasoning that users have to do is exactly the same. If anything, the user reasoning is simpler since the sequence interface is not exposed.

It's unfortunate that we didn't reconcile this issue back when it was suggested in the very first Optax issue.

It would also be a good idea to try and "break" this design - e.g. what happens when using more esoteric JAX transforms (such as vmap, pmap, pjit, or grad)

Why don't you try breaking it? I think it might make the benefits more apparent.

I also think it would be good to at least block the sequence interface, which are misuses of the current optax design. This will make it easier to improve your design in the future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants