Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Bayesian regression example from NumPyro in Pyro #3006

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from

Conversation

arijeetchatterjee
Copy link

This PR contributes the implementation of a Bayesian regression example / tutorial from NumPyro to Pyro.
There is one issue that I am not able to address yet - I have to run the below cell before running the cells for Model 2 and Model 3. I did not face this issue with the NumPyro tutorial.

# Run NUTS
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_samples=num_samples, warmup_steps=200)

Please let me know what you think @eb8680

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@eb8680
Copy link
Member

eb8680 commented Jan 20, 2022

@arijc76 sorry I am so slow this week, great work! Here are some minor comments:

Notebook

  • Could you use the same model function for SVI and MCMC, instead of having a separate model_svi?
  • Could you add a bit of descriptive text in the section on SVI? Perhaps a sentence or two before each of cell 20, 21 and 22 explaining what each cell is about to do and what the results mean
  • Typo: "For this we one of" -> "For this we use one of"
  • Could you use the smoke_test parameter to set num_samples = 2 and num_warmup = 2, like you did with num_iter in cell 20?

Rendering

  • Could you add your notebook's filename to tutorial/source/index.rst under the "Other inference algorithms" header?
  • Could you run cd tutorial && make html locally and check that everything in the notebook (e.g. math, images) is rendered correctly on your machine? The generated HTML files should appear in tutorial/build/html/ - start a temporary local HTTP server in the tutorial directory (e.g. via python -m http.server), open index.html in a browser and check that your example appears under the "other inference algorithms" header on the sidebar, then click the link to make sure it works and scroll through the generated HTML of your notebook and look for obvious visual errors

Testing

Your notebook is being executed correctly during CI (modulo reducing num_samples when smoke_test == True as requested above), but it is currently failing with the following error:

=================================== FAILURES ===================================
_______ tutorial/source/bayesian_regression_mcmc_and_svi.ipynb::Cell 21 ________
Notebook cell execution failed
Cell 21: Cell execution caused an exception

Input:
mcmc.run(
    age=torch.tensor(dset.AgeScaled.values, dtype=torch.float), 
    divorce=torch.tensor(dset.DivorceScaled.values, dtype=torch.float)
)
mcmc.summary()
samples_2 = mcmc.get_samples()

Traceback:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-22-596822a70e0a> in <module>
      1 mcmc.run(
      2     age=torch.tensor(dset.AgeScaled.values, dtype=torch.float),
----> 3     divorce=torch.tensor(dset.DivorceScaled.values, dtype=torch.float)
      4 )
      5 mcmc.summary()

/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    561             # requires_grad", which happens with `jit_compile` under PyTorch 1.7
    562             args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 563             for x, chain_id in self.sampler.run(*args, **kwargs):
    564                 if num_samples[chain_id] == 0:
    565                     num_samples[chain_id] += 1

/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    228                 i if self.num_chains > 1 else None,
    229                 *args,
--> 230                 **kwargs
    231             ):
    232                 yield sample, i  # sample, chain_id

/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
    142 
    143 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 144     kernel.setup(warmup_steps, *args, **kwargs)
    145     params = kernel.initial_params
    146     save_params = getattr(kernel, "save_params", sorted(params))

/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    323         self._warmup_steps = warmup_steps
    324         if self.model is not None:
--> 325             self._initialize_model_properties(args, kwargs)
    326         if self.initial_params:
    327             z = {k: v.detach() for k, v in self.initial_params.items()}

/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
    267             skip_jit_warnings=self._ignore_jit_warnings,
    268             init_strategy=self._init_strategy,
--> 269             initial_params=self._initial_params,
    270         )
    271         self.potential_fn = potential_fn

/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
    462 
    463     if initial_params is None:
--> 464         prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()}
    465         # Note that we deliberately do not exercise jit compilation here so as to
    466         # enable potential_fn to be picklable (a torch._C.Function cannot be pickled).

/opt/hostedtoolcache/Python/3.6.15/x64/lib/python3.6/site-packages/pyro/infer/mcmc/util.py in <dictcomp>(.0)
    462 
    463     if initial_params is None:
--> 464         prototype_params = {k: transforms[k](v) for k, v in prototype_samples.items()}
    465         # Note that we deliberately do not exercise jit compilation here so as to
    466         # enable potential_fn to be picklable (a torch._C.Function cannot be pickled).

KeyError: 'bA'

@eb8680
Copy link
Member

eb8680 commented Jan 20, 2022

I have to run the below cell before running the cells for Model 2 and Model 3.

There may be some difference between the MCMC objects in Pyro and NumPyro. Could you try creating separate instances for each model rather than reusing the same one?

@eb8680
Copy link
Member

eb8680 commented Jan 20, 2022

The custom predict_fn in your updated "Predictive Utility With Effect Handlers" section can be simplified - the pyro.plate("samples", 2000) you added plays the same role as Jax's vmap in this case, so you should be able to write something like this in cell 12 that more directly follows the original structure:

def predict(post_samples, model, *args, **kwargs):
    conditioned_model = poutine.condition(model, post_samples)
    model_trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
    return model_trace.nodes["obs"]["value"]

def predict_fn(post_samples):
    with pyro.plate("samples", num_samples):
        return predict(post_samples, model, marriage=torch.tensor(dset.MarriageScaled.values, dtype=torch.float))

@arijeetchatterjee
Copy link
Author

The custom predict_fn in your updated "Predictive Utility With Effect Handlers" section can be simplified - the pyro.plate("samples", 2000) you added plays the same role as Jax's vmap in this case, so you should be able to write something like this in cell 12 that more directly follows the original structure:

def predict(post_samples, model, *args, **kwargs):
    conditioned_model = poutine.condition(model, post_samples)
    model_trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
    return model_trace.nodes["obs"]["value"]

def predict_fn(post_samples):
    with pyro.plate("samples", num_samples):
        return predict(post_samples, model, marriage=torch.tensor(dset.MarriageScaled.values, dtype=torch.float))

@eb8680 Thanks for this suggestion.

@arijeetchatterjee
Copy link
Author

  • cd tutorial && make html

@eb8680 When I run cd tutorial && make html, the build stops at waiting for workers.... Then after some time I get the error as shown below. Can you let me know what I need to do for this? Thanks.

Running Sphinx v4.4.0
building [mo]: targets for 0 po files that are out of date
building [html]: targets for 78 source files that are out of date
updating environment: [new config] 78 added, 0 changed, 0 removed
reading sources... [100%] tensor_shapes .. working_memory                                                                     
waiting for workers...

Warning, treated as error:
/Users/arijeetchatterjee/Documents/github_personal_projects/pyro/tutorial/source/ss-vae.ipynb:7:Unexpected indentation.
make: *** [html] Error 2

@eb8680
Copy link
Member

eb8680 commented Jan 25, 2022

Can you let me know what I need to do for this?

I can't reproduce your error, but you can tell Sphinx not to treat warnings as errors by overriding the SPHINXOPTS environment variable used in our Makefile:

SPHINXOPTS="-E -j 8" make html

@arijeetchatterjee
Copy link
Author

arijeetchatterjee commented Jan 29, 2022

Can you let me know what I need to do for this?

I can't reproduce your error, but you can tell Sphinx not to treat warnings as errors by overriding the SPHINXOPTS environment variable used in our Makefile:

SPHINXOPTS="-E -j 8" make html

Thanks @eb8680
Sorry about the delay. I was facing some issues with make html not working locally, but that's solved now (made a couple of changes in conf.py along with above suggested change for SPHINXOPTS in the Makefile).
I have completed the suggested changes and now when I run make html to render the HTML, the example appears under the "other inference algorithms" header on the sidebar. The generated HTML of the notebook does not show any visual errors.
Can I commit only the updated version of the notebook for a review?

[UPDATED] I have committed the notebook with the changes as mentioned above. Please take a look. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants