Skip to content

Diffrax v0.5.0

Compare
Choose a tag to compare
@github-actions github-actions released this 08 Jan 23:23
· 28 commits to main since this release

This is a fun release. :)

Diffrax was the very first project I ever released for the JAX ecosystem. Since then, many new libraries have grown up around it -- most notably jaxtyping, Lineax, and Optimistix.

All of these other libraries actually got their start because I wanted to use them for some purpose in Diffrax!

And with this release... we are now finally doing that. Diffrax now depends on jaxtyping for its type annotations, Lineax for linear solves, and Optimistix for root-finding!

That makes this release mostly just a huge internal refactor, so it shouldn't affect you (as a downstream user) very much at all.

Features

  • Added diffrax.VeryChord, which is a chord-type quasi-Newton method typically used as part of an implicit solver. (This is the most common root-finding method used in implicit differential equation solvers.)
  • Added diffrax.with_stepsize_controller_tols, which can be used to mark that a root-finder should inherit its tolerances from the stepsize_controller. For example, this is used as:
    root_finder = diffrax.with_stepsize_controller_tols(diffrax.VeryChord)()
    solver = diffrax.Kvaerno5(root_finder=root_finder)
    diffrax.diffeqsolve(..., solver=solver, ...)
    This tolerance-inheritance is the default for all implicit solvers.
    (Previously this tolerance-inheritance business was done by passing rtol/atol=None to the nonlinear solver -- and again was the default. However now that Optimistix owns the nonlinear solvers, it's up to Diffrax to handle tolerance-inheritance in a slightly different way.)
  • Added the arguments diffrax.ImplicitAdjoint(linear_solver=..., tags=...). Implicit backpropagation can now be done using any choice of Lineax solver.
  • Now static-type-checking compatible. No more having your IDE yell at you for incorrect types.
  • Diffrax should now be compatible with JAX_NUMPY_DTYPE_PROMOTION=strict and JAX_NUMPY_RANK_PROMOTION=raise. (These are JAX flags that can be used to disable dtype promotion and broadcasting, to help write more reliable code.)
  • diffrax.{ControlTerm, WeaklDiagonalControlTerm} now support using a callable as their control, in which case it is treated as the evaluate of an AbstractPath over [-inf, inf].
  • Experimental support for complex numbers in explicit solvers. This may still go wrong, so please report bugs / send fixing PRs as you encounter them.

Breaking changes

  • diffrax.{AbstractNonlinearSolver, NewtonNonlinearSolver, NonlinearSolution} have been removed in favour of using Optimistix. If you were using these explicitly, e.g. Kvaerno5(nonlinear_solver=NewtonNonlinearSolver(...)), then the equivalent behaviour is now given by Kvaerno5(root_finder=VeryChord(...)). You can also use any other Optimistix root-finder too.
  • The result of a solve is now an Enumeration rather than a plain integer. For example, this means that you should write something like jnp.where(sol.result == diffrax.RESULTS.successful, ...), not jnp.where(sol.result == 0, ...).
  • A great many modules have been renamed from foo.py to _foo.py to explicitly indicate that they're private. Make sure to access features via the public API.
  • Removed the AbstractStepSizeController.wrap_solver method.

Bugfixes

  • Crash fix when using an implicit solver together with DirectAdjoint.
  • Crash fix when using dt0=None, stepsize_controller=diffrax.PIDController(...) with SDEs.
  • Crash fix when using adjoint=BacksolveAdjoint(...) with VirtualBrownianTree with jax.disable_jit on the TPU backend.

New Contributors

Full Changelog: v0.4.1...v0.5.0