Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Investigate the use of Jax #101

Open
CharlesCossette opened this issue Sep 23, 2023 · 0 comments
Open

Investigate the use of Jax #101

CharlesCossette opened this issue Sep 23, 2023 · 0 comments

Comments

@CharlesCossette
Copy link
Member

Jax could be an interesting framework to use with its autodiff and just-in-time compiling capabilities. I personally thought there was value in having a purely numpy-based package to begin with, in terms of having lower barrier to entry (many people in our lab group were familiar with python/numpy/lie groups but not with Jax).

There's potentially a solution that mixes jax with the current code, where jax is used for JIT-able state/process/measurement model evaluations along with autodiff Jacobians, but the output of the models is converted back to regular numpy arrays for use by the same filter implementations as there is now. While this would not produce end-to-end differentiable/compilable code, it would speed up evaluation/jacobian calculations compared to the current finite differencing.

To start exploring this, I recommend we define some new jax-based abstract classes for state/process/measurement models. Say JaxState, JaxProcessModel and JaxMeasurementModel. These classes could have the default jacobian implementations use jax's autodiff, and the user is responsible for writing jax-compatible code in the evaluate method.

If you'd like to tackle this. Please reach out :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant