Simulation-based inference in JAX
sbijax
implements several algorithms for simulation-based inference in
JAX using Haiku,
Distrax and BlackJAX. Specifically, sbijax
implements
- Sequential Monte Carlo ABC (
SMCABC
) - Neural Likelihood Estimation (
SNL
) - Surjective Neural Likelihood Estimation (
SSNL
) - Neural Posterior Estimation C (short
SNP
) - Contrastive Neural Ratio Estimation (short
SNR
) - Neural Approximate Sufficient Statistics (
SNASS
) - Neural Approximate Slice Sufficient Statistics (
SNASSS
) - Flow matching posterior estimation (
SFMPE
) - Consistency model posterior estimation (
SCMPE
)
where the acronyms in parentheses denote the names of the methods in sbijax
.
Caution
You can find several self-contained examples on how to use the algorithms in examples.
Documentation can be found here.
Make sure to have a working JAX
installation. Depending whether you want to use CPU/GPU/TPU,
please follow these instructions.
To install from PyPI, just call the following on the command line:
pip install sbijax
To install the latest GitHub , use:
pip install git+https://github.com/dirmeier/sbijax@<RELEASE>
Note
📝 The API of the package is heavily inspired by the excellent Pytorch-based sbi
package which is substantially more
feature-complete and user-friendly, and better documented.
Simon Dirmeier sfyrbnd @ pm me