You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am currently learning Bayesian analysis and utilizing tensorflow_probability.substrates.jax, but I've encountered some issues. While using jax with jit for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here's a summary of the time taken:
TFP GPU: NUTS alone took 118.2952 seconds
TFP GPU: NUTS + Bijector took 1986.8306 seconds
TFP GPU: NUTS + DualAveragingStepSizeAdaptation took 141.0955 seconds
TFP GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 2397.5875 seconds
Numpypro GPU: NUTS + Bijector + DualAveragingStepSizeAdaptation took 180 seconds
I've conducted speed tests comparing with Numpypro, and essentially, Numpypro with dual averaging step size adaptation and parameter constraints is equivalent to tensorflow_probability NUTS alone.
Could there be something I've missed? Is there room for optimization in this process?
Please find the data and code (.txt need to be change as .ipynb) for reproducibility enclosed: data.csv gitissue.txt google Colab
Please note that I'm only using the first 100 lines of the data.
Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)
Thank you in advance for your assistance.
Sebastian
The text was updated successfully, but these errors were encountered:
SebastianSosa
changed the title
TFP JAX: transitionKernel drastically reduces speed
TFP JAX: The transition kernel drastically increases speed.
Apr 10, 2024
SebastianSosa
changed the title
TFP JAX: The transition kernel drastically increases speed.
TFP JAX: The transition kernel drastically decreases speed.
Apr 10, 2024
Dear all,
I am currently learning Bayesian analysis and utilizing
tensorflow_probability.substrates.jax
, but I've encountered some issues. While usingjax
withjit
for NUTS alone, the performance is quite fast. However, when combined with transformed transitionKernel, the speed decreases drastically. Here's a summary of the time taken:I've conducted speed tests comparing with
Numpypro
, and essentially,Numpypro
with dual averaging step size adaptation and parameter constraints is equivalent totensorflow_probability
NUTS alone.Could there be something I've missed? Is there room for optimization in this process?
Please find the data and code (.txt need to be change as .ipynb) for reproducibility enclosed:
data.csv
gitissue.txt
google Colab
Please note that I'm only using the first 100 lines of the data.
Additionally, as a potential cause, I observed similar speed loss when using the LKJ distribution for other models. (I could post one of them if needed.)
Thank you in advance for your assistance.
Sebastian
The text was updated successfully, but these errors were encountered: