Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement a bacth wrapper for jaxopt #137

Open
BalzaniEdoardo opened this issue Apr 16, 2024 · 2 comments
Open

Implement a bacth wrapper for jaxopt #137

BalzaniEdoardo opened this issue Apr 16, 2024 · 2 comments

Comments

@BalzaniEdoardo
Copy link
Collaborator

Implement a class that wraps jaxopt solvers and performs batching.

If users passes an array on the GPU, keep it there, otherwise transfer.

If the user passes an hdf5 lazy-loaded? current GLM.fit would complain, think if we want to be more permissive in the checks and try to call solver.run on lazy-loaded data.

Each epoch and each batch should be compiled using foriloop or scan.

See extensions branch for an outline of the batching.

Give the option to "batch=True" and "batch_size=default_int" in the regularizer.

How do we check which default int would work??

@BalzaniEdoardo
Copy link
Collaborator Author

note on line search: in jaxopt there is a line search method called that adapts the stepsize which must be disabled when batching because we have no guarantee we will find a loss reducing direction on any given batch.

The way to go is pretty simple: one can pass a value for the stepsize, and that is kept constant, or one can pass a callable that gets as input an integer (the iteration number) and returns a stepsize.

@BalzaniEdoardo
Copy link
Collaborator Author

Add to this:

With the PR #143, we have now set up an update method for the GLM class. This method can be used to batch algorithms.

Next step would be to implement modern stochastic gradient based methods with convergence guarantees: like SAG and SAGA, see https://arxiv.org/pdf/2010.00892.

This solvers will have a run method, which will be fully compiled (using a scan call for running trough the whole dataset) when solver.run is called. In glm this will be equivalent to set the solver to SAG/SAGA and call model.fit(X,y).

We will also provide an example on how one could fit a larger than memory dataset with our nemos+pynapple:

  • write a custom batcher that uses pynaple to lazy-load a file from dis, get the batch and compute the design on the fly
  • use the update method and to perform a step of SAG/SAGA.
  • write a loop that iterates the algorithm until a convergence criteria is matched.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant