Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example of PyMC usage #41

Open
twiecki opened this issue Feb 14, 2022 · 11 comments
Open

Add example of PyMC usage #41

twiecki opened this issue Feb 14, 2022 · 11 comments

Comments

@twiecki
Copy link

twiecki commented Feb 14, 2022

PyMC v4 has a JAX backend and can use samplers like those from numpyro or blackjax, it should be pretty easy thus to add an example of how to use SGMCMCJax with a PyMC model.

https://github.com/pymc-devs/pymc/blob/main/pymc/sampling_jax.py#L141
https://github.com/blackjax-devs/blackjax/blob/main/examples/use_with_pymc3.ipynb

@jeremiecoullon
Copy link
Owner

This is a great idea! 
This is related to this issue in NumPyro.

However the thing to figure out here is how to convert PyMC3's log-posterior into a log-prior and log-likelihood that take in a single data point as well as model parameters. That way the gradient estimators in sgmcmcjax can be used.

In the sample_numpyro_nuts function the log-posterior is obtained and then passed into the numpyro sampler. Is the data "baked in" this log-posterior function, or is there a way to define the log-posterior with data still as an argument?

@ricardoV94
Copy link

ricardoV94 commented Feb 24, 2022

You can obtain the "log-prior" graph via model.varlogpt and the "log-likelihood" graph via model.datalogpt. The latter has the data "baked" in. Is this sufficient, or do you still need to be able to evaluate the "log-likelihood" for subsets of the data at a time?

Instead of just using model.logpt() in this line: https://github.com/pymc-devs/pymc/blob/47503bf6bbe1d5617e1a0e089c05a48540c14a1d/pymc/sampling_jax.py#L100

@jeremiecoullon
Copy link
Owner

jeremiecoullon commented Feb 24, 2022

@ricardoV94 : the log-prior graph sounds like the correct thing yes.

However the log-likelihood would need to be a function without data "baked in". So a function like log_likelihood(params, data).

The algorithms in sgmcmcjax would use this function to only evaluate the log-likelihood for subsets of data at a time as you mentioned. You can see how this happens in this page of the docs in cell number 5: mygrad = grad_log_post(get_params(state), *data) # use all the data.. In that toy example all the data is passed in, but usually a subset would be passed instead.

Note that this grad_log_post function was built in cell number 4 from the log-prior and log-likelihood: grad_log_post = build_grad_log_post(loglikelihood, logprior, data)

@ricardoV94
Copy link

ricardoV94 commented Feb 24, 2022

How do you make sure data is partitioned correctly? If I have a vectorized Normal likelihood with mu.shape == (N,) and observed_data.shape = (M, N) You can pass one row at a time, but if you have observed_data.shape == (N,), you can't partition the likelihood, or can you?

Anyway, to have the data as an input, you can do something more involved like:

import pymc as pm
from pymc.sampling_jax import get_jaxified_graph

with pm.Model() as model:
    x = pm.Normal("x")
    y = pm.Normal("y", x, observed=[0, 1, 2, 3])
    
print(model.compile_fn(model.datalogpt)({"x": 0}))

original_data = []
dummy_data_inputs = []
for observed_RV in model.observed_RVs:
  data =  model.rvs_to_values[observed_RV]
  dummy_data_input = data.type()
  # TODO: You should revert these inplace changes after you're done
  model.rvs_to_values[observed_RV] = dummy_data_input
  original_data.append(data.data)
  dummy_data_inputs.append(dummy_data_input)
  
loglike_fn = get_jaxified_graph(
    inputs=model.value_vars + dummy_data_inputs,
    outputs=[model.datalogpt],
)
print(
    loglike_fn(0, original_data[0][:1]),
    loglike_fn(0, original_data[0][:2]),
    loglike_fn(0, original_data[0][:3]),
    loglike_fn(0, original_data[0][:4]),
    sep="\n",
)    
-10.67575413281869
(DeviceArray(-0.91893853, dtype=float64),)
(DeviceArray(-2.33787707, dtype=float64),)
(DeviceArray(-5.2568156, dtype=float64),)
(DeviceArray(-10.67575413, dtype=float64),)

@ricardoV94
Copy link

PyMC v4 has a JAX backend and can use samplers like those from numpyro or blackjax, it should be pretty easy thus to add an example of how to use SGMCMCJax with a PyMC model.

https://github.com/pymc-devs/pymc/blob/main/pymc/sampling_jax.py#L141 https://github.com/blackjax-devs/blackjax/blob/main/examples/use_with_pymc3.ipynb

Your second link should be: https://github.com/blackjax-devs/blackjax/blob/main/examples/use_with_pymc.ipynb

@jeremiecoullon
Copy link
Owner

jeremiecoullon commented Feb 24, 2022

How do you make sure data is partitioned correctly?

I don't quite understand what this means; could you explain some more please?

In your example: mu.shape == (N,), is N the dimensionality of the mean parameter mu? So in observed_data.shape = (M, N), M is the numer of observations? But if that's the case, then observed_data.shape == (N,) corresponds to only having a single data point?

In case this is what you were asking about: the standard way to estimate the log-likelihood for these sgmcmc algorithms is to use equation (4) in this paper. Note that in this equation U_i(\theta) is the potential for a single data point given a parameter \theta.

This is implemented in this library here; note that batch_loglik batches over the data, with the parameters fixed.

@ricardoV94
Copy link

ricardoV94 commented Feb 24, 2022

@jeremiecoullon thanks for the reply. The idea of splitting the data in the likelihood just seemed surprising to me.

For instance if you have a linear regression, each "datapoint" includes multiple predictors + observation(s). Or if you have multivariate likelihood you may have several observations per "datapoint". I was just curious how did you guys handle those cases.

Anyway let me know if the snippet above is sufficient to make an example with PyMC :)

@jeremiecoullon
Copy link
Owner

Splitting data in the likelihood: the approach is exactly the same as in stochastic gradient descent (and related algorithms like Adam, RMSProp etc..).

For instance if you have a linear regression, each "datapoint" includes multiple predictors + observation(s). Or if you have multivariate likelihood you may have several observations per "datapoint". I was just curious how did you guys handle those cases.

Are you asking what happens when the data is high dimensional? And what we do in the case of supervised learning?

To give an example, consider a dataset of 5 points: D = {(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x5, y5)}. This is a supervised problem, and x_i may be high dimensional.

A minibatch of data might be D_minibatch = {(x2, y2), (x5, y5)}, or D_minibatch = {(x1, y1), (x4, y4), (x5, y5)}. The likelihood for this smaller dataset is faster to evaluate than the likelihood for the entire dataset

So you don’t “split” the dimensionality of x_i, and you always keep x_i and y_i together. Is this what you meant?

Snippet: I'll have a look at this to understand it and see if that works!

@ricardoV94
Copy link

So you don’t “split” the dimensionality of x_i, and you always keep x_i and y_i together. Is this what you meant?

Yes that's what I meant! Thanks for clarifying

@jeremiecoullon
Copy link
Owner

I added a notebook with a basic Gaussian model example working.

Some questions:

  • In the PyMC sampling cell (cell number 5), the sampler takes much longer than the progress bar shows. Is there something like block_until_ready() needed in there to correctly benchmark this? I simply print out the last sample to force the cell to wait until all the computations are done.
  • I would like to add a ESS/sec calculation as this would give a better comparison than simply running time. How do you recommend I do this? Arviz? I usually use emcee's autocorr module but using PyMC related tools would make more sense.
  • I was also thinking of adding PMF instead of a Gaussian example as this would be a more realistic than a Gaussian, and the PyMC NUTS sampler takes an hour to run.
    However when I tried to jaxify the model I got NotImplementedError: No JAX conversion for the given Op: SolveTriangular{lower=True, trans=0, unit_diagonal=False, check_finite=True}. Is it easy to convert this to PyMC V4?

@twiecki
Copy link
Author

twiecki commented Mar 3, 2022

This is great @jeremiecoullon, thanks for adding that! Also the SolveTriangular, we still need to add a JAX implementation for it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants