Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computing log_prob for tfd.Sample() with a different number of samples #1792

Open
nick-ponvert opened this issue Feb 24, 2024 · 0 comments
Open

Comments

@nick-ponvert
Copy link

nick-ponvert commented Feb 24, 2024

I am interested in being able to construct a joint distribution (I use JointDistributionCoroutineAutobatched) for use in regression modeling which includes predictors as part of the model specification. I want to do this so that I can use the same model to construct a generative dataset, test inference code, and then pin to real predictors and observations for inference. Here is an example to help illustrate what I am trying to do.

In this example, we are regressing kcal_per_gram against neocortex_pct (from Statistical Rethinking chapter 5).

@tfd.JointDistributionCoroutineAutoBatched
def model():
    
    # Generative neocortex_pct
    mu_N = yield tfd.Normal(0, 0.2, name='mu_N')
    sigma_N = yield tfd.Exponential(1, name='sigma_N')
    neocortex_pct = yield tfd.Sample(tfd.Normal(mu_N, sigma_N), sample_shape=20, name='neocortex_pct')
    
    intercept = yield tfd.Normal(0, 0.2, name='intercept')
    beta_N = yield tfd.Normal(0, 0.5, name='beta_N')
    mu = intercept + beta_N * neocortex_pct
    
    sigma = yield tfd.Exponential(1, name='sigma')
    kcal_per_gram = yield tfd.Normal(mu, sigma, name='kcal_per_gram')

This allows constructing a generative dataset just by taking a prior sample. You can even set a beta_N value during the sampling and then test whether your algorithm can recover it (I am using the JAX substrate):

key, prior_sample_key = random.split(key)
prior_samples = model.sample(seed=key, beta_N=2)

The issue I am getting is that I then want to condition this distribution on real data and run the inference algorithm. My preferred way of doing this would be to use experimental_pin()

data_dict = {'neocortex_pct': data_df['neocortex_pct'].values, 'kcal_per_gram': data_df['kcal_per_gram'].values})
model_pinned = model.experimental_pin(data_dict)

Then use model_pinned.log_prob() along with model_pinned.experimental_default_event_space_bijector() to either do Laplace approximation or run an MCMC chain. The issue is that if there are a different number of samples in my dataset than the sample_shape that I used when constructing the JointDistribution, I get broadcasting errors. Ideally I would like to be able to use another number of samples later - the one defined in the JointDistribution being used for generating data, but the ability to then do inference with whatever number of samples. Is this something that can be achieved via broadcasting somehow? If not, is there a different way that would be suggested to achieve what I'm looking for? Thanks in advance, and let me know if I can provide any more info.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant