Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFP JAX: The transition kernel drastically decreases speed. #1807

Open
SebastianSosa opened this issue Apr 9, 2024 · 0 comments
Open

TFP JAX: The transition kernel drastically decreases speed. #1807

SebastianSosa opened this issue Apr 9, 2024 · 0 comments

Comments

@SebastianSosa
Copy link

SebastianSosa commented Apr 9, 2024

Dear all,

I am currently learning Bayesian analysis and utilizing tensorflow_probability.substrates.jax, but I've encountered some issues. While using jax with jit for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here's a summary of the time taken:

  • TFP GPU: NUTS alone took 118.2952 seconds
  • TFP GPU: NUTS + Bijector took 1986.8306 seconds
  • TFP GPU: NUTS + DualAveragingStepSizeAdaptation took 141.0955 seconds
  • TFP GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 2397.5875 seconds
  • Numpypro GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 180 seconds

I've conducted speed tests comparing with Numpypro, and essentially, Numpypro with dual averaging step size adaptation and parameter constraints is equivalent to tensorflow_probability NUTS alone.

Could there be something I've missed? Is there room for optimization in this process?

Please find the data and code (.txt need to be change as .ipynb) for reproducibility enclosed:
data.csv
gitissue.txt
google Colab

Please note that I'm only using the first 100 lines of the data.

Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)

Thank you in advance for your assistance.

Sebastian

@SebastianSosa SebastianSosa changed the title TFP JAX: transitionKernel drastically reduces speed TFP JAX: The transition kernel drastically increases speed. Apr 10, 2024
@SebastianSosa SebastianSosa changed the title TFP JAX: The transition kernel drastically increases speed. TFP JAX: The transition kernel drastically decreases speed. Apr 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant