Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX BDF solver tests failing / update [jax] versions (due to scipy.linalg.tril deprecation) #3959

Open
1 of 2 tasks
agriyakhetarpal opened this issue Apr 3, 2024 · 13 comments
Assignees
Labels
bug Something isn't working priority: high To be resolved as soon as possible release blocker Issues that need to be addressed before the creation of a release

Comments

@agriyakhetarpal
Copy link
Member

agriyakhetarpal commented Apr 3, 2024

The JAX BDF solver tests are failing on all PRs (#3846, #3945, etc.) for Python 3.9 and later because SciPy removed some linear algebra routines in v1.13.0. The Python 3.8 tests are passing because SciPy has dropped support for it earlier

I'm guessing we need to bump the jax and jaxlib versions now or relax the pin in the requirements, because there have been quite many releases since v0.4.20 – the current version available at the time of writing is v0.4.25.

Checklist

@agriyakhetarpal agriyakhetarpal changed the title JAX BDF solver tests failing / update [jax] versions (due to scipy.linalg.tril deprecation) JAX BDF solver tests failing / update [jax] versions (due to scipy.linalg.tril deprecation) Apr 3, 2024
@agriyakhetarpal agriyakhetarpal self-assigned this Apr 3, 2024
@agriyakhetarpal agriyakhetarpal added bug Something isn't working priority: high To be resolved as soon as possible labels Apr 3, 2024
@agriyakhetarpal
Copy link
Member Author

It's probably not as trivial as bumping the JAX version because there are a few other errors that I don't understand with JAX's JIT and spectral volumes, so I'm putting this aside for a bit to return to soon and let others proceed if there is progress

@agriyakhetarpal
Copy link
Member Author

agriyakhetarpal commented Apr 3, 2024

Bumping to v0.4.24 fixes at least part of the tests, earlier versions still have the SciPy error

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

It is worthwhile to bump jax up as high as possible. We have people that are experienced with Jax that might be able to help. We are going to get into more compatibility issues as the code ages

@agriyakhetarpal
Copy link
Member Author

I agree with you – v0.4.26 is their latest release, should we drop the pin altogether? It might break on v0.5.X, so having >0.4, <0.5 bounds is another option

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

Pinning is fine so there are not unexpected changes. Realistically we should have all major dependencies pinned. Something like dependabot should do the updates so the failures are all in one place

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

Do you need help with this one?

@valentinsulzer
Copy link
Member

we should have all major dependencies pinned

We shouldn't pin to exact versions as that may cause compatibility issues for our users (if they try to use pybamm + another package that happens to pin e.g. numpy to a different version). We can specify ranges but they should be as wide as possible

@valentinsulzer
Copy link
Member

jax is an exception where we have to pin the exact version since every release changes the API

@agriyakhetarpal
Copy link
Member Author

Do you need help with this one?

I would appreciate that, being someone who hasn't used JAX a lot. I was able to get the tests to pass with newer versions of JAX (some of those can be ignored because it's probably not caching the solves properly on my machine). Some spatial methods tests are still failing, where I received IndexErrors – and my debugger doesn't help there

We can specify ranges but they should be as wide as possible

To add to this, we have been keeping the lower bounds in sync with the versions of the packages available on conda-forge (too much of a lower bound brought some trouble earlier during the time of the PyBaMM 23.9 release). It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code?

@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

It might make sense to drop Python 3.8 soon since it has been passing due to the use of deprecated code?

I was planning on putting up a PR for that this week. Seemed to align with the removal of ODEs and the removal of the Jax windows restrictions. I will probably just go ahead and make that PR while helping with the Jax stuff. I should have a bit of time to take a look this afternoon. Just share the branch you are working on and I will see what I can do to help out

@agriyakhetarpal
Copy link
Member Author

I don't have a branch or anything concrete, I was debugging only locally. I'll add the link here once I get back to it

@valentinsulzer
Copy link
Member

Yeah let's follow numpy's lead for which python versions we support, they have dropped support for 3.8

@kratman kratman mentioned this issue Apr 3, 2024
5 tasks
@kratman
Copy link
Contributor

kratman commented Apr 3, 2024

A few related issues were solved with #3963, #3961, and #3962. I will take another stab at updating Jax in a few days

@agriyakhetarpal agriyakhetarpal added the release blocker Issues that need to be addressed before the creation of a release label Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working priority: high To be resolved as soon as possible release blocker Issues that need to be addressed before the creation of a release
Projects
None yet
Development

No branches or pull requests

3 participants