You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Implement a class that wraps jaxopt solvers and performs batching.
If users passes an array on the GPU, keep it there, otherwise transfer.
If the user passes an hdf5 lazy-loaded? current GLM.fit would complain, think if we want to be more permissive in the checks and try to call solver.run on lazy-loaded data.
Each epoch and each batch should be compiled using foriloop or scan.
See extensions branch for an outline of the batching.
Give the option to "batch=True" and "batch_size=default_int" in the regularizer.
How do we check which default int would work??
The text was updated successfully, but these errors were encountered:
note on line search: in jaxopt there is a line search method called that adapts the stepsize which must be disabled when batching because we have no guarantee we will find a loss reducing direction on any given batch.
The way to go is pretty simple: one can pass a value for the stepsize, and that is kept constant, or one can pass a callable that gets as input an integer (the iteration number) and returns a stepsize.
With the PR #143, we have now set up an update method for the GLM class. This method can be used to batch algorithms.
Next step would be to implement modern stochastic gradient based methods with convergence guarantees: like SAG and SAGA, see https://arxiv.org/pdf/2010.00892.
This solvers will have a run method, which will be fully compiled (using a scan call for running trough the whole dataset) when solver.run is called. In glm this will be equivalent to set the solver to SAG/SAGA and call model.fit(X,y).
We will also provide an example on how one could fit a larger than memory dataset with our nemos+pynapple:
write a custom batcher that uses pynaple to lazy-load a file from dis, get the batch and compute the design on the fly
use the update method and to perform a step of SAG/SAGA.
write a loop that iterates the algorithm until a convergence criteria is matched.
Implement a class that wraps
jaxopt
solvers and performs batching.If users passes an array on the GPU, keep it there, otherwise transfer.
If the user passes an hdf5 lazy-loaded? current
GLM.fit
would complain, think if we want to be more permissive in the checks and try to call solver.run on lazy-loaded data.Each epoch and each batch should be compiled using foriloop or scan.
See
extensions
branch for an outline of the batching.Give the option to "batch=True" and "batch_size=default_int" in the regularizer.
How do we check which default int would work??
The text was updated successfully, but these errors were encountered: