Skip to content

Commit

Permalink
Convenient function to access inference methods and kwargs (#795)
Browse files Browse the repository at this point in the history
* add inference_methods class to obtain names of methods and kwargs

* re-run notebook

* update notebook to include new methods

* convienent methods for getting inference names and kwargs

* Fix `get_model_covariates()` utility function (#801)

* Support PyMC 5.13 and fix bayeux related issues (#803)

* run black to fix formatting

* add test to check for inference method names

* test get_kwargs method of InferenceMethods class

---------

Co-authored-by: Tomás Capretto <tomicapretto@gmail.com>
  • Loading branch information
GStechschulte and tomicapretto committed Apr 15, 2024
1 parent 793be6a commit b5aefcf
Show file tree
Hide file tree
Showing 7 changed files with 1,286 additions and 996 deletions.
3 changes: 2 additions & 1 deletion bambi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pymc import math

from .backend import PyMCModel
from .backend import inference_methods, PyMCModel
from .config import config
from .data import clear_data_home, load_data
from .families import Family, Likelihood, Link
Expand All @@ -25,6 +25,7 @@
"Formula",
"clear_data_home",
"config",
"inference_methods",
"load_data",
"math",
]
Expand Down
3 changes: 2 additions & 1 deletion bambi/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .pymc import PyMCModel
from .inference_methods import inference_methods

__all__ = ["PyMCModel"]
__all__ = ["inference_methods", "PyMCModel"]
119 changes: 119 additions & 0 deletions bambi/backend/inference_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import importlib
import inspect
import operator

import pymc as pm


class InferenceMethods:
"""Obtain a dictionary of available inference methods for Bambi
models and or the default kwargs of each inference method.
"""

def __init__(self):
# In order to access inference methods, a bayeux model must be initialized
self.bayeux_model = bayeux_model()
self.bayeux_methods = self._get_bayeux_methods(bayeux_model())
self.pymc_methods = self._pymc_methods()

def _get_bayeux_methods(self, model):
# Bambi only supports bayeux MCMC methods
mcmc_methods = model.methods.get("mcmc")
return {"mcmc": mcmc_methods}

def _pymc_methods(self):
return {"mcmc": ["mcmc"], "vi": ["vi"]}

def _remove_parameters(self, fn_signature_dict):
# Remove 'pm.sample' parameters that are irrelevant for Bambi users
params_to_remove = [
"progressbar",
"progressbar_theme",
"var_names",
"nuts_sampler",
"return_inferencedata",
"idata_kwargs",
"callback",
"mp_ctx",
"model",
]
return {k: v for k, v in fn_signature_dict.items() if k not in params_to_remove}

def get_kwargs(self, method):
"""Get the default kwargs for a given inference method.
Parameters
----------
method : str
The name of the inference method.
Returns
-------
dict
The default kwargs for the inference method.
"""
if method in self.bayeux_methods.get("mcmc"):
bx_method = operator.attrgetter(method)(
self.bayeux_model.mcmc # pylint: disable=no-member
)
return bx_method.get_kwargs()
elif method in self.pymc_methods.get("mcmc"):
return self._remove_parameters(get_default_signature(pm.sample))
elif method in self.pymc_methods.get("vi"):
return get_default_signature(pm.ADVI.fit)
else:
raise ValueError(
f"Inference method '{method}' not found in the list of available"
" methods. Use `bmb.inference_methods.names` to list the available methods."
)

@property
def names(self):
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}


def bayeux_model():
"""Dummy bayeux model for obtaining inference methods.
A dummy model is needed because algorithms are dynamically determined at
runtime, based on the libraries that are installed. A model can give
programmatic access to the available algorithms via the `methods` attribute.
Returns
-------
bayeux.Model
A dummy model with a simple quadratic likelihood function.
"""
if importlib.util.find_spec("bayeux") is None:
return {"mcmc": []}

import bayeux as bx # pylint: disable=import-outside-toplevel

return bx.Model(lambda x: -(x**2), 0.0)


def get_default_signature(fn):
"""Get the default parameter values of a function.
This function inspects the signature of the provided function and returns
a dictionary containing the default values of its parameters.
Parameters
----------
fn : callable
The function for which default argument values are to be retrieved.
Returns
-------
dict
A dictionary mapping argument names to their default values.
"""
defaults = {}
for key, val in inspect.signature(fn).parameters.items():
if val.default is not inspect.Signature.empty:
defaults[key] = val.default
return defaults


inference_methods = InferenceMethods()
32 changes: 4 additions & 28 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import importlib
import logging
import operator
import traceback
Expand All @@ -14,6 +13,7 @@
import pytensor.tensor as pt
from pytensor.tensor.special import softmax

from bambi.backend.inference_methods import inference_methods
from bambi.backend.links import cloglog, identity, inverse_squared, logit, probit, arctan_2
from bambi.backend.model_components import ConstantComponent, DistributionalComponent
from bambi.utils import get_aliased_name
Expand Down Expand Up @@ -47,8 +47,8 @@ def __init__(self):
self.model = None
self.spec = None
self.components = {}
self.bayeux_methods = _get_bayeux_methods()
self.pymc_methods = {"mcmc": ["mcmc"], "vi": ["vi"]}
self.bayeux_methods = inference_methods.names["bayeux"]
self.pymc_methods = inference_methods.names["pymc"]

def build(self, spec):
"""Compile the PyMC model from an abstract model specification.
Expand Down Expand Up @@ -348,8 +348,7 @@ def _run_laplace(self, draws, omit_offsets, include_mean):
Mainly for pedagogical use, provides reasonable results for approximately
Gaussian posteriors. The approximation can be very poor for some models
like hierarchical ones. Use ``mcmc``, ``vi``, or JAX based MCMC methods
for better approximations.
like hierarchical ones. Use MCMC or VI methods for better approximations.
Parameters
----------
Expand Down Expand Up @@ -398,10 +397,6 @@ def constant_components(self):
def distributional_components(self):
return {k: v for k, v in self.components.items() if isinstance(v, DistributionalComponent)}

@property
def inference_methods(self):
return {"pymc": self.pymc_methods, "bayeux": self.bayeux_methods}


def _posterior_samples_to_idata(samples, model):
"""Create InferenceData from samples.
Expand Down Expand Up @@ -441,22 +436,3 @@ def _posterior_samples_to_idata(samples, model):

idata = pm.to_inference_data(pm.backends.base.MultiTrace([strace]), model=model)
return idata


def _get_bayeux_methods():
"""Gets a dictionary of usable bayeux methods if the bayeux package is installed
within the user's environment.
Returns
-------
dict
A dict where the keys are the module names and the values are the methods
available in that module.
"""
if importlib.util.find_spec("bayeux") is None:
return {"mcmc": []}

import bayeux as bx # pylint: disable=import-outside-toplevel

# Dummy log density to get access to all methods
return bx.Model(lambda x: -(x**2), 0.0).methods
4 changes: 2 additions & 2 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def fit(
Finally, ``"laplace"``, in which case a Laplace approximation is used and is not
recommended other than for pedagogical use.
To get a list of JAX based inference methods, call
``model.backend.inference_methods['bayeux']``. This will return a dictionary of the
``bmb.inference_methods.names['bayeux']``. This will return a dictionary of the
available methods such as ``blackjax_nuts``, ``numpyro_nuts``, among others.
init : str
Initialization method. Defaults to ``"auto"``. The available methods are:
Expand Down Expand Up @@ -307,7 +307,7 @@ def fit(
-------
An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default),
"laplace", or one of the MCMC methods in
``model.backend.inference_methods['bayeux']['mcmc]``.
``bmb.inference_methods.names['bayeux']['mcmc]``.
An ``Approximation`` object if ``"vi"``.
"""
method = kwargs.pop("method", None)
Expand Down

0 comments on commit b5aefcf

Please sign in to comment.